From 3cc9b4bcc788798139a4fee497c6ef7ae603a9fd Mon Sep 17 00:00:00 2001 From: Guy David Date: Wed, 20 Nov 2019 10:07:04 +0000 Subject: [PATCH 001/492] Missing activation bind of built-in Mul operation --- .../delegates/gpu/common/model_builder.cc | 47 ++++++++++++------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index f9bce0b5542..1bf9810061e 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1368,8 +1368,9 @@ class MulOperationParser : public TFLiteOperationParser { if (tflite_node->inputs->size != 2) { return UnimplementedError("MUL requires two input tensors."); } - // TODO(eignasheva): Add params check. - return OkStatus(); + TfLiteMulParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + return IsActivationSupported(tf_options->activation); } Status Parse(const TfLiteNode* tflite_node, @@ -1392,6 +1393,8 @@ class MulOperationParser : public TFLiteOperationParser { const bool runtime_tensor0 = !constant_tensor0; const bool runtime_tensor1 = !constant_tensor1; + Node* node = graph->NewNode(); + // Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st // input and the "smaller" input tensor ("mask") must be bound to 2nd input. if (runtime_tensor0 && runtime_tensor1) { @@ -1406,27 +1409,36 @@ class MulOperationParser : public TFLiteOperationParser { input_tensor0 = 1; input_tensor1 = 0; } - return ParseApplyMask(input_tensor0, input_tensor1, graph, reader); + RETURN_IF_ERROR(ParseApplyMask(input_tensor0, input_tensor1, node, graph, reader)); + } + else + { + // Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st + // input and the constant input tensor must be bound to 2nd input. + int runtime_tensor = 0; + int constant_tensor = 1; + TfLiteIntArray* constant_dims = input1->dims; + if (constant_tensor0 && runtime_tensor1) { + runtime_tensor = 1; + constant_tensor = 0; + constant_dims = input0->dims; + } + RETURN_IF_ERROR(ParseMultiplyScalar(runtime_tensor, constant_tensor, + constant_dims, node, graph, reader)); } - // Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st - // input and the constant input tensor must be bound to 2nd input. - int runtime_tensor = 0; - int constant_tensor = 1; - TfLiteIntArray* constant_dims = input1->dims; - if (constant_tensor0 && runtime_tensor1) { - runtime_tensor = 1; - constant_tensor = 0; - constant_dims = input0->dims; + const auto* tf_options = + reinterpret_cast(tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing tflite params"); } - return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims, - graph, reader); + return MaybeFuseActivationToTheSingleOutput( + tf_options->activation, graph, node); } private: Status ParseApplyMask(int input_tensor0, int input_tensor1, - GraphFloat32* graph, ObjectReader* reader) { - Node* node = graph->NewNode(); + Node* node, GraphFloat32* graph, ObjectReader* reader) { node->operation.type = ToString(OperationType::APPLY_MASK); RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); @@ -1435,8 +1447,7 @@ class MulOperationParser : public TFLiteOperationParser { Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor, const TfLiteIntArray* constant_dims, - GraphFloat32* graph, ObjectReader* reader) { - Node* node = graph->NewNode(); + Node* node, GraphFloat32* graph, ObjectReader* reader) { node->operation.type = ToString(OperationType::MULTIPLY_SCALAR); RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); MultiplyScalarAttributes attr; From 3f950776c751a18ed3cee26c2bfa73264ec5dfb9 Mon Sep 17 00:00:00 2001 From: Guy David Date: Tue, 3 Dec 2019 22:08:17 +0200 Subject: [PATCH 002/492] Added clarity to error messages --- .../delegates/gpu/common/model_builder.cc | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 1bf9810061e..2f08cb077b1 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -710,7 +710,7 @@ class AddOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLiteAddParams"); } return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node); @@ -787,7 +787,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLiteConcatenationParams"); } RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node)); @@ -870,7 +870,7 @@ class Conv2DOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLiteConvParams"); } attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); attr.dilations = HW(tf_options->dilation_height_factor, @@ -1294,7 +1294,7 @@ class LSTMOperationParser : public TFLiteOperationParser { const auto* params = reinterpret_cast(tflite_node->builtin_data); if (!params) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLiteLSTMParams"); } RETURN_IF_ERROR(CheckParameters(params)); @@ -1410,9 +1410,7 @@ class MulOperationParser : public TFLiteOperationParser { input_tensor1 = 0; } RETURN_IF_ERROR(ParseApplyMask(input_tensor0, input_tensor1, node, graph, reader)); - } - else - { + } else { // Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st // input and the constant input tensor must be bound to 2nd input. int runtime_tensor = 0; @@ -1430,7 +1428,7 @@ class MulOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLiteMulParams"); } return MaybeFuseActivationToTheSingleOutput( tf_options->activation, graph, node); @@ -1593,7 +1591,7 @@ class Pooling2DOperationParser : public TFLiteOperationParser { reinterpret_cast(tflite_node->builtin_data); } if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLitePoolParams"); } std::vector max_tensor_id{0}; @@ -1711,7 +1709,7 @@ class ResizeBilinearOperationParser : public TFLiteOperationParser { reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLiteResizeBilinearParams"); } Upsample2DAttributes attr; attr.align_corners = tf_options->align_corners; @@ -1751,7 +1749,7 @@ class SoftmaxOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLiteSoftmaxParams"); } if (tf_options->beta != 1) { // there is multiply by scalar operation fused in softmax. Make a layer @@ -1882,7 +1880,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser { tflite_node->builtin_data); auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLiteStridedSliceParams"); } RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); @@ -2055,7 +2053,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite options."); + return InternalError("Missing TfLiteTransposeConvParams"); } ConvolutionTransposedAttributes attr; attr.stride = tf_options @@ -2137,7 +2135,7 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->custom_initial_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return InternalError("Missing TfLitePoolParams (Unpooling)"); } attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width); attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); From 2349aca08b0e4de14c28b8f4fe501c32c3b8c2de Mon Sep 17 00:00:00 2001 From: Guy David Date: Tue, 3 Dec 2019 22:26:24 +0200 Subject: [PATCH 003/492] Create a new node just when it's required --- .../delegates/gpu/common/model_builder.cc | 51 ++++++++++--------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 2f08cb077b1..353c70c276e 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1393,7 +1393,11 @@ class MulOperationParser : public TFLiteOperationParser { const bool runtime_tensor0 = !constant_tensor0; const bool runtime_tensor1 = !constant_tensor1; - Node* node = graph->NewNode(); + const auto* tf_options = + reinterpret_cast(tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing TfLiteMulParams"); + } // Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st // input and the "smaller" input tensor ("mask") must be bound to 2nd input. @@ -1409,43 +1413,40 @@ class MulOperationParser : public TFLiteOperationParser { input_tensor0 = 1; input_tensor1 = 0; } - RETURN_IF_ERROR(ParseApplyMask(input_tensor0, input_tensor1, node, graph, reader)); - } else { - // Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st - // input and the constant input tensor must be bound to 2nd input. - int runtime_tensor = 0; - int constant_tensor = 1; - TfLiteIntArray* constant_dims = input1->dims; - if (constant_tensor0 && runtime_tensor1) { - runtime_tensor = 1; - constant_tensor = 0; - constant_dims = input0->dims; - } - RETURN_IF_ERROR(ParseMultiplyScalar(runtime_tensor, constant_tensor, - constant_dims, node, graph, reader)); + return ParseApplyMask( + input_tensor0, input_tensor1, tf_options, graph, reader); } - - const auto* tf_options = - reinterpret_cast(tflite_node->builtin_data); - if (!tf_options) { - return InternalError("Missing TfLiteMulParams"); + + // Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st + // input and the constant input tensor must be bound to 2nd input. + int runtime_tensor = 0; + int constant_tensor = 1; + TfLiteIntArray* constant_dims = input1->dims; + if (constant_tensor0 && runtime_tensor1) { + runtime_tensor = 1; + constant_tensor = 0; + constant_dims = input0->dims; } - return MaybeFuseActivationToTheSingleOutput( - tf_options->activation, graph, node); + return ParseMultiplyScalar(runtime_tensor, constant_tensor, + constant_dims, tf_options, graph, reader); } private: Status ParseApplyMask(int input_tensor0, int input_tensor1, - Node* node, GraphFloat32* graph, ObjectReader* reader) { + const TfLiteMulParams* params, GraphFloat32* graph, ObjectReader* reader) { + Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::APPLY_MASK); RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); + RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput( + params->activation, graph, node)); return reader->AddOutputs(node); } Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor, const TfLiteIntArray* constant_dims, - Node* node, GraphFloat32* graph, ObjectReader* reader) { + const TfLiteMulParams* params, GraphFloat32* graph, ObjectReader* reader) { + Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::MULTIPLY_SCALAR); RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); MultiplyScalarAttributes attr; @@ -1459,6 +1460,8 @@ class MulOperationParser : public TFLiteOperationParser { attr.param = std::move(tensor); } node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput( + params->activation, graph, node)); return reader->AddOutputs(node); } }; From 767fb451c40796c5971293af2b15d6037476e07e Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Thu, 19 Dec 2019 01:51:53 +0000 Subject: [PATCH 004/492] Fix typo in TFLiteConverter.from_concrete_function error message --- tensorflow/lite/python/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 83e97f156eb..b4bb9630ec2 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -344,7 +344,7 @@ class TFLiteConverterV2(TFLiteConverterBase): message = "This function takes in a list of ConcreteFunction." if isinstance(func, _def_function.Function): message += (" To get the ConcreteFunction from a Function," - " call from_concrete_function.") + " call get_concrete_function.") raise ValueError(message) return cls(funcs) From 359f22eef2e6c6e0adc804f88b1a94b91f7e2372 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Fri, 10 Jan 2020 01:25:04 +0100 Subject: [PATCH 005/492] Fix internal unittest --- tensorflow/lite/python/lite_v2_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index f4a6a4e6d19..e8048be2865 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -115,7 +115,7 @@ class FromConcreteFunctionTest(TestModels): root = self._getSimpleVariableModel() with self.assertRaises(ValueError) as error: _ = lite.TFLiteConverterV2.from_concrete_functions([root.f]) - self.assertIn('call from_concrete_function', str(error.exception)) + self.assertIn('call get_concrete_function', str(error.exception)) @parameterized.named_parameters( ('EnableMlirConverter', True), # enable mlir From 9bbc41c55433cbd7749ad50a2882fe07c2babf63 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Wed, 11 Sep 2019 12:06:23 +0100 Subject: [PATCH 006/492] 16-bit support for reference kernels: MAX/MIN element-wise reference operators PACK UNPACK --- tensorflow/lite/kernels/maximum_minimum.cc | 3 + .../lite/kernels/maximum_minimum_test.cc | 11 + tensorflow/lite/kernels/pack.cc | 4 + tensorflow/lite/kernels/pack_test.cc | 75 ++--- tensorflow/lite/kernels/unpack.cc | 6 +- tensorflow/lite/kernels/unpack_test.cc | 262 +++++------------- 6 files changed, 110 insertions(+), 251 deletions(-) diff --git a/tensorflow/lite/kernels/maximum_minimum.cc b/tensorflow/lite/kernels/maximum_minimum.cc index c51d7d07aff..29ac306311b 100644 --- a/tensorflow/lite/kernels/maximum_minimum.cc +++ b/tensorflow/lite/kernels/maximum_minimum.cc @@ -119,6 +119,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TFLiteOperation(context, node, op_context); break; + case kTfLiteInt16: + TFLiteOperation(context, node, op_context); + break; default: context->ReportError(context, "Type %d is currently not supported by Maximum.", diff --git a/tensorflow/lite/kernels/maximum_minimum_test.cc b/tensorflow/lite/kernels/maximum_minimum_test.cc index 669135bcf51..b421dd3b3ea 100644 --- a/tensorflow/lite/kernels/maximum_minimum_test.cc +++ b/tensorflow/lite/kernels/maximum_minimum_test.cc @@ -123,6 +123,17 @@ TEST(MaxMinOpTest, Int8Test) { data1, data2, {0, 0, 1, 11, 2, 1}); } +TEST(MaxMinOpTest, Int16Test) { + std::initializer_list data1 = {-32768, 0, 2, 11, 2, 23}; + std::initializer_list data2 = {0, 0, 1, 32767, 123, 1}; + TestModel(BuiltinOperator_MAXIMUM, {TensorType_INT16, {3, 1, 2}}, + {TensorType_INT16, {3, 1, 2}}, {TensorType_INT16, {3, 1, 2}}, + data1, data2, {0, 0, 2, 32767, 123, 23}); + TestModel(BuiltinOperator_MINIMUM, {TensorType_INT16, {3, 1, 2}}, + {TensorType_INT16, {3, 1, 2}}, {TensorType_INT16, {3, 1, 2}}, + data1, data2, {-32768, 0, 1, 11, 2, 1}); +} + TEST(MaximumOpTest, FloatWithBroadcastTest) { std::initializer_list data1 = {1.0, 0.0, -1.0, -2.0, -1.44, 11.0}; std::initializer_list data2 = {0.5, 2.0}; diff --git a/tensorflow/lite/kernels/pack.cc b/tensorflow/lite/kernels/pack.cc index 8e30dce8009..ebc3381dae8 100644 --- a/tensorflow/lite/kernels/pack.cc +++ b/tensorflow/lite/kernels/pack.cc @@ -116,6 +116,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return PackImpl(context, node, output, data->values_count, data->axis); } + case kTfLiteInt16: { + return PackImpl(context, node, output, data->values_count, + data->axis); + } case kTfLiteInt32: { return PackImpl(context, node, output, data->values_count, data->axis); diff --git a/tensorflow/lite/kernels/pack_test.cc b/tensorflow/lite/kernels/pack_test.cc index 9b181e465cf..7d18bb6c34f 100644 --- a/tensorflow/lite/kernels/pack_test.cc +++ b/tensorflow/lite/kernels/pack_test.cc @@ -191,9 +191,22 @@ TEST(PackOpTest, Int64MultilDimensions) { 4LL, 5LL, 6LL, 10LL, 11LL, 12LL})); } -// uint8 -TEST(PackOpTest, Uint8ThreeInputs) { - PackOpModel model({TensorType_UINT8, {2}}, 0, 3); +template +struct PackOpTestInt : public ::testing::Test { + using TypeToTest = InputType; + TensorType TENSOR_TYPE = + std::is_same::value + ? TensorType_INT16 + : (std::is_same::value ? TensorType_UINT8 + : TensorType_INT8); +}; + +using TestTypes = testing::Types; +TYPED_TEST_CASE(PackOpTestInt, TestTypes); + +TYPED_TEST(PackOpTestInt, ThreeInputs) { + PackOpModel model( + {TestFixture::TENSOR_TYPE, {2}}, 0, 3); model.SetInput(0, {1, 4}); model.SetInput(1, {2, 5}); model.SetInput(2, {3, 6}); @@ -202,8 +215,9 @@ TEST(PackOpTest, Uint8ThreeInputs) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); } -TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) { - PackOpModel model({TensorType_UINT8, {2}}, 1, 3); +TYPED_TEST(PackOpTestInt, ThreeInputsDifferentAxis) { + PackOpModel model( + {TestFixture::TENSOR_TYPE, {2}}, 1, 3); model.SetInput(0, {1, 4}); model.SetInput(1, {2, 5}); model.SetInput(2, {3, 6}); @@ -212,8 +226,9 @@ TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(PackOpTest, Uint8ThreeInputsNegativeAxis) { - PackOpModel model({TensorType_UINT8, {2}}, -1, 3); +TYPED_TEST(PackOpTestInt, ThreeInputsNegativeAxis) { + PackOpModel model( + {TestFixture::TENSOR_TYPE, {2}}, -1, 3); model.SetInput(0, {1, 4}); model.SetInput(1, {2, 5}); model.SetInput(2, {3, 6}); @@ -222,49 +237,9 @@ TEST(PackOpTest, Uint8ThreeInputsNegativeAxis) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(PackOpTest, Uint8MultilDimensions) { - PackOpModel model({TensorType_UINT8, {2, 3}}, 1, 2); - model.SetInput(0, {1, 2, 3, 4, 5, 6}); - model.SetInput(1, {7, 8, 9, 10, 11, 12}); - model.Invoke(); - EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3)); - EXPECT_THAT(model.GetOutput(), - ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); -} - -// int8 -TEST(PackOpTest, Int8ThreeInputs) { - PackOpModel model({TensorType_INT8, {2}}, 0, 3); - model.SetInput(0, {1, 4}); - model.SetInput(1, {2, 5}); - model.SetInput(2, {3, 6}); - model.Invoke(); - EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2)); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); -} - -TEST(PackOpTest, Int8ThreeInputsDifferentAxis) { - PackOpModel model({TensorType_INT8, {2}}, 1, 3); - model.SetInput(0, {1, 4}); - model.SetInput(1, {2, 5}); - model.SetInput(2, {3, 6}); - model.Invoke(); - EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); -} - -TEST(PackOpTest, Int8ThreeInputsNegativeAxis) { - PackOpModel model({TensorType_INT8, {2}}, -1, 3); - model.SetInput(0, {1, 4}); - model.SetInput(1, {2, 5}); - model.SetInput(2, {3, 6}); - model.Invoke(); - EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); -} - -TEST(PackOpTest, Int8MultilDimensions) { - PackOpModel model({TensorType_INT8, {2, 3}}, 1, 2); +TYPED_TEST(PackOpTestInt, MultilDimensions) { + PackOpModel model( + {TestFixture::TENSOR_TYPE, {2, 3}}, 1, 2); model.SetInput(0, {1, 2, 3, 4, 5, 6}); model.SetInput(1, {7, 8, 9, 10, 11, 12}); model.Invoke(); diff --git a/tensorflow/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc index 8e66432e9cd..ebbc5e6472c 100644 --- a/tensorflow/lite/kernels/unpack.cc +++ b/tensorflow/lite/kernels/unpack.cc @@ -44,7 +44,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input)); if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 && input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 && - input->type != kTfLiteBool) { + input->type != kTfLiteInt16 && input->type != kTfLiteBool) { context->ReportError(context, "Type '%s' is not supported by unpack.", TfLiteTypeGetName(input->type)); return kTfLiteError; @@ -117,6 +117,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { UnpackImpl(context, node, input, data->num, data->axis); break; } + case kTfLiteInt16: { + UnpackImpl(context, node, input, data->num, data->axis); + break; + } default: { context->ReportError(context, "Type '%s' is not supported by unpack.", TfLiteTypeGetName(input->type)); diff --git a/tensorflow/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc index 88eb706e969..40fb0597893 100644 --- a/tensorflow/lite/kernels/unpack_test.cc +++ b/tensorflow/lite/kernels/unpack_test.cc @@ -83,213 +83,75 @@ void Check(int axis, const std::initializer_list& input_shape, EXPECT_THAT(m.GetOutputDatas(), ElementsAreArray(exp_output_data)); } -// float32 tests. -TEST(UnpackOpTest, FloatThreeOutputs) { - Check(/*axis=*/0, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{2}, {2}, {2}}, - /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}}); +template +struct UnpackOpTest : public ::testing::Test { + using TypeToTest = InputType; + TensorType TENSOR_TYPE = + std::is_same::value + ? TensorType_INT16 + : std::is_same::value + ? TensorType_UINT8 + : std::is_same::value + ? TensorType_INT8 + : std::is_same::value + ? TensorType_INT32 + : TensorType_FLOAT32; +}; + +using TestTypes = testing::Types; +TYPED_TEST_CASE(UnpackOpTest, TestTypes); + +TYPED_TEST(UnpackOpTest, ThreeOutputs) { + Check( + /*axis=*/0, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*exp_output_shape=*/{{2}, {2}, {2}}, + /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, TestFixture::TENSOR_TYPE); } -TEST(UnpackOpTest, FloatThreeOutputsAxisOne) { - Check(/*axis=*/1, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{3}, {3}}, - /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}); +TYPED_TEST(UnpackOpTest, ThreeOutputsAxisOne) { + Check( + /*axis=*/1, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*exp_output_shape=*/{{3}, {3}}, + /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}, TestFixture::TENSOR_TYPE); } -TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisOne) { - Check(/*axis=*/-1, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{3}, {3}}, - /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}); +TYPED_TEST(UnpackOpTest, ThreeOutputsNegativeAxisOne) { + Check( + /*axis=*/-1, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*exp_output_shape=*/{{3}, {3}}, + /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}, TestFixture::TENSOR_TYPE); } -TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisTwo) { - Check(/*axis=*/-2, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{2}, {2}, {2}}, - /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}}); +TYPED_TEST(UnpackOpTest, OneOutput) { + Check( + /*axis=*/0, /*input_shape=*/{1, 6}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*exp_output_shape=*/{{6}}, + /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}}, TestFixture::TENSOR_TYPE); } -TEST(UnpackOpTest, FloatOneOutput) { - Check(/*axis=*/0, /*input_shape=*/{1, 6}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{6}}, - /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}}); +TYPED_TEST(UnpackOpTest, ThreeDimensionsOutputs) { + Check( + /*axis=*/2, /*input_shape=*/{2, 2, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8}, + /*exp_output_shape=*/{{2, 2}, {2, 2}}, + /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}}, + TestFixture::TENSOR_TYPE); } -TEST(UnpackOpTest, FloatThreeDimensionsOutputs) { - Check(/*axis=*/2, /*input_shape=*/{2, 2, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8}, - /*exp_output_shape=*/{{2, 2}, {2, 2}}, - /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}}); -} - -TEST(UnpackOpTest, FloatVectorToScalar) { - Check(/*axis=*/0, /*input_shape=*/{5}, - /*input_data=*/{1, 2, 3, 4, 5}, - /*exp_output_shape=*/{{}, {}, {}, {}, {}}, - /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}}); -} - -// int32 tests. -TEST(UnpackOpTest, IntThreeOutputs) { - Check(/*axis=*/0, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{2}, {2}, {2}}, - /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, - /*type=*/TensorType_INT32); -} - -TEST(UnpackOpTest, IntThreeOutputsAxisOne) { - Check(/*axis=*/1, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{3}, {3}}, - /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}, - /*type=*/TensorType_INT32); -} - -TEST(UnpackOpTest, IntOneOutput) { - Check(/*axis=*/0, /*input_shape=*/{1, 6}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{6}}, - /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}}, - /*type=*/TensorType_INT32); -} - -TEST(UnpackOpTest, IntThreeDimensionsOutputs) { - Check(/*axis=*/2, /*input_shape=*/{2, 2, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8}, - /*exp_output_shape=*/{{2, 2}, {2, 2}}, - /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}}, - /*type=*/TensorType_INT32); -} - -TEST(UnpackOpTest, IntVectorToScalar) { - Check(/*axis=*/0, /*input_shape=*/{5}, - /*input_data=*/{1, 2, 3, 4, 5}, - /*exp_output_shape=*/{{}, {}, {}, {}, {}}, - /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}}, - /*type=*/TensorType_INT32); -} - -// uint8 tests. -TEST(UnpackOpTest, Uint8ThreeOutputs) { - Check(/*axis=*/0, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{2}, {2}, {2}}, - /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, - /*type=*/TensorType_UINT8); -} - -TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) { - Check(/*axis=*/1, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{3}, {3}}, - /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}, - /*type=*/TensorType_UINT8); -} - -TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) { - Check(/*axis=*/-1, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{3}, {3}}, - /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}, - /*type=*/TensorType_UINT8); -} - -TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) { - Check(/*axis=*/-2, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{2}, {2}, {2}}, - /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, - /*type=*/TensorType_UINT8); -} - -TEST(UnpackOpTest, Uint8OneOutput) { - Check(/*axis=*/0, /*input_shape=*/{1, 6}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{6}}, - /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}}, - /*type=*/TensorType_UINT8); -} - -TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) { - Check(/*axis=*/2, /*input_shape=*/{2, 2, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8}, - /*exp_output_shape=*/{{2, 2}, {2, 2}}, - /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}}, - /*type=*/TensorType_UINT8); -} - -TEST(UnpackOpTest, Uint8VectorToScalar) { - Check(/*axis=*/0, /*input_shape=*/{5}, - /*input_data=*/{1, 2, 3, 4, 5}, - /*exp_output_shape=*/{{}, {}, {}, {}, {}}, - /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}}, - /*type=*/TensorType_UINT8); -} - -// int8 tests. -TEST(UnpackOpTest, Int8ThreeOutputs) { - Check(/*axis=*/0, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{2}, {2}, {2}}, - /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, - /*type=*/TensorType_INT8); -} - -TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) { - Check(/*axis=*/1, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{3}, {3}}, - /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}, - /*type=*/TensorType_INT8); -} - -TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) { - Check(/*axis=*/-1, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{3}, {3}}, - /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}}, - /*type=*/TensorType_INT8); -} - -TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) { - Check(/*axis=*/-2, /*input_shape=*/{3, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{2}, {2}, {2}}, - /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, - /*type=*/TensorType_INT8); -} - -TEST(UnpackOpTest, Int8OneOutput) { - Check(/*axis=*/0, /*input_shape=*/{1, 6}, - /*input_data=*/{1, 2, 3, 4, 5, 6}, - /*exp_output_shape=*/{{6}}, - /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}}, - /*type=*/TensorType_INT8); -} - -TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) { - Check(/*axis=*/2, /*input_shape=*/{2, 2, 2}, - /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8}, - /*exp_output_shape=*/{{2, 2}, {2, 2}}, - /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}}, - /*type=*/TensorType_INT8); -} - -TEST(UnpackOpTest, Int8VectorToScalar) { - Check(/*axis=*/0, /*input_shape=*/{5}, - /*input_data=*/{1, 2, 3, 4, 5}, - /*exp_output_shape=*/{{}, {}, {}, {}, {}}, - /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}}, - /*type=*/TensorType_INT8); +TYPED_TEST(UnpackOpTest, VectorToScalar) { + Check( + /*axis=*/0, /*input_shape=*/{5}, + /*input_data=*/{1, 2, 3, 4, 5}, + /*exp_output_shape=*/{{}, {}, {}, {}, {}}, + /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}}, TestFixture::TENSOR_TYPE); } // bool tests. -TEST(UnpackOpTest, BoolThreeOutputs) { +TEST(UnpackOpTestBool, BoolThreeOutputs) { Check( /*axis=*/0, /*input_shape=*/{3, 2}, /*input_data=*/{true, false, true, false, true, false}, @@ -298,7 +160,7 @@ TEST(UnpackOpTest, BoolThreeOutputs) { /*type=*/TensorType_BOOL); } -TEST(UnpackOpTest, BoolThreeOutputsAxisOne) { +TEST(UnpackOpTestBool, BoolThreeOutputsAxisOne) { Check( /*axis=*/1, /*input_shape=*/{3, 2}, /*input_data=*/{true, false, true, false, true, false}, @@ -307,7 +169,7 @@ TEST(UnpackOpTest, BoolThreeOutputsAxisOne) { /*type=*/TensorType_BOOL); } -TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisOne) { +TEST(UnpackOpTestBool, BoolThreeOutputsNegativeAxisOne) { Check( /*axis=*/-1, /*input_shape=*/{3, 2}, /*input_data=*/{true, false, true, false, true, false}, @@ -316,7 +178,7 @@ TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisOne) { /*type=*/TensorType_BOOL); } -TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisTwo) { +TEST(UnpackOpTestBool, BoolThreeOutputsNegativeAxisTwo) { Check( /*axis=*/-2, /*input_shape=*/{3, 2}, /*input_data=*/{true, false, true, false, true, false}, @@ -325,7 +187,7 @@ TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisTwo) { /*type=*/TensorType_BOOL); } -TEST(UnpackOpTest, BoolOneOutput) { +TEST(UnpackOpTestBool, BoolOneOutput) { Check( /*axis=*/0, /*input_shape=*/{1, 6}, /*input_data=*/{true, false, true, false, true, false}, @@ -334,7 +196,7 @@ TEST(UnpackOpTest, BoolOneOutput) { /*type=*/TensorType_BOOL); } -TEST(UnpackOpTest, BoolThreeDimensionsOutputs) { +TEST(UnpackOpTestBool, BoolThreeDimensionsOutputs) { Check( /*axis=*/2, /*input_shape=*/{2, 2, 2}, /*input_data=*/{true, false, true, false, true, false, true, false}, @@ -344,7 +206,7 @@ TEST(UnpackOpTest, BoolThreeDimensionsOutputs) { /*type=*/TensorType_BOOL); } -TEST(UnpackOpTest, BoolVectorToScalar) { +TEST(UnpackOpTestBool, BoolVectorToScalar) { Check(/*axis=*/0, /*input_shape=*/{5}, /*input_data=*/{true, false, true, false, true}, /*exp_output_shape=*/{{}, {}, {}, {}, {}}, From 00a4b961985f5fc258dd055a20e59163db3bd386 Mon Sep 17 00:00:00 2001 From: JaehunRyu Date: Sun, 9 Feb 2020 23:09:04 +0900 Subject: [PATCH 007/492] Fixed FileWriter docs and raise msg --- tensorflow/python/summary/writer/writer.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index e171a9ed2c3..87d1b625449 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -291,7 +291,7 @@ class FileWriter(SummaryToEventTransformer): When constructed with a `tf.compat.v1.Session` parameter, a `FileWriter` instead forms a compatibility layer over new graph-based summaries - (`tf.contrib.summary`) to facilitate the use of new summary writing with + to facilitate the use of new summary writing with pre-existing code that expects a `FileWriter` instance. This class is not thread-safe. @@ -328,15 +328,8 @@ class FileWriter(SummaryToEventTransformer): ``` The `session` argument to the constructor makes the returned `FileWriter` a - compatibility layer over new graph-based summaries (`tf.contrib.summary`). - Crucially, this means the underlying writer resource and events file will - be shared with any other `FileWriter` using the same `session` and `logdir`, - and with any `tf.contrib.summary.SummaryWriter` in this session using the - the same shared resource name (which by default scoped to the logdir). If - no such resource exists, one will be created using the remaining arguments - to this constructor, but if one already exists those arguments are ignored. - In either case, ops will be added to `session.graph` to control the - underlying file writer resource. See `tf.contrib.summary` for more details. + compatibility layer over new graph-based summaries. + Args: logdir: A string. Directory where event file will be written. @@ -354,13 +347,13 @@ class FileWriter(SummaryToEventTransformer): @compatibility(eager) `FileWriter` is not compatible with eager execution. To write TensorBoard - summaries under eager execution, use `tf.contrib.summary` instead. + summaries under eager execution, use `tf.compat.v1.disable_eager_execution()` before the code. @end_compatibility """ if context.executing_eagerly(): raise RuntimeError( "tf.summary.FileWriter is not compatible with eager execution. " - "Use tf.contrib.summary instead.") + "Use `tf.compat.v1.disable_eager_execution()` before the code") if session is not None: event_writer = EventFileWriterV2( session, logdir, max_queue, flush_secs, filename_suffix) From 4d9e067fae2971ce6f8a3e90ab9a2e608cc60415 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 9 Feb 2020 07:44:19 -0800 Subject: [PATCH 008/492] Update writer.py --- tensorflow/python/summary/writer/writer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index 87d1b625449..cce21f86b28 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -328,8 +328,11 @@ class FileWriter(SummaryToEventTransformer): ``` The `session` argument to the constructor makes the returned `FileWriter` a - compatibility layer over new graph-based summaries. - + compatibility layer over new graph-based summaries (`tf.summary`). + Crucially, this means the underlying writer resource and events file will + be shared with any other `FileWriter` using the same `session` and `logdir`. + In either case, ops will be added to `session.graph` to control the + underlying file writer resource. Args: logdir: A string. Directory where event file will be written. @@ -346,14 +349,14 @@ class FileWriter(SummaryToEventTransformer): RuntimeError: If called with eager execution enabled. @compatibility(eager) - `FileWriter` is not compatible with eager execution. To write TensorBoard - summaries under eager execution, use `tf.compat.v1.disable_eager_execution()` before the code. + `v1.summary.FileWriter` is not compatible with eager execution. To write TensorBoard + summaries under eager execution, use `tf.summary.create_file_writer`. @end_compatibility """ if context.executing_eagerly(): raise RuntimeError( - "tf.summary.FileWriter is not compatible with eager execution. " - "Use `tf.compat.v1.disable_eager_execution()` before the code") + "v1.summary.FileWriter is not compatible with eager execution. " + "Use `tf.summary.create_file_writer`") if session is not None: event_writer = EventFileWriterV2( session, logdir, max_queue, flush_secs, filename_suffix) From 2c3bd0f697ee828a097f6147c91f97fbad277f70 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 9 Feb 2020 07:46:17 -0800 Subject: [PATCH 009/492] Update writer.py --- tensorflow/python/summary/writer/writer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index cce21f86b28..1dffd2e5b75 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -350,13 +350,14 @@ class FileWriter(SummaryToEventTransformer): @compatibility(eager) `v1.summary.FileWriter` is not compatible with eager execution. To write TensorBoard - summaries under eager execution, use `tf.summary.create_file_writer`. + summaries under eager execution, use `tf.summary.create_file_writer` or a + `with v1.Graph.as_default():` context. @end_compatibility """ if context.executing_eagerly(): raise RuntimeError( "v1.summary.FileWriter is not compatible with eager execution. " - "Use `tf.summary.create_file_writer`") + "Use `tf.summary.create_file_writer`, or a `with v1.Graph.as_default():` context") if session is not None: event_writer = EventFileWriterV2( session, logdir, max_queue, flush_secs, filename_suffix) From d7e98863ddd36f0de30125e89991a1faaf80e266 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 9 Feb 2020 08:50:11 -0800 Subject: [PATCH 010/492] Update writer.py --- tensorflow/python/summary/writer/writer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index 1dffd2e5b75..cdced098db8 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -329,9 +329,9 @@ class FileWriter(SummaryToEventTransformer): The `session` argument to the constructor makes the returned `FileWriter` a compatibility layer over new graph-based summaries (`tf.summary`). - Crucially, this means the underlying writer resource and events file will + Crucially, this means the underlying writer resource and events file will be shared with any other `FileWriter` using the same `session` and `logdir`. - In either case, ops will be added to `session.graph` to control the + In either case, ops will be added to `session.graph` to control the underlying file writer resource. Args: @@ -351,13 +351,13 @@ class FileWriter(SummaryToEventTransformer): @compatibility(eager) `v1.summary.FileWriter` is not compatible with eager execution. To write TensorBoard summaries under eager execution, use `tf.summary.create_file_writer` or a - `with v1.Graph.as_default():` context. + `with v1.Graph().as_default():` context. @end_compatibility """ if context.executing_eagerly(): raise RuntimeError( "v1.summary.FileWriter is not compatible with eager execution. " - "Use `tf.summary.create_file_writer`, or a `with v1.Graph.as_default():` context") + "Use `tf.summary.create_file_writer`, or a `with v1.Graph().as_default():` context") if session is not None: event_writer = EventFileWriterV2( session, logdir, max_queue, flush_secs, filename_suffix) From 1022952737e2ae810b8b020b6d9f02cb5c55dee3 Mon Sep 17 00:00:00 2001 From: JaehunRyu Date: Tue, 11 Feb 2020 14:28:13 +0900 Subject: [PATCH 011/492] Update text length for pylint error --- configure | 2 +- tensorflow/python/summary/writer/writer.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/configure b/configure index 66b66ba54ed..7a299f3e274 100755 --- a/configure +++ b/configure @@ -4,7 +4,7 @@ set -e set -o pipefail if [ -z "$PYTHON_BIN_PATH" ]; then - PYTHON_BIN_PATH=$(which python || which python3 || true) + PYTHON_BIN_PATH=$( which python3 || true) fi # Set all env variables diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index cdced098db8..5b06820abd9 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -349,15 +349,17 @@ class FileWriter(SummaryToEventTransformer): RuntimeError: If called with eager execution enabled. @compatibility(eager) - `v1.summary.FileWriter` is not compatible with eager execution. To write TensorBoard - summaries under eager execution, use `tf.summary.create_file_writer` or a - `with v1.Graph().as_default():` context. + `v1.summary.FileWriter` is not compatible with eager execution. + To write TensorBoard summaries under eager execution, + use `tf.summary.create_file_writer` or + a `with v1.Graph().as_default():` context. @end_compatibility """ if context.executing_eagerly(): raise RuntimeError( "v1.summary.FileWriter is not compatible with eager execution. " - "Use `tf.summary.create_file_writer`, or a `with v1.Graph().as_default():` context") + "Use `tf.summary.create_file_writer`," + "or a `with v1.Graph().as_default():` context") if session is not None: event_writer = EventFileWriterV2( session, logdir, max_queue, flush_secs, filename_suffix) From d666b1bba534bd9e8c2f2cf29ee8de9e6d705495 Mon Sep 17 00:00:00 2001 From: Judd Date: Tue, 11 Feb 2020 18:32:03 +0800 Subject: [PATCH 012/492] Update wav_to_features.py Remove the leading comma in the generated float array. Withi this leading comma, the generated C code can't be compiled. --- tensorflow/examples/speech_commands/wav_to_features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/examples/speech_commands/wav_to_features.py b/tensorflow/examples/speech_commands/wav_to_features.py index 9ed02b7ab5d..972a37f49ba 100644 --- a/tensorflow/examples/speech_commands/wav_to_features.py +++ b/tensorflow/examples/speech_commands/wav_to_features.py @@ -117,7 +117,7 @@ def wav_to_features(sample_rate, clip_duration_ms, window_size_ms, for value in features.flatten(): if i == 0: f.write('\n ') - f.write(' ,%f' % value) + f.write('%f, ' % value) i = (i + 1) % 10 f.write('\n};\n') From 391b21e9c2bee4164f888c222978d35c666fedde Mon Sep 17 00:00:00 2001 From: "William D. Irons" Date: Tue, 11 Feb 2020 22:29:20 +0000 Subject: [PATCH 013/492] Fix gpu test when running the ./configure script Today if you run the ./configure script without setting the variable TF_NEED_CUDA, in the resulting .tf_configure.bazelrc you get: test:v1 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial test:v1 --build_tag_filters=-benchmark-test,-no_oss,-gpu test:v2 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial,-v1only test:v2 --build_tag_filters=-benchmark-test,-no_oss,-gpu,-v1only This is incorrect because -gpu means exclude the gpu test. It should be -no_gpu. Debugging the problem I found when the code was switched from using os.env to using environ_cp, that the method system_specific_test_config was never updated. So unless the environment variable TF_NEED_CUDA was set before running ./configure then answering yes for CUDA would not select the correct test filter for gpus. With this change the test filters are correct: test:v1 --test_tag_filters=-benchmark-test,-no_oss,-no_gpu,-oss_serial test:v1 --build_tag_filters=-benchmark-test,-no_oss,-no_gpu test:v2 --test_tag_filters=-benchmark-test,-no_oss,-no_gpu,-oss_serial,-v1only test:v2 --build_tag_filters=-benchmark-test,-no_oss,-no_gpu,-v1only --- configure.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/configure.py b/configure.py index 4cb68924db4..563c2db23ab 100644 --- a/configure.py +++ b/configure.py @@ -1155,7 +1155,7 @@ def set_trisycl_include_dir(environ_cp): write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) -def system_specific_test_config(env): +def system_specific_test_config(environ_cp): """Add default build and test flags required for TF tests to bazelrc.""" write_to_bazelrc('test --flaky_test_attempts=3') write_to_bazelrc('test --test_size_filters=small,medium') @@ -1171,14 +1171,14 @@ def system_specific_test_config(env): test_only_filters = ['-oss_serial'] if is_windows(): test_and_build_filters.append('-no_windows') - if env.get('TF_NEED_CUDA', None) == '1': + if environ_cp.get('TF_NEED_CUDA', None) == '1': test_and_build_filters += ['-no_windows_gpu', '-no_gpu'] else: test_and_build_filters.append('-gpu') elif is_macos(): test_and_build_filters += ['-gpu', '-nomac', '-no_mac'] elif is_linux(): - if env.get('TF_NEED_CUDA', None) == '1': + if environ_cp.get('TF_NEED_CUDA', None) == '1': test_and_build_filters.append('-no_gpu') write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') else: @@ -1523,7 +1523,7 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) - system_specific_test_config(os.environ) + system_specific_test_config(environ_cp) set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False) if environ_cp.get('TF_CONFIGURE_IOS') == '1': From cb2702fcdee38d1dceaad40bfa91797c5b58a131 Mon Sep 17 00:00:00 2001 From: Jonas Skog Date: Wed, 12 Feb 2020 17:20:43 +0100 Subject: [PATCH 014/492] Micro: Return error if tensor allocation fails. --- .../micro/examples/network_tester/network_tester_test.cc | 7 ++++++- tensorflow/lite/micro/micro_allocator.cc | 2 +- tensorflow/lite/micro/micro_interpreter.cc | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc b/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc index cebaae77486..5fa364a0a2e 100644 --- a/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc +++ b/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc @@ -70,7 +70,12 @@ TF_LITE_MICRO_TEST(TestInvoke) { tflite::MicroInterpreter interpreter(model, resolver, tensor_arena, TENSOR_ARENA_SIZE, error_reporter); - interpreter.AllocateTensors(); + + TfLiteStatus allocate_status = interpreter.AllocateTensors(); + if (allocate_status != kTfLiteOk) { + error_reporter->Report("alloc failed\n"); + } + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocate_status); TfLiteTensor* input = interpreter.input(0); memcpy(input->data.uint8, input_data, input->bytes); diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index 72a0e0c3e11..c4d26403553 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -410,7 +410,7 @@ TfLiteStatus MicroAllocator::AllocateNodeAndRegistrations( status = GetRegistrationFromOpCode(opcode, op_resolver, error_reporter_, &(output[i].registration)); if (status != kTfLiteOk) { - error_reporter_->Report("Failed to get registration from op code % d\n ", + error_reporter_->Report("Failed to get registration from op code %d\n ", opcode); return status; } diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index f6f8127f467..14aaa3e1e62 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -140,7 +140,7 @@ TfLiteStatus MicroInterpreter::Invoke() { // Ensure tensors are allocated before the interpreter is invoked to avoid // difficult to debug segfaults. if (!tensors_allocated_) { - AllocateTensors(); + TF_LITE_ENSURE_OK(&context_, AllocateTensors()); } // Init method is not yet implemented. From b06ea087e4596a714c282a91b7618c8fa4e84955 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Thu, 13 Feb 2020 10:35:29 +0000 Subject: [PATCH 015/492] Versioning for operators MIN/MAX and PACK/UNPACK, 16x8. --- tensorflow/lite/kernels/register.cc | 8 ++++---- tensorflow/lite/toco/tflite/op_version.cc | 4 ++++ .../lite/tools/versioning/op_version.cc | 20 ++++++++++++++++--- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 4435008b653..8c1d42c3317 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -172,10 +172,10 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 3); AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 3); AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX(), /* min_version */ 1, /* max_version */ 2); @@ -240,14 +240,14 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); AddBuiltin(BuiltinOperator_PACK, Register_PACK(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 3); AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(), /* min_version */ 1, - /* max_version */ 3); + /* max_version */ 4); AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV(), /* min_version */ 1, /* max_version */ 2); diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index 2e27c1d8a0f..17a736eea7d 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -98,8 +98,10 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kMaxPool, 2}, "1.14.0"}, {{OperatorType::kMaximum, 1}, "1.14.0"}, {{OperatorType::kMaximum, 2}, "1.14.0"}, + {{OperatorType::kMaximum, 3}, kPendingReleaseOpVersion}, {{OperatorType::kMinimum, 1}, "1.14.0"}, {{OperatorType::kMinimum, 2}, "1.14.0"}, + {{OperatorType::kMinimum, 3}, kPendingReleaseOpVersion}, {{OperatorType::kMul, 1}, "1.5.0"}, {{OperatorType::kMul, 2}, "1.14.0"}, {{OperatorType::kMul, 3}, "1.15.0"}, @@ -160,6 +162,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kExpandDims, 1}, "1.10.0"}, {{OperatorType::kPack, 1}, "1.11.0"}, {{OperatorType::kPack, 2}, "1.14.0"}, + {{OperatorType::kPack, 3}, kPendingReleaseOpVersion}, {{OperatorType::kShape, 1}, "1.10.0"}, {{OperatorType::kSlice, 1}, "1.14.0"}, {{OperatorType::kSlice, 2}, "1.14.0"}, @@ -171,6 +174,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kUnpack, 1}, "1.11.0"}, {{OperatorType::kUnpack, 2}, "1.14.0"}, {{OperatorType::kUnpack, 3}, kPendingReleaseOpVersion}, + {{OperatorType::kUnpack, 4}, kPendingReleaseOpVersion}, {{OperatorType::kLeakyRelu, 1}, "1.13.1"}, {{OperatorType::kLogistic, 1}, "1.14.0"}, {{OperatorType::kLogistic, 2}, "1.14.0"}, diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index ef81d0169f5..9c73bc3b25f 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -227,6 +227,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { if (op_sig.input_types.at(0) == TensorType_BOOL) { return 3; } + if (op_sig.input_types.at(0) == TensorType_INT16 && + op_sig.output_types.at(0) == TensorType_INT16) { + return 4; + } return 1; case BuiltinOperator_DEQUANTIZE: @@ -273,6 +277,19 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_PACK: + if (op_sig.input_types.at(0) == TensorType_INT8) { + return 2; + } + + if (op_sig.input_types.at(0) == TensorType_INT16 && + op_sig.output_types.at(0) == TensorType_INT16) { + return 3; + } + return 1; + case BuiltinOperator_AVERAGE_POOL_2D: case BuiltinOperator_ADD: case BuiltinOperator_SPACE_TO_BATCH_ND: @@ -280,8 +297,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_BATCH_TO_SPACE_ND: case BuiltinOperator_CONCATENATION: case BuiltinOperator_MAX_POOL_2D: - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: case BuiltinOperator_PAD: case BuiltinOperator_PADV2: case BuiltinOperator_SOFTMAX: @@ -293,7 +308,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_RELU6: case BuiltinOperator_RESIZE_BILINEAR: case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: - case BuiltinOperator_PACK: case BuiltinOperator_TANH: case BuiltinOperator_LOGISTIC: case BuiltinOperator_LOG_SOFTMAX: From cbbd02a3c32b7e8c5622e737cb4ab2237ce8ea39 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Wed, 19 Feb 2020 17:37:09 +0000 Subject: [PATCH 016/492] Addressed reviewer comments. --- tensorflow/lite/tools/versioning/op_version.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index f2864ffcb1e..78428707b0f 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -293,6 +293,16 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_MAXIMUM: case BuiltinOperator_MINIMUM: + if (op_sig.input_types.at(0) == TensorType_INT8) { + return 2; + } + + if (op_sig.input_types.at(0) == TensorType_INT16 && + op_sig.output_types.at(0) == TensorType_INT16) { + return 3; + } + return 1; + case BuiltinOperator_PACK: if (op_sig.input_types.at(0) == TensorType_INT8) { return 2; From 94d4e8673e9f8e23794d8e1b9b16b97631c37a55 Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Sun, 24 Nov 2019 16:01:23 -0800 Subject: [PATCH 017/492] Update debug log to support redirection --- .../experimental/micro/openmvcam/debug_log.cc | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tensorflow/lite/experimental/micro/openmvcam/debug_log.cc diff --git a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc new file mode 100644 index 00000000000..d0499ab3612 --- /dev/null +++ b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "tensorflow/lite/experimental/micro/debug_log.h" + +// These are set by openmv py_tf.c code to redirect printing to an error message +// buffer... + +static char *py_tf_putchar_buffer = NULL; +static size_t py_tf_putchar_buffer_len = 0; + +extern "C" void DebugLog(const char* s) { + for (size_t i = 0, j = strlen(s); i < j; i++) { + if (py_tf_putchar_buffer_len) { + *py_tf_putchar_buffer++ = s[i]; + py_tf_putchar_buffer_len--; + } else { + putchar(s[i]); + } + } +} From 2df3cce0f34e0fb5943400a2c549d64310c1a8be Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Fri, 3 May 2019 23:15:27 -0700 Subject: [PATCH 018/492] Change printf to puts --- tensorflow/lite/experimental/micro/openmvcam/debug_log.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc index d0499ab3612..c911b5c7b08 100644 --- a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc +++ b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc @@ -24,6 +24,7 @@ static char *py_tf_putchar_buffer = NULL; static size_t py_tf_putchar_buffer_len = 0; extern "C" void DebugLog(const char* s) { +<<<<<<< HEAD for (size_t i = 0, j = strlen(s); i < j; i++) { if (py_tf_putchar_buffer_len) { *py_tf_putchar_buffer++ = s[i]; @@ -32,4 +33,7 @@ extern "C" void DebugLog(const char* s) { putchar(s[i]); } } +======= + puts(s); +>>>>>>> Change printf to puts } From cadc0e18991d7e6f1fc276bfb8d39a3f587a7697 Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Mon, 14 Oct 2019 21:31:52 -0700 Subject: [PATCH 019/492] Fix delete operator being included --- tensorflow/lite/micro/memory_planner/greedy_memory_planner.h | 1 + tensorflow/lite/micro/memory_planner/linear_memory_planner.h | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h index f2c77ed94f3..29ae39c9bf1 100644 --- a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h +++ b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h @@ -87,6 +87,7 @@ class GreedyMemoryPlanner : public MemoryPlanner { }; private: + TF_LITE_REMOVE_VIRTUAL_DELETE // Whether a buffer is active in a given time range. bool DoesEntryOverlapInTime(const ListEntry* entry, const int first_time_used, const int last_time_used) const; diff --git a/tensorflow/lite/micro/memory_planner/linear_memory_planner.h b/tensorflow/lite/micro/memory_planner/linear_memory_planner.h index 4d77e778237..b04b2c96788 100644 --- a/tensorflow/lite/micro/memory_planner/linear_memory_planner.h +++ b/tensorflow/lite/micro/memory_planner/linear_memory_planner.h @@ -37,6 +37,7 @@ class LinearMemoryPlanner : public MemoryPlanner { int buffer_index, int* offset) override; private: + TF_LITE_REMOVE_VIRTUAL_DELETE static constexpr int kMaxBufferCount = 1024; size_t buffer_offsets_[kMaxBufferCount]; int current_buffer_count_; From 9903797166938a3f30e9cfb73d669b969f990eed Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Sun, 24 Nov 2019 16:05:36 -0800 Subject: [PATCH 020/492] Update debug log --- tensorflow/lite/experimental/micro/openmvcam/debug_log.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc index c911b5c7b08..d0499ab3612 100644 --- a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc +++ b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc @@ -24,7 +24,6 @@ static char *py_tf_putchar_buffer = NULL; static size_t py_tf_putchar_buffer_len = 0; extern "C" void DebugLog(const char* s) { -<<<<<<< HEAD for (size_t i = 0, j = strlen(s); i < j; i++) { if (py_tf_putchar_buffer_len) { *py_tf_putchar_buffer++ = s[i]; @@ -33,7 +32,4 @@ extern "C" void DebugLog(const char* s) { putchar(s[i]); } } -======= - puts(s); ->>>>>>> Change printf to puts } From 654854703724eb9428b0b97c044e36200cf63233 Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Sun, 24 Nov 2019 16:41:49 -0800 Subject: [PATCH 021/492] Remove static --- tensorflow/lite/experimental/micro/openmvcam/debug_log.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc index d0499ab3612..0caab617c8b 100644 --- a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc +++ b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc @@ -20,8 +20,8 @@ limitations under the License. // These are set by openmv py_tf.c code to redirect printing to an error message // buffer... -static char *py_tf_putchar_buffer = NULL; -static size_t py_tf_putchar_buffer_len = 0; +char *py_tf_putchar_buffer = NULL; +size_t py_tf_putchar_buffer_len = 0; extern "C" void DebugLog(const char* s) { for (size_t i = 0, j = strlen(s); i < j; i++) { From a1dd6f9580caa655d9e5af88cfd0581f87843564 Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Thu, 19 Dec 2019 09:56:33 -0800 Subject: [PATCH 022/492] Add openmv cam --- tensorflow/lite/micro/openmvcam/debug_log.cc | 35 ++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tensorflow/lite/micro/openmvcam/debug_log.cc diff --git a/tensorflow/lite/micro/openmvcam/debug_log.cc b/tensorflow/lite/micro/openmvcam/debug_log.cc new file mode 100644 index 00000000000..0caab617c8b --- /dev/null +++ b/tensorflow/lite/micro/openmvcam/debug_log.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "tensorflow/lite/experimental/micro/debug_log.h" + +// These are set by openmv py_tf.c code to redirect printing to an error message +// buffer... + +char *py_tf_putchar_buffer = NULL; +size_t py_tf_putchar_buffer_len = 0; + +extern "C" void DebugLog(const char* s) { + for (size_t i = 0, j = strlen(s); i < j; i++) { + if (py_tf_putchar_buffer_len) { + *py_tf_putchar_buffer++ = s[i]; + py_tf_putchar_buffer_len--; + } else { + putchar(s[i]); + } + } +} From 66f03f534fcecafcc030d7d3d9c2551735f662fc Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Thu, 19 Dec 2019 09:56:54 -0800 Subject: [PATCH 023/492] Add openmv cam --- .../experimental/micro/openmvcam/debug_log.cc | 35 ------------------- 1 file changed, 35 deletions(-) delete mode 100644 tensorflow/lite/experimental/micro/openmvcam/debug_log.cc diff --git a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc b/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc deleted file mode 100644 index 0caab617c8b..00000000000 --- a/tensorflow/lite/experimental/micro/openmvcam/debug_log.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include "tensorflow/lite/experimental/micro/debug_log.h" - -// These are set by openmv py_tf.c code to redirect printing to an error message -// buffer... - -char *py_tf_putchar_buffer = NULL; -size_t py_tf_putchar_buffer_len = 0; - -extern "C" void DebugLog(const char* s) { - for (size_t i = 0, j = strlen(s); i < j; i++) { - if (py_tf_putchar_buffer_len) { - *py_tf_putchar_buffer++ = s[i]; - py_tf_putchar_buffer_len--; - } else { - putchar(s[i]); - } - } -} From 520533a52bd067a5c785e5e021981eb96eddf6dc Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Thu, 19 Dec 2019 10:31:56 -0800 Subject: [PATCH 024/492] Fix debug path --- tensorflow/lite/micro/openmvcam/debug_log.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/micro/openmvcam/debug_log.cc b/tensorflow/lite/micro/openmvcam/debug_log.cc index 0caab617c8b..345a2c20226 100644 --- a/tensorflow/lite/micro/openmvcam/debug_log.cc +++ b/tensorflow/lite/micro/openmvcam/debug_log.cc @@ -15,7 +15,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/micro/debug_log.h" +#include "tensorflow/lite/micro/debug_log.h" // These are set by openmv py_tf.c code to redirect printing to an error message // buffer... From b01309c36f90e7cbbca7d7b7cc2d176eed70cc87 Mon Sep 17 00:00:00 2001 From: "Kwabena W. Agyeman" Date: Thu, 20 Feb 2020 22:24:48 -0800 Subject: [PATCH 025/492] Remove adding delete operators --- tensorflow/lite/micro/memory_planner/greedy_memory_planner.h | 1 - tensorflow/lite/micro/memory_planner/linear_memory_planner.h | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h index 29ae39c9bf1..f2c77ed94f3 100644 --- a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h +++ b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h @@ -87,7 +87,6 @@ class GreedyMemoryPlanner : public MemoryPlanner { }; private: - TF_LITE_REMOVE_VIRTUAL_DELETE // Whether a buffer is active in a given time range. bool DoesEntryOverlapInTime(const ListEntry* entry, const int first_time_used, const int last_time_used) const; diff --git a/tensorflow/lite/micro/memory_planner/linear_memory_planner.h b/tensorflow/lite/micro/memory_planner/linear_memory_planner.h index b04b2c96788..4d77e778237 100644 --- a/tensorflow/lite/micro/memory_planner/linear_memory_planner.h +++ b/tensorflow/lite/micro/memory_planner/linear_memory_planner.h @@ -37,7 +37,6 @@ class LinearMemoryPlanner : public MemoryPlanner { int buffer_index, int* offset) override; private: - TF_LITE_REMOVE_VIRTUAL_DELETE static constexpr int kMaxBufferCount = 1024; size_t buffer_offsets_[kMaxBufferCount]; int current_buffer_count_; From bf21db9695bf231b668b31e55f96da1e2327f12f Mon Sep 17 00:00:00 2001 From: Jonas Skog Date: Fri, 21 Feb 2020 10:58:05 +0100 Subject: [PATCH 026/492] Clarified error message. --- .../lite/micro/examples/network_tester/network_tester_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc b/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc index aee7353273d..0650222b970 100644 --- a/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc +++ b/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc @@ -73,7 +73,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { TfLiteStatus allocate_status = interpreter.AllocateTensors(); if (allocate_status != kTfLiteOk) { - error_reporter->Report("alloc failed\n"); + TF_LITE_REPORT_ERROR(error_reporter, "Tensor allocation failed\n"); } TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocate_status); From 85690a9dae5e6e72ed52a02270659e5ce9f9ec3d Mon Sep 17 00:00:00 2001 From: JaehunRyu <2013103902@khu.ac.kr> Date: Sat, 22 Feb 2020 18:54:09 +0900 Subject: [PATCH 027/492] Fixex some error --- configure | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configure b/configure index 7a299f3e274..66b66ba54ed 100755 --- a/configure +++ b/configure @@ -4,7 +4,7 @@ set -e set -o pipefail if [ -z "$PYTHON_BIN_PATH" ]; then - PYTHON_BIN_PATH=$( which python3 || true) + PYTHON_BIN_PATH=$(which python || which python3 || true) fi # Set all env variables From b928c1095272048a32acd63dd970586cc0757762 Mon Sep 17 00:00:00 2001 From: Ashutosh Hathidara Date: Sat, 29 Feb 2020 19:56:42 +0530 Subject: [PATCH 028/492] Specifying about representative_datasets in TfLiteConverter --- tensorflow/lite/python/lite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index bda8898d879..25bdd2201c8 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -308,7 +308,8 @@ class TFLiteConverterV2(TFLiteConverterBase): to apply when converting the model. E.g. `[Optimize.DEFAULT]` representative_dataset: A representative dataset that can be used to generate input and output samples for the model. The converter can use the - dataset to evaluate different optimizations. + dataset to evaluate different optimizations. Note that this is a necessary + attribute since the conversion optimization depends upon it. experimental_new_converter: Experimental flag, subject to change. Enables MLIR-based conversion instead of TOCO conversion. experimental_new_quantizer: Experimental flag, subject to change. From e946f44aed4c7d90ad5f3cc8bd56e2fc29d47ad6 Mon Sep 17 00:00:00 2001 From: ayushmankumar7 Date: Mon, 2 Mar 2020 12:36:01 +0530 Subject: [PATCH 029/492] module_test solved --- tensorflow/tools/api/tests/module_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/api/tests/module_test.py b/tensorflow/tools/api/tests/module_test.py index 2b3a7dbe31b..5397278f5f3 100644 --- a/tensorflow/tools/api/tests/module_test.py +++ b/tensorflow/tools/api/tests/module_test.py @@ -76,7 +76,7 @@ class ModuleTest(test.TestCase): if hasattr(tf, '_major_api_version') and tf._major_api_version == 2: tf.summary.create_file_writer else: - tf.summary.FileWriter + tf.compat.v1.summary.FileWriter # pylint: enable=pointless-statement From 94f8a9d5d519e6e943e05dc3bae67ab217171183 Mon Sep 17 00:00:00 2001 From: Ir1d Date: Wed, 4 Mar 2020 14:02:41 +0800 Subject: [PATCH 030/492] docs: add examples for random_uniform, random_normal and random_binomial in tf.keras.backend closes #31277 --- tensorflow/python/keras/backend.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index f83ed74c2f8..54b1df156c1 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -5653,6 +5653,13 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): Returns: A tensor with normal distribution of values. + + Example: + + >>> kvar = tf.keras.backend.random_normal((2,3), 0, 1) + >>> kvar + """ if dtype is None: dtype = floatx() @@ -5677,6 +5684,13 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): Returns: A tensor. + + Example: + + >>> kvar = tf.keras.backend.random_uniform((2,3), 0, 1) + >>> kvar + """ if dtype is None: dtype = floatx() @@ -5702,6 +5716,13 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None): Returns: A tensor. + + Example: + + >>> kvar = tf.keras.backend.random_binomial((2,3), 0.5) + >>> kvar + """ if dtype is None: dtype = floatx() From 4b06542ffda4c3cda2fb430b758cd097501e21a5 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Wed, 4 Mar 2020 13:03:56 +0100 Subject: [PATCH 031/492] Ignore control out edges from const node. When we loop throug the input edges, only non-const input edges are added to the list of connections. With this fix we do the same for the output edges, otherwise the EngineInfo structures could be inconsistent if we have multiple TRT segments. --- tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 6f276546451..cb11ab53d5c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -221,9 +221,11 @@ Status GetEngineInfo(const Graph* g, } if (edge->IsControlEdge()) { // Control output. - info->connections.emplace_back(output_node->name(), output_node->id(), - node_name, node_id, - /*input_edge=*/false); + if (node->type_string() != "Const") { + info->connections.emplace_back(output_node->name(), output_node->id(), + node_name, node_id, + /*input_edge=*/false); + } } else { // Data output. int port = Graph::kControlSlot - 1; From a69af8b11bb5838431da5a59df0e4230e6acb188 Mon Sep 17 00:00:00 2001 From: ngc92 <7938269+ngc92@users.noreply.github.com> Date: Thu, 5 Mar 2020 11:36:51 +0200 Subject: [PATCH 032/492] document `(Tensor, shape)` list possibility for `shapes` argument of `assert_shapes` --- tensorflow/python/ops/check_ops.py | 37 ++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 242c41b2927..df7ab3e61e1 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1634,8 +1634,8 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None, prefix) are both treated as having a single dimension of size one. Args: - shapes: dictionary with (`Tensor` to shape) items. A shape must be an - iterable. + shapes: dictionary with (`Tensor` to shape) items, or a list of + (`Tensor`, shape) tuples. A shape must be an iterable. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of the violating tensor. summarize: Print this many entries of the tensor. @@ -1658,14 +1658,27 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): Example: - ```python - tf.assert_shapes([ - (x, ('N', 'Q')), - (y, ('N', 'D')), - (param, ('Q',)), - (scalar, ()) - ]) - ``` + >>> n = 10 + >>> q = 3 + >>> d = 7 + >>> x = tf.zeros([n,q]) + >>> y = tf.ones([n,d]) + >>> param = tf.Variable([1.0, 2.0, 3.0]) + >>> scalar = 1.0 + >>> tf.debugging.assert_shapes([ + ... (x, ('N', 'Q')), + ... (y, ('N', 'D')), + ... (param, ('Q',)), + ... (scalar, ()), + ... ]) + + >>> tf.debugging.assert_shapes([ + ... (x, ('N', 'D')), + ... (y, ('N', 'D')) + ... ]) + Traceback (most recent call last): + ... + ValueError: ... Example of adding a dependency to an operation: @@ -1693,8 +1706,8 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): prefix) are both treated as having a single dimension of size one. Args: - shapes: dictionary with (`Tensor` to shape) items. A shape must be an - iterable. + shapes: dictionary with (`Tensor` to shape) items, or a list of + (`Tensor`, shape) tuples. A shape must be an iterable. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of the violating tensor. summarize: Print this many entries of the tensor. From 0aa22bd95adc4ed318e029ac5b495db7a42bc150 Mon Sep 17 00:00:00 2001 From: Ir1d Date: Thu, 5 Mar 2020 19:56:13 +0800 Subject: [PATCH 033/492] update as per review comment --- tensorflow/python/keras/backend.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 54b1df156c1..86687d4f71f 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -1479,7 +1479,7 @@ def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): Example: - >>> kvar = tf.keras.backend.random_uniform_variable((2,3), 0, 1) + >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3), low=0.0, high=1.0) >>> kvar @@ -1513,7 +1513,7 @@ def random_normal_variable(shape, mean, scale, dtype=None, name=None, Example: - >>> kvar = tf.keras.backend.random_normal_variable((2,3), 0, 1) + >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3), mean=0.0, scale=1.0) >>> kvar @@ -5656,8 +5656,8 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): Example: - >>> kvar = tf.keras.backend.random_normal((2,3), 0, 1) - >>> kvar + >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3), mean=0.0, stddev=1.0) + >>> random_normal_tensor """ @@ -5687,8 +5687,8 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): Example: - >>> kvar = tf.keras.backend.random_uniform((2,3), 0, 1) - >>> kvar + >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3), minval=0.0, maxval=1.0) + >>> random_uniform_tensor """ @@ -5719,8 +5719,8 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None): Example: - >>> kvar = tf.keras.backend.random_binomial((2,3), 0.5) - >>> kvar + >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3), p=0.5) + >>> random_binomial_tensor """ From 5c5acccc90fd8edaea046007d8a082a81b504937 Mon Sep 17 00:00:00 2001 From: Ashutosh Hathidara Date: Thu, 5 Mar 2020 18:21:17 +0530 Subject: [PATCH 034/492] Adding changes of #37188 in this PR --- tensorflow/python/keras/optimizer_v2/nadam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/keras/optimizer_v2/nadam.py b/tensorflow/python/keras/optimizer_v2/nadam.py index a5f5e2dd8a7..ce3fcc17fd3 100644 --- a/tensorflow/python/keras/optimizer_v2/nadam.py +++ b/tensorflow/python/keras/optimizer_v2/nadam.py @@ -78,7 +78,7 @@ class Nadam(optimizer_v2.OptimizerV2): rate for the exponentially weighted infinity norm. epsilon: A small constant for numerical stability. name: Optional name for the operations created when applying gradients. - Defaults to "Adamax". + Defaults to "Nadam". **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to From 142d271f014953d8f3cb5dac7cfac1e224e84d17 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 5 Mar 2020 05:54:47 -0800 Subject: [PATCH 035/492] Fix ubuntu sanity-check --- tensorflow/python/summary/writer/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index 5b06820abd9..889f71bc669 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -359,7 +359,7 @@ class FileWriter(SummaryToEventTransformer): raise RuntimeError( "v1.summary.FileWriter is not compatible with eager execution. " "Use `tf.summary.create_file_writer`," - "or a `with v1.Graph().as_default():` context") + "or a `with v1.Graph().as_default():` context") if session is not None: event_writer = EventFileWriterV2( session, logdir, max_queue, flush_secs, filename_suffix) From b30f9cf90463a2087660206a45f872fff666491a Mon Sep 17 00:00:00 2001 From: Ir1d Date: Fri, 6 Mar 2020 15:47:32 +0800 Subject: [PATCH 036/492] fix pylint --- tensorflow/python/keras/backend.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 86687d4f71f..45102fa2cfb 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -1479,7 +1479,8 @@ def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): Example: - >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3), low=0.0, high=1.0) + >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3), + low=0.0, high=1.0) >>> kvar @@ -1513,7 +1514,8 @@ def random_normal_variable(shape, mean, scale, dtype=None, name=None, Example: - >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3), mean=0.0, scale=1.0) + >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3), + mean=0.0, scale=1.0) >>> kvar @@ -5656,7 +5658,8 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): Example: - >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3), mean=0.0, stddev=1.0) + >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3), + mean=0.0, stddev=1.0) >>> random_normal_tensor @@ -5687,7 +5690,8 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): Example: - >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3), minval=0.0, maxval=1.0) + >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3), + minval=0.0, maxval=1.0) >>> random_uniform_tensor @@ -5719,7 +5723,8 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None): Example: - >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3), p=0.5) + >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3), + p=0.5) >>> random_binomial_tensor From 3b458633e15e816ecd10cabab934fa16b9200534 Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Fri, 6 Mar 2020 17:04:43 +0800 Subject: [PATCH 037/492] [tflite] NNAPI Pow op allows scalar input --- tensorflow/lite/delegates/nnapi/nnapi_delegate.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index a90b28fcd10..760c24c63f5 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -149,6 +149,7 @@ bool IsScalarInputSupported(int builtin_code) { case kTfLiteBuiltinGreaterEqual: case kTfLiteBuiltinLess: case kTfLiteBuiltinLessEqual: + case kTfLiteBuiltinPow: return true; default: return false; From 19b174ea81d1bfa770879863c0f74cec894ec068 Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Fri, 6 Mar 2020 22:31:48 +0200 Subject: [PATCH 038/492] Add `dtype` property to RaggedTensorSpec --- tensorflow/python/ops/ragged/ragged_tensor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index fccdf8fe3c1..d0acb448540 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -2264,6 +2264,11 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): else: return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value) + @property + def dtype(self): + """The `tf.dtypes.DType` specified by this type for the RaggedTensor.""" + return self._dtype + def _serialize(self): return (self._shape, self._dtype, self._ragged_rank, self._row_splits_dtype) From 6f6459880217b48dd0d0baa59fdea663bd311a94 Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Fri, 6 Mar 2020 22:35:10 +0200 Subject: [PATCH 039/492] [tf.data] Add support for any Tensor describable by `tf.TypeSpec` to Dataset.from_generator --- .../data/kernel_tests/from_generator_test.py | 66 ++++++- tensorflow/python/data/ops/dataset_ops.py | 171 ++++++++++++------ tensorflow/python/data/util/structure.py | 16 +- 3 files changed, 183 insertions(+), 70 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py index d320b281136..b020f4b67bb 100644 --- a/tensorflow/python/data/kernel_tests/from_generator_test.py +++ b/tensorflow/python/data/kernel_tests/from_generator_test.py @@ -28,10 +28,16 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_spec from tensorflow.python.platform import test + class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): def _testFromGenerator(self, generator, elem_sequence, num_repeats, @@ -241,7 +247,8 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaisesOpError("The expected type was int64"): + with self.assertRaisesOpError(r"The expected structure was " + r"\(TensorShape\(\[3\]\), tf\.int64, None\)"): self.evaluate(get_next()) self.assertAllEqual([7, 8, 9], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -261,7 +268,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): + with self.assertRaisesOpError(r"element of TypeSpec\(TensorShape\(\[3\]\), tf\.int64, None\) was expected"): self.evaluate(get_next()) self.assertAllEqual([11, 12, 13], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -283,10 +290,12 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual((1, 2), self.evaluate(get_next())) self.assertEqual((3, 4), self.evaluate(get_next())) with self.assertRaisesOpError( - r"The expected structure was \(tf\.int64, tf\.int64\)"): + r"element of TypeSpec\(\(TensorShape\(None\), tf\.int64, None\), " + "\(TensorShape\(None\), tf\.int64, None\)\)"): self.evaluate(get_next()) with self.assertRaisesOpError( - r"The expected structure was \(tf\.int64, tf\.int64\)"): + r"The expected structure was \(\(TensorShape\(None\), tf\.int64, None\)" + r", \(TensorShape\(None\), tf\.int64, None\)\)"): self.evaluate(get_next()) self.assertEqual((9, 10), self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -405,8 +414,14 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): stateful=True) dummy = constant_op.constant(37) - dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x, - finalize_fn).take(2) + + dataset = dataset_ops._GeneratorDataset( + dummy, + lambda x: x, + lambda x: x, finalize_fn, + tensor_spec.TensorSpec((), dtypes.int32)) + + dataset = dataset.take(2) get_next = self.getNext(dataset) self.assertAllEqual(37, self.evaluate(get_next())) @@ -428,6 +443,45 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([20], self.evaluate(get_next())) + @combinations.generate(test_base.default_test_combinations()) + def testFromGeneratorRaggedTensor(self): + + def generator(): + yield ragged_factory_ops.constant([[1, 2], [3]], + dtype=dtypes.int64, + ragged_rank=1) + + dataset = dataset_ops.Dataset.from_generator( + generator, + output_spec=ragged_tensor.RaggedTensorSpec(shape=(2, None), + dtype=dtypes.int64)) + get_next = self.getNext(dataset) + + ret = get_next() + + self.assertIsInstance(ret, ragged_tensor.RaggedTensor) + self.assertAllEqual([1, 2, 3], ret.values) + + @combinations.generate(test_base.default_test_combinations()) + def testFromGeneratorSparseTensor(self): + + def generator(): + yield sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], + values=constant_op.constant([1, 2], dtype=dtypes.int64), + dense_shape=[3, 4]) + + dataset = dataset_ops.Dataset.from_generator( + generator, + output_spec=sparse_tensor.SparseTensorSpec([3, 4], dtypes.int64)) + + get_next = self.getNext(dataset) + + ret = get_next() + + self.assertIsInstance(ret, sparse_tensor.SparseTensor) + self.assertAllEqual([[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]], + sparse_ops.sparse_tensor_to_dense(ret)) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index c7b2257c510..36ffce1b1c6 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -676,13 +676,39 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): del self._iterators[iterator_id] @staticmethod - def from_generator(generator, output_types, output_shapes=None, args=None): + def from_generator(generator, + output_types=None, + output_shapes=None, + args=None, + output_spec=None): """Creates a `Dataset` whose elements are generated by `generator`. The `generator` argument must be a callable object that returns an object that supports the `iter()` protocol (e.g. a generator function). - The elements generated by `generator` must be compatible with the given - `output_types` and (optional) `output_shapes` arguments. + The elements generated by `generator` must be compatible with either the + given `output_types` and (optional) `output_shapes` arguments or with the + given `output_spec` argument whichiver was specified. + + There are three ways to specify the output format: + + * Using only `output_types` argument. In this case the output of the + function will be assumed to consist of `tf.Tensor` objects with the unknown + shapes and with the types defined by `output_types`. + + * Using both `output_types` and `output_shapes` arguments. In this case the + output will be assumed to consist of `tf.Tensor` objects with the shapes + and types defined by these two arguments together. + + * Using `output_spec` argument. In this case the output will be assumed to + consist of objects with the classes, shapes and types defined by + `tf.TypeSpec` objects from `output_spec` argument. + + One of the `output_types` and `output_spec` arguments must be specified. + If used together, `output_spec` will override both `output_types` and + `output_shapes`. + + Use `output_types` argument in simpler cases to benefit from its shorter + form: >>> import itertools >>> @@ -698,7 +724,25 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): >>> list(dataset.take(3).as_numpy_iterator()) [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))] - Note: The current implementation of `Dataset.from_generator()` uses + Use `output_spec` for more complicated cases, e.g. to specify an output + containing `tf.RaggedTensor`, `tf.SparseTensor` or other objects different + from `tf.Tensor`: + + >>> def gen(): + ... ragged_tensor = tf.ragged.constant([[1, 2], [3]], + ... ragged_rank=1, + ... dtype=tf.int64) + ... yield 42, ragged_tensor + >>> + >>> dataset = tf.data.Dataset.from_generator( + ... gen, + ... output_spec=(tf.TensorSpec(shape=(), dtype=tf.int64), + ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int64))) + >>> + >>> list(dataset.take(1)) + [(, )] + + NOTE: The current implementation of `Dataset.from_generator()` uses `tf.numpy_function` and inherits the same constraints. In particular, it requires the `Dataset`- and `Iterator`-related operations to be placed on a device in the same process as the Python program that called @@ -706,7 +750,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): serialized in a `GraphDef`, and you should not use this method if you need to serialize your model and restore it in a different environment. - Note: If `generator` depends on mutable global variables or other external + NOTE: If `generator` depends on mutable global variables or other external state, be aware that the runtime may invoke `generator` multiple times (in order to support repeating the `Dataset`) and at any time between the call to `Dataset.from_generator()` and the production of the @@ -720,31 +764,52 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): `iter()` protocol. If `args` is not specified, `generator` must take no arguments; otherwise it must take as many arguments as there are values in `args`. - output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element yielded by `generator`. + output_types: (Optional.) A nested structure of `tf.DType` objects + corresponding to each component of an element yielded by `generator`. + Ignored if `output_spec` is specified. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element yielded by `generator`. + Ignored if `output_spec` is specified. args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated and passed to `generator` as NumPy-array arguments. + output_spec: (Optional.) A nested structure of `tf.TypeSpec` objects + corresponding to each component of an element yielded by `generator`. Returns: Dataset: A `Dataset`. """ if not callable(generator): raise TypeError("`generator` must be callable.") - if output_shapes is None: - output_shapes = nest.map_structure( + + if output_types is None and output_spec is None: + raise TypeError("Either `output_types` or `output_spec` must be " + "specified.") + + if output_spec is not None: + if not all(isinstance(_, type_spec.TypeSpec) + for _ in nest.flatten(output_spec)): + raise TypeError("All the elements of `output_spec` must be " + "a `tf.TypeSpec` objects.") + + if output_spec is None: + if output_shapes is None: + output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) - else: - output_shapes = nest.map_structure_up_to( + else: + output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) + output_spec = nest.map_structure_up_to( + output_types, + lambda shape, dtype: tensor_spec.TensorSpec(shape, dtype), # pylint: disable=unnecessary-lambda + output_shapes, + output_types) + if args is None: args = () else: args = tuple(ops.convert_n_to_tensor(args, name="args")) - flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)] - flattened_shapes = nest.flatten(output_shapes) + flat_output_types = structure.get_flat_tensor_types(output_spec) generator_state = DatasetV2._GeneratorState(generator) @@ -782,56 +847,35 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): """A `py_func` that will be called to invoke the iterator.""" # `next()` raises `StopIteration` when there are no more # elements remaining to be generated. - values = next(generator_state.get_iterator(iterator_id)) + values = next(generator_state.get_iterator(iterator_id.numpy())) + + def serialize_structure(s): + return nest.map_structure(lambda ts: ts._serialize(), s) # pylint: disable=protected-access - # Use the same _convert function from the py_func() implementation to - # convert the returned values to arrays early, so that we can inspect - # their values. try: - flattened_values = nest.flatten_up_to(output_types, values) + output_dtypes = nest.map_structure(lambda t: t.dtype, output_spec) + values = structure.normalize_element(values, dtypes=output_dtypes) except (TypeError, ValueError): six.reraise(TypeError, TypeError( - "`generator` yielded an element that did not match the expected " - "structure. The expected structure was %s, but the yielded " - "element was %s." % (output_types, values)), sys.exc_info()[2]) - ret_arrays = [] - for ret, dtype in zip(flattened_values, flattened_types): - try: - ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access - ret, dtype=dtype.as_numpy_dtype)) - except (TypeError, ValueError): - six.reraise(TypeError, TypeError( - "`generator` yielded an element that could not be converted to " - "the expected type. The expected type was %s, but the yielded " - "element was %s." % (dtype.name, ret)), sys.exc_info()[2]) + "`generator` yielded an element that did not match the expected " + "structure. The expected structure was %s, but the yielded " + "element was %s." % (serialize_structure(output_spec), values)), + sys.exc_info()[2]) - # Additional type and shape checking to ensure that the components - # of the generated element match the `output_types` and `output_shapes` - # arguments. - for (ret_array, expected_dtype, expected_shape) in zip( - ret_arrays, flattened_types, flattened_shapes): - if ret_array.dtype != expected_dtype.as_numpy_dtype: - raise TypeError( - "`generator` yielded an element of type %s where an element " - "of type %s was expected." % (ret_array.dtype, - expected_dtype.as_numpy_dtype)) - if not expected_shape.is_compatible_with(ret_array.shape): - raise ValueError( - "`generator` yielded an element of shape %s where an element " - "of shape %s was expected." % (ret_array.shape, expected_shape)) + values_spec = structure.type_spec_from_value(values) - return ret_arrays + if not structure.are_compatible(values_spec, output_spec): + raise TypeError( + "`generator` yielded an element of TypeSpec%s where an element " + "of TypeSpec%s was expected." + % (serialize_structure(values_spec), + serialize_structure(output_spec))) - flat_values = script_ops.numpy_function(generator_py_func, - [iterator_id_t], flattened_types) + return structure.to_tensor_list(output_spec, values) - # The `py_func()` op drops the inferred shapes, so we add them back in - # here. - if output_shapes is not None: - for ret_t, shape in zip(flat_values, flattened_shapes): - ret_t.set_shape(shape) - - return nest.pack_sequence_as(output_types, flat_values) + return script_ops.eager_py_func(generator_py_func, + inp=[iterator_id_t], + Tout=flat_output_types) def finalize_fn(iterator_id_t): """Releases host-side state for the iterator with ID `iterator_id_t`.""" @@ -857,7 +901,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): # given ID, and raises StopIteration when that iterator contains no # more elements. return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, - finalize_fn) + finalize_fn, output_spec) # A single-element dataset that, each time it is evaluated, contains a # freshly-generated and unique (for the returned dataset) int64 @@ -2279,9 +2323,10 @@ class DatasetV1(DatasetV2): @staticmethod @functools.wraps(DatasetV2.from_generator) - def from_generator(generator, output_types, output_shapes=None, args=None): + def from_generator(generator, output_types=None, output_shapes=None, + args=None, output_spec=None): return DatasetV1Adapter(DatasetV2.from_generator( - generator, output_types, output_shapes, args)) + generator, output_types, output_shapes, args, output_spec)) @staticmethod @functools.wraps(DatasetV2.range) @@ -3265,7 +3310,8 @@ class StructuredFunctionWrapper(object): class _GeneratorDataset(DatasetSource): """A `Dataset` that generates elements by invoking a function.""" - def __init__(self, init_args, init_func, next_func, finalize_func): + def __init__(self, init_args, init_func, next_func, finalize_func, + output_spec): """Constructs a `_GeneratorDataset`. Args: @@ -3279,6 +3325,8 @@ class _GeneratorDataset(DatasetSource): finalize_func: A TensorFlow function that will be called on the result of `init_func` immediately before a C++ iterator over this dataset is destroyed. The return value is ignored. + output_spec: A nested structure of `tf.TypeSpec` objects describing the + output of `next_func`. """ self._init_args = init_args @@ -3298,6 +3346,9 @@ class _GeneratorDataset(DatasetSource): finalize_func, self._transformation_name(), input_structure=self._init_func.output_structure) + + self._output_spec = output_spec + variant_tensor = gen_dataset_ops.generator_dataset( structure.to_tensor_list(self._init_structure, self._init_args) + self._init_func.function.captured_inputs, @@ -3311,7 +3362,7 @@ class _GeneratorDataset(DatasetSource): @property def element_spec(self): - return self._next_func.output_structure + return self._output_spec def _transformation_name(self): return "Dataset.from_generator()" diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 87825005069..ee6151742f6 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -67,7 +67,7 @@ def _RaggedTensorStructure(dtype, shape, ragged_rank): # TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once # it is a subclass of `CompositeTensor`. -def normalize_element(element): +def normalize_element(element, dtypes=None): """Normalizes a nested structure of element components. * Components matching `SparseTensorSpec` are converted to `SparseTensor`. @@ -78,6 +78,10 @@ def normalize_element(element): Args: element: A nested structure of individual components. + dtypes: (Optional.) A nested structure of `tf.DType` objects corresponding + to each component of `element`. If specified, it will be used to set the + exact type of output tensor when converting input components which + are not tensors themselves (e.g. numpy arrays, native python types, etc.) Returns: A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`, @@ -85,17 +89,21 @@ def normalize_element(element): """ components = nest.flatten(element) normalized_components = [] + if dtypes is None: + flattened_dtypes = [None] * len(components) + else: + flattened_dtypes = nest.flatten(dtypes) with ops.name_scope("normalize_element"): # Imported here to avoid circular dependency. from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top - for i, t in enumerate(components): + for i, (t, dtype) in enumerate(zip(components, flattened_dtypes)): try: spec = type_spec_from_value(t, use_fallback=False) except TypeError: # TypeError indicates it was not possible to compute a `TypeSpec` for # the value. As a fallback try converting the value to a tensor. normalized_components.append( - ops.convert_to_tensor(t, name="component_%d" % i)) + ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) else: if isinstance(spec, sparse_tensor.SparseTensorSpec): normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) @@ -112,7 +120,7 @@ def normalize_element(element): normalized_components.append(t) else: normalized_components.append( - ops.convert_to_tensor(t, name="component_%d" % i)) + ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) return nest.pack_sequence_as(element, normalized_components) From 17e9b7adf5c333900c1d6b49259b2e774832ed7e Mon Sep 17 00:00:00 2001 From: Ruan Kunliang Date: Sat, 7 Mar 2020 07:11:47 +0800 Subject: [PATCH 040/492] Fix a bug which may cause segmentation fault --- tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 3281f97457f..52e4580b1f8 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1586,6 +1586,10 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { const int n = node.attr().at("N").i(); const int start = node.op() == "Concat" ? 1 : 0; const int end = start + n; + if (end > node.input_size()) { + return errors::FailedPrecondition( + "Got attr N=", n, " without enough inputs."); + } // Set up tail pointers to point to the immediate inputs to Concat. for (int input_port = start; input_port < end; ++input_port) { if (IsControlInput(node.input(input_port))) { From 330863b56588a7014292fc47e4e4bd192b34b330 Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Mon, 9 Mar 2020 16:36:56 +0200 Subject: [PATCH 041/492] Fix 'NOTE's in docstrings --- tensorflow/python/data/ops/dataset_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 36ffce1b1c6..895b5305bef 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -742,7 +742,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): >>> list(dataset.take(1)) [(, )] - NOTE: The current implementation of `Dataset.from_generator()` uses + Note: The current implementation of `Dataset.from_generator()` uses `tf.numpy_function` and inherits the same constraints. In particular, it requires the `Dataset`- and `Iterator`-related operations to be placed on a device in the same process as the Python program that called @@ -750,7 +750,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): serialized in a `GraphDef`, and you should not use this method if you need to serialize your model and restore it in a different environment. - NOTE: If `generator` depends on mutable global variables or other external + Note: If `generator` depends on mutable global variables or other external state, be aware that the runtime may invoke `generator` multiple times (in order to support repeating the `Dataset`) and at any time between the call to `Dataset.from_generator()` and the production of the From 686a9aef26341146fcf66a6ecaf3a7014a457c7e Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Mon, 9 Mar 2020 19:37:18 +0200 Subject: [PATCH 042/492] Deprecating old arguments, rewriting docstring --- tensorflow/python/data/ops/dataset_ops.py | 60 +++++++---------------- 1 file changed, 18 insertions(+), 42 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 895b5305bef..2332980abb6 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -675,7 +675,12 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): def iterator_completed(self, iterator_id): del self._iterators[iterator_id] + @staticmethod + @deprecation.deprecated_args(None, "Use output_spec instead", + "output_types", "output_shapes") + # TODO(lithuak): Make output_spec a required argument once output_types + # and output_shapes are removed. def from_generator(generator, output_types=None, output_shapes=None, @@ -685,48 +690,15 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): The `generator` argument must be a callable object that returns an object that supports the `iter()` protocol (e.g. a generator function). + The elements generated by `generator` must be compatible with either the - given `output_types` and (optional) `output_shapes` arguments or with the - given `output_spec` argument whichiver was specified. + given `output_spec` argument or with the given `output_types` and (optional) + `output_shapes` arguments whichiver was specified. - There are three ways to specify the output format: - - * Using only `output_types` argument. In this case the output of the - function will be assumed to consist of `tf.Tensor` objects with the unknown - shapes and with the types defined by `output_types`. - - * Using both `output_types` and `output_shapes` arguments. In this case the - output will be assumed to consist of `tf.Tensor` objects with the shapes - and types defined by these two arguments together. - - * Using `output_spec` argument. In this case the output will be assumed to - consist of objects with the classes, shapes and types defined by - `tf.TypeSpec` objects from `output_spec` argument. - - One of the `output_types` and `output_spec` arguments must be specified. - If used together, `output_spec` will override both `output_types` and - `output_shapes`. - - Use `output_types` argument in simpler cases to benefit from its shorter - form: - - >>> import itertools - >>> - >>> def gen(): - ... for i in itertools.count(1): - ... yield (i, [1] * i) - >>> - >>> dataset = tf.data.Dataset.from_generator( - ... gen, - ... (tf.int64, tf.int64), - ... (tf.TensorShape([]), tf.TensorShape([None]))) - >>> - >>> list(dataset.take(3).as_numpy_iterator()) - [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))] - - Use `output_spec` for more complicated cases, e.g. to specify an output - containing `tf.RaggedTensor`, `tf.SparseTensor` or other objects different - from `tf.Tensor`: + The recommended way to call `from_generator` is to use the `output_spec` + argument. In this case the output will be assumed to consist of objects with + the classes, shapes and types defined by `tf.TypeSpec` objects from + `output_spec` argument: >>> def gen(): ... ragged_tensor = tf.ragged.constant([[1, 2], [3]], @@ -742,6 +714,12 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): >>> list(dataset.take(1)) [(, )] + There is also a deprecated way to call `from_generator` by either with + `output_types` argument alone or together with `output_shapes` argument. + In this case the output of the function will be assumed to consist of + `tf.Tensor` objects with with the types defined by `output_types` and with + the shapes which are either unknown or defined by `output_shapes`. + Note: The current implementation of `Dataset.from_generator()` uses `tf.numpy_function` and inherits the same constraints. In particular, it requires the `Dataset`- and `Iterator`-related operations to be placed @@ -766,10 +744,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): in `args`. output_types: (Optional.) A nested structure of `tf.DType` objects corresponding to each component of an element yielded by `generator`. - Ignored if `output_spec` is specified. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element yielded by `generator`. - Ignored if `output_spec` is specified. args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated and passed to `generator` as NumPy-array arguments. output_spec: (Optional.) A nested structure of `tf.TypeSpec` objects From 471f5aaa8dec6cf2a785fb48118925fbd1a3ec9b Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Mon, 9 Mar 2020 19:50:39 +0200 Subject: [PATCH 043/492] Verify arguments' correctness --- tensorflow/python/data/ops/dataset_ops.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 2332980abb6..68fb27e535b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -757,15 +757,21 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): if not callable(generator): raise TypeError("`generator` must be callable.") - if output_types is None and output_spec is None: - raise TypeError("Either `output_types` or `output_spec` must be " - "specified.") - if output_spec is not None: + if output_types is not None: + raise TypeError("`output_types` can not be used together with " + "`output_spec`") + if output_shapes is not None: + raise TypeError("`output_shapes` can not be used together with " + "`output_spec`") if not all(isinstance(_, type_spec.TypeSpec) for _ in nest.flatten(output_spec)): raise TypeError("All the elements of `output_spec` must be " "a `tf.TypeSpec` objects.") + else: + if output_types is None and output_shapes is not None: + raise TypeError("`output_shapes` can not be used alone without " + "`output_types`") if output_spec is None: if output_shapes is None: From 864e31180293af676ea6f00e4e7699cf93b1570d Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Mon, 9 Mar 2020 20:02:21 +0200 Subject: [PATCH 044/492] Update goldens --- .../tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt | 4 ++++ tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt | 2 +- .../v1/tensorflow.data.-fixed-length-record-dataset.pbtxt | 2 +- .../api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt | 2 +- .../api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt | 2 +- .../golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt | 2 +- .../v1/tensorflow.data.experimental.-random-dataset.pbtxt | 2 +- .../golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt | 2 +- .../tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt | 4 ++++ tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt | 2 +- .../v2/tensorflow.data.-fixed-length-record-dataset.pbtxt | 2 +- .../api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt | 2 +- .../api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt | 2 +- .../golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt | 2 +- .../v2/tensorflow.data.experimental.-random-dataset.pbtxt | 2 +- .../golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt | 2 +- 16 files changed, 22 insertions(+), 14 deletions(-) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt index 2ec5bb46ed1..029d04fee9b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "dtype" + mtype: "" + } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 872d03770ed..496726d6ee1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -63,7 +63,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index a84c5aa3caf..50935b5f60d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index a3862ae2a19..f30d71f08ab 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index baaaf7ea7be..4e9866303b3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt index afdeea5d018..78c011aa6f9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt index 76113c5e01d..689b752b4ab 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt index 1a11026fd19..16fe52e5edf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt index 2ec5bb46ed1..029d04fee9b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "dtype" + mtype: "" + } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index d9414c31e7d..cd920aa68a5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -46,7 +46,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 28efdb6e855..4f75307c9a6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index c9553efb58c..3acc58a243c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -47,7 +47,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 16a878144ae..43994daeff0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index d1d2db041e0..3b46075479f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index 18a6b8cbd1b..15905dc367a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index 0cf3d94ba68..4cac2f18484 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" From c054ade8ed6391e46401ac56f22b57bc94ef7618 Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Mon, 9 Mar 2020 21:11:16 +0200 Subject: [PATCH 045/492] Fix iterator_test --- tensorflow/python/data/kernel_tests/iterator_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 36689ed75fb..94b50a7864d 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -946,7 +946,9 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @def_function.function def fn(): - dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn) + output_spec = tensor_spec.TensorSpec((), dtypes.int64) + dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, + output_spec) iterator = iter(dataset) next(iterator) From 23831c8b3941034bfb29098cee7b9175b638b8d8 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 2 Mar 2020 13:41:38 +0000 Subject: [PATCH 046/492] Introduce external delegate provider to benchmark model. Allows TFLite benchmark model to load an external delegate through dlopen. --- tensorflow/lite/tools/benchmark/BUILD | 15 ++ tensorflow/lite/tools/benchmark/README.md | 5 + .../benchmark/external_delegate_provider.cc | 182 ++++++++++++++++++ tensorflow/lite/tools/make/Makefile | 2 +- 4 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 tensorflow/lite/tools/benchmark/external_delegate_provider.cc diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 5a413112e2f..c327f7a3b8f 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -140,6 +140,7 @@ cc_library( ":delegate_provider_hdr", ":gpu_delegate_provider", ":hexagon_delegate_provider", + ":external_delegate_provider", ":logging", ":nnapi_delegate_provider", "@com_google_absl//absl/base:core_headers", @@ -297,6 +298,20 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "external_delegate_provider", + srcs = ["external_delegate_provider.cc"], + copts = tflite_copts(), + linkstatic = True, + visibility = ["//visibility:public"], + deps = [ + ":benchmark_model_lib", + ":delegate_provider_hdr", + ":logging" + ], + alwayslink = 1, +) + cc_library( name = "benchmark_utils", srcs = [ diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md index 286ddf69cab..55debe72484 100644 --- a/tensorflow/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -87,6 +87,11 @@ and the following optional parameters: `enable_op_profiling`. When this is set to true the profile of ops on hexagon DSP will be added to the profile table. Note that, the reported data on hexagon is in cycles, not in ms like on cpu. +* `external_delegate_path`: `string` (default="") \ + Path to the external delegate library to use. +* `external_delegate_options`: `string` (default="") \ + A list of options to be passed to the external delegate library. + Options should be in the format of `option1:value1;option2:value2;optionN:valueN` ## To build/install/run diff --git a/tensorflow/lite/tools/benchmark/external_delegate_provider.cc b/tensorflow/lite/tools/benchmark/external_delegate_provider.cc new file mode 100644 index 00000000000..35a532b07cc --- /dev/null +++ b/tensorflow/lite/tools/benchmark/external_delegate_provider.cc @@ -0,0 +1,182 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/tools/benchmark/benchmark_model.h" +#include "tensorflow/lite/tools/benchmark/delegate_provider.h" +#include "tensorflow/lite/tools/benchmark/logging.h" + +#if defined(_WIN32) +#include +#else +#include +#endif + +#include +#include +#include + +namespace tflite { +namespace benchmark { +namespace { +// Library Support construct to handle dynamic library operations +#if defined(_WIN32) +struct LibSupport { + static void* Load(const char* lib) { return LoadLibrary(lib); } + + static void* GetSymbol(void* handle, const char* symbol) { + return (void*)GetProcAddress((HMODULE)handle, symbol); + } + + static int UnLoad(void* handle) { return FreeLibrary((HMODULE)handle); } +}; +#else +struct LibSupport { + static void* Load(const char* lib) { + return dlopen(lib, RTLD_LAZY | RTLD_LOCAL); + } + + static void* GetSymbol(void* handle, const char* symbol) { + return dlsym(handle, symbol); + } + + static int UnLoad(void* handle) { return dlclose(handle); } +}; +#endif + +// Split a given string to a vector of string using a delimiter character +std::vector SplitString(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream ss(str); + while (std::getline(ss, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +// External delegate library construct +struct ExternalLib { + using CreateDelegatePtr = std::add_pointer::type; + using DestroyDelegatePtr = std::add_pointer::type; + + // Open a given delegate library and load the create/destroy symbols + void load(const std::string library) { + if (!is_loaded) { + void* handle = LibSupport::Load(library.c_str()); + if (handle == nullptr) { + TFLITE_LOG(INFO) << "Unable to load external delegate from : " + << library; + } else { + create = reinterpret_cast( + LibSupport::GetSymbol(handle, "tflite_plugin_create_delegate")); + destroy = reinterpret_cast( + LibSupport::GetSymbol(handle, "tflite_plugin_destroy_delegate")); + is_loaded = create && destroy; + } + } + } + + CreateDelegatePtr create{nullptr}; + DestroyDelegatePtr destroy{nullptr}; + bool is_loaded{false}; +}; +} // namespace + +// External delegate provider used to dynamically load delegate libraries +// Note: Assumes the lifetime of the provider exceeds the usage scope of +// the generated delegates. +class ExternalDelegateProvider : public DelegateProvider { + public: + std::vector CreateFlags(BenchmarkParams* params) const final; + + void AddParams(BenchmarkParams* params) const final; + + void LogParams(const BenchmarkParams& params) const final; + + TfLiteDelegatePtr CreateTfLiteDelegate( + const BenchmarkParams& params) const final; + + std::string GetName() const final { return "EXTERNAL"; } + + private: + mutable ExternalLib delegate_lib_; +}; +REGISTER_DELEGATE_PROVIDER(ExternalDelegateProvider); + +std::vector ExternalDelegateProvider::CreateFlags( + BenchmarkParams* params) const { + std::vector flags = { + CreateFlag("external_delegate_path", params, + "The library path for the underlying external."), + CreateFlag( + "external_delegate_options", params, + "Comma-seperated options to be passed to the external delegate")}; + return flags; +} + +void ExternalDelegateProvider::AddParams(BenchmarkParams* params) const { + params->AddParam("external_delegate_path", + BenchmarkParam::Create("")); + params->AddParam("external_delegate_options", + BenchmarkParam::Create("")); +} + +void ExternalDelegateProvider::LogParams(const BenchmarkParams& params) const { + TFLITE_LOG(INFO) << "External delegate path : [" + << params.Get("external_delegate_path") << "]"; + TFLITE_LOG(INFO) << "External delegate options : [" + << params.Get("external_delegate_options") + << "]"; +} + +TfLiteDelegatePtr ExternalDelegateProvider::CreateTfLiteDelegate( + const BenchmarkParams& params) const { + TfLiteDelegatePtr delegate(nullptr, [](TfLiteDelegate*) {}); + std::string lib_path = params.Get("external_delegate_path"); + if (!lib_path.empty()) { + delegate_lib_.load(lib_path); + + if (delegate_lib_.is_loaded) { + // Parse delegate options + const std::vector options = SplitString( + params.Get("external_delegate_options"), ';'); + std::vector keys, values; + for (const auto& option : options) { + auto key_value = SplitString(option, ':'); + if (key_value.size() == 2) { + values.push_back(std::move(key_value[1])); + keys.push_back(std::move(key_value[0])); + } + } + + const size_t num_options = keys.size(); + std::vector ckeys, cvalues; + for (int i = 0; i < num_options; ++i) { + ckeys.push_back(keys[i].c_str()); + cvalues.push_back(values[i].c_str()); + } + + // Create delegate + delegate = + TfLiteDelegatePtr(delegate_lib_.create(ckeys.data(), cvalues.data(), + num_options, nullptr), + delegate_lib_.destroy); + } + } + return delegate; +} +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index b78fb14b785..f22f2cf3b90 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -312,7 +312,7 @@ $(BENCHMARK_BINARY) : $(BENCHMARK_MAIN_OBJ) $(BENCHMARK_LIB) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ -o $(BENCHMARK_BINARY) $(BENCHMARK_MAIN_OBJ) \ - $(LIBFLAGS) $(BENCHMARK_LIB) $(LDFLAGS) $(LIBS) + $(LIBFLAGS) -Wl,--whole-archive $(BENCHMARK_LIB) -Wl,--no-whole-archive $(LDFLAGS) $(LIBS) $(BENCHMARK_PERF_OPTIONS_BINARY) : $(BENCHMARK_PERF_OPTIONS_OBJ) $(BENCHMARK_LIB) @mkdir -p $(dir $@) From 5c9e9ca4253d28154a140d104ca72cb9fdc67ed6 Mon Sep 17 00:00:00 2001 From: Leslie-Fang Date: Wed, 11 Mar 2020 07:42:31 +0800 Subject: [PATCH 047/492] fix the reduce_max with complex64 input core_dump --- tensorflow/compiler/tf2xla/kernels/reduction_ops.cc | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 65e158d64fd..15f08537536 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -84,7 +84,18 @@ REGISTER_XLA_OP(Name("Min").CompileTimeConstantInput("reduction_indices"), class MaxOp : public XlaReductionOp { public: explicit MaxOp(OpKernelConstruction* ctx) - : XlaReductionOp(ctx, ctx->input_type(0)) {} + : XlaReductionOp(ctx, ctx->input_type(0)) { + OP_REQUIRES_OK(ctx, TypeCheck(xla_reduction_type_)); + } + + Status TypeCheck(xla::PrimitiveType xla_reduction_type_){ + if(xla_reduction_type_ == xla::C64){ + return errors::InvalidArgument( + "Unsupported type in xla_reduction_type_"); + }else{ + return Status::OK(); + } + } xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return xla::MinValue(builder, xla_reduction_type_); From 4205fac9b9794f97ec26dd2af633f1fa1bc9133c Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Wed, 11 Mar 2020 16:35:36 +0800 Subject: [PATCH 048/492] [tflite] allow passing quanized ReLU to NNAPI --- tensorflow/lite/delegates/nnapi/nnapi_delegate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index a90b28fcd10..1af75ebb594 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -1757,7 +1757,7 @@ bool NNAPIDelegateKernel::Validate( case kTfLiteBuiltinReluN1To1: case kTfLiteBuiltinRelu6: case kTfLiteBuiltinLogistic: { - ExpectOpVersion(version, 1, &val_ctx); + ExpectMaxOpVersion(version, 2, &val_ctx); ExpectIsFloatOrQuant8Operator(context, node, &val_ctx); } break; case kTfLiteBuiltinTanh: { From 614e211481339c2f2917613d9642cb0ebbe4e92c Mon Sep 17 00:00:00 2001 From: Leslie-Fang Date: Wed, 11 Mar 2020 17:41:08 +0800 Subject: [PATCH 049/492] format error message in create op --- tensorflow/compiler/tf2xla/kernels/reduction_ops.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 15f08537536..bad324b1aca 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -88,11 +88,12 @@ class MaxOp : public XlaReductionOp { OP_REQUIRES_OK(ctx, TypeCheck(xla_reduction_type_)); } - Status TypeCheck(xla::PrimitiveType xla_reduction_type_){ - if(xla_reduction_type_ == xla::C64){ - return errors::InvalidArgument( - "Unsupported type in xla_reduction_type_"); - }else{ + Status TypeCheck(xla::PrimitiveType xla_reduction_type_) { + if (xla_reduction_type_ == xla::C64) { + return errors::InvalidArgument( + "Unsupported PrimitiveType in MaxOp: '", + xla::PrimitiveType_Name(xla_reduction_type_), "'"); + } else { return Status::OK(); } } From 88ddf3b7a1e6bb43c0be607863402d79eabb46a2 Mon Sep 17 00:00:00 2001 From: Tomohiro Ubukata Date: Wed, 11 Mar 2020 10:08:05 +0000 Subject: [PATCH 050/492] Fix typo --- .../api_def_DebugNumericSummary.pbtxt | 4 +-- tensorflow/core/framework/tensor.cc | 29 ++++++++++--------- tensorflow/core/kernels/cwise_ops_common.h | 6 ++-- .../core/kernels/spacetobatch_functor.h | 6 ++-- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_DebugNumericSummary.pbtxt b/tensorflow/core/api_def/base_api/api_def_DebugNumericSummary.pbtxt index 565a49ad744..fc5429183e7 100644 --- a/tensorflow/core/api_def/base_api/api_def_DebugNumericSummary.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DebugNumericSummary.pbtxt @@ -58,7 +58,7 @@ END Provide a basic summary of numeric value types, range and distribution. output: A double tensor of shape [14 + nDimensions], where nDimensions is the - the number of dimensions of the tensor's shape. The elements of output are: + number of dimensions of the tensor's shape. The elements of output are: [0]: is initialized (1.0) or not (0.0). [1]: total number of elements [2]: NaN element count @@ -68,7 +68,7 @@ output: A double tensor of shape [14 + nDimensions], where nDimensions is the -inf. Otherwise, this is the count of elements > lower_bound and < 0. [5]: zero element count [6]: positive element count (excluding +inf), if upper_bound is the default - -inf. Otherwise, this is the count of elements < upper_bound and > 0. + +inf. Otherwise, this is the count of elements < upper_bound and > 0. [7]: generalized +inf count, elements >= upper_bound. upper_bound is +inf by default. Output elements [1:8] are all zero, if the tensor is uninitialized. diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index a7cc9f59b69..66c97aed874 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -656,15 +656,15 @@ bool Tensor::IsInitialized() const { } void Tensor::CheckType(DataType expected_dtype) const { - CHECK_EQ(dtype(), expected_dtype) - << " " << DataTypeString(expected_dtype) << " expected, got " - << DataTypeString(dtype()); + CHECK_EQ(dtype(), expected_dtype) << " " << DataTypeString(expected_dtype) + << " expected, got " + << DataTypeString(dtype()); } void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const { - CHECK_EQ(dtype(), expected_dtype) - << " " << DataTypeString(expected_dtype) << " expected, got " - << DataTypeString(dtype()); + CHECK_EQ(dtype(), expected_dtype) << " " << DataTypeString(expected_dtype) + << " expected, got " + << DataTypeString(dtype()); CHECK(IsAligned()) << "ptr = " << base(); } @@ -764,9 +764,10 @@ bool Tensor::RefCountIsOne() const { break; \ } -#define CASES(TYPE_ENUM, STMTS) \ - CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \ - , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;) +#define CASES(TYPE_ENUM, STMTS) \ + CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Unexpected type: " \ + << TYPE_ENUM; \ + , LOG(FATAL) << "Type not set";) Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape) : shape_(shape), buf_(nullptr) { @@ -1255,14 +1256,14 @@ bool Tensor::SharesBufferWith(const Tensor& b) const { } string Tensor::DebugString(int num_values) const { - return strings::StrCat("Tensor"); + return strings::StrCat("Tensor"); } string Tensor::DeviceSafeDebugString() const { - return strings::StrCat("Tensor"); + return strings::StrCat("Tensor"); } void Tensor::FillDescription(TensorDescription* description) const { diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index c8ac4103f91..c0aee43d268 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -346,7 +346,7 @@ void Assign(const D& d, Out out, Rhs rhs) { } // Partial specialization of BinaryFunctor -// for functors with with no error checking. +// for functors with no error checking. template struct BinaryFunctor { void operator()(const CPUDevice& d, typename Functor::tout_type out, @@ -405,7 +405,7 @@ struct BinaryFunctor { }; // Partial specialization of BinaryFunctor -// for functors with with no error checking. +// for functors with no error checking. template struct BinaryFunctor { enum { NDIMS = 2 }; @@ -472,7 +472,7 @@ struct BinaryFunctor { typename Functor::func func; if (Functor::use_bcast_optimization && use_bcast_optimization::value) { // Optimize for speed by using Eigen::type2index and avoid - // .broadcast() when we know its a no-op. + // .broadcast() when we know it's a no-op. // // Here, we need to handle 6 cases depending on how many "1" // exist in in0 and in1's shapes (4 numbers in total). It's not diff --git a/tensorflow/core/kernels/spacetobatch_functor.h b/tensorflow/core/kernels/spacetobatch_functor.h index 459f20b0ae1..d197892b130 100644 --- a/tensorflow/core/kernels/spacetobatch_functor.h +++ b/tensorflow/core/kernels/spacetobatch_functor.h @@ -18,11 +18,11 @@ limitations under the License. #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -44,7 +44,7 @@ constexpr int kMaxSpaceToBatchBlockDims = 4; MACRO(2 /**/, ##__VA_ARGS__) \ MACRO(3 /**/, ##__VA_ARGS__) \ MACRO(4 /**/, ##__VA_ARGS__) \ - /**/ +/**/ namespace internal { namespace spacetobatch { @@ -80,7 +80,7 @@ namespace functor { // Functor used by {SpaceToBatch,BatchToSpace}{ND,}Op to do the conversion. // -// If B2S is false, then this performs the space-to-batch conversion. If S2B is +// If B2S is false, then this performs the space-to-batch conversion. If B2S is // true, then this performs the inverse batch-to-space conversion. template struct SpaceToBatchFunctor { From aaa7c3069181c1591e4612a3b1d409f4ffc6c2c1 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Wed, 11 Mar 2020 16:24:23 +0100 Subject: [PATCH 051/492] Introduce ShallKeepControlEdgeFrom function --- .../compiler/tf2tensorrt/convert/convert_graph.cc | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index cb11ab53d5c..c66ec6bffb0 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -103,6 +103,15 @@ std::pair GetFirstValidDeviceId() { return std::make_pair(TfGpuId(-1), PlatformGpuId(-1)); } +// Returns false for const nodes (we intend to drop control edges from those). +bool ShallKeepControlEdgeFrom(const Node* input_node) { + if (!input_node) { + VLOG(2) << "Node pointer is null, this should not happen"; + return false; + } + return input_node->type_string() != "Const"; +} + // Function to get subsegment information structure. Status GetEngineInfo(const Graph* g, const grappler::GraphProperties& graph_properties, @@ -172,7 +181,7 @@ Status GetEngineInfo(const Graph* g, continue; } if (edge->IsControlEdge()) { - if (input_node->type_string() != "Const") { + if (ShallKeepControlEdgeFrom(input_node)) { // Non-Const control input. info->connections.emplace_back(input_node->name(), input_node->id(), node_name, node_id, @@ -221,7 +230,7 @@ Status GetEngineInfo(const Graph* g, } if (edge->IsControlEdge()) { // Control output. - if (node->type_string() != "Const") { + if (ShallKeepControlEdgeFrom(node)) { info->connections.emplace_back(output_node->name(), output_node->id(), node_name, node_id, /*input_edge=*/false); From 8fa9fa59cb9e0cf31071c3c2e0c6dd975727bff5 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Wed, 11 Mar 2020 22:09:20 +0100 Subject: [PATCH 052/492] Change log level for error iin ShallKeepControlEdgeFrom --- tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index c66ec6bffb0..b76a3f63b40 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -106,7 +106,7 @@ std::pair GetFirstValidDeviceId() { // Returns false for const nodes (we intend to drop control edges from those). bool ShallKeepControlEdgeFrom(const Node* input_node) { if (!input_node) { - VLOG(2) << "Node pointer is null, this should not happen"; + LOG(FATAL) << "Node pointer is null, this should not happen"; return false; } return input_node->type_string() != "Const"; From 63abc5122b45f09e2d3e712d6281da8de479a512 Mon Sep 17 00:00:00 2001 From: Rajan Singh Date: Wed, 11 Mar 2020 16:18:09 -0700 Subject: [PATCH 053/492] Update func_graph.py Fixing error message. removed "contrib" usage. --- tensorflow/python/framework/func_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index d702771cef3..663f13e0d16 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -937,7 +937,7 @@ def func_graph_from_py_func(name, x = ops.convert_to_tensor_or_composite(x) except (ValueError, TypeError): raise TypeError( - "To be compatible with tf.contrib.eager.defun, Python functions " + "To be compatible with tf.eager.defun, Python functions " "must return zero or more Tensors; in compilation of %s, found " "return value of type %s, which is not a Tensor." % (str(python_func), type(x))) From feff4ad8f83533bd7fe500f938057f01843a94ef Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Thu, 12 Mar 2020 18:18:49 +0200 Subject: [PATCH 054/492] Workaround the memory leak in script_ops --- tensorflow/python/data/ops/dataset_ops.py | 5 ++-- tensorflow/python/ops/script_ops.py | 31 ++++++++++++++++++++--- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 68fb27e535b..be460ed4ceb 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -855,9 +855,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): return structure.to_tensor_list(output_spec, values) - return script_ops.eager_py_func(generator_py_func, - inp=[iterator_id_t], - Tout=flat_output_types) + return script_ops.eager_py_func_without_tape_cache( + generator_py_func, inp=[iterator_id_t], Tout=flat_output_types) def finalize_fn(iterator_id_t): """Releases host-side state for the iterator with ID `iterator_id_t`.""" diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index bee85dc4a5b..9403ea11130 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -70,7 +70,7 @@ def _maybe_copy_to_context_device(tensor, device_name): class EagerFunc(object): """A wrapper for a function owned by an EagerPyFunc.""" - def __init__(self, func, Tout, is_grad_func): + def __init__(self, func, Tout, is_grad_func, use_tape_cache=True): """Constructs an EagerFunc. Args: @@ -79,10 +79,14 @@ class EagerFunc(object): None. is_grad_func: Whether this EagerFunc is the gradient of another EagerPyFunc. + use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`. + NOTE(lithuak): see the note for `eager_py_func_without_tape_cache`. + This parameter should be removed once the #35084 issue is fixed. """ self._func = func self._out_dtypes = Tout self._is_grad_func = is_grad_func + self._use_tape_cache = use_tape_cache def _convert(self, value, dtype): """Converts `value` to a tensor of type `dtype`, with error checking. @@ -146,7 +150,8 @@ class EagerFunc(object): else: outputs = _maybe_copy_to_context_device( self._convert(ret, dtype=self._out_dtypes[0]), device_name) - tape_cache[compat.as_bytes(token)] = (tape, args, outputs) + if self._use_tape_cache: + tape_cache[compat.as_bytes(token)] = (tape, args, outputs) return outputs @@ -276,7 +281,8 @@ def _internal_py_func(func, stateful=None, eager=False, is_grad_func=False, - name=None): + name=None, + use_tape_cache=True): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError("Expected func to be callable, got func of type {}".format( @@ -292,7 +298,7 @@ def _internal_py_func(func, Tout = [Tout] if eager: - func = EagerFunc(func, Tout, is_grad_func) + func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache) # Tying the registered function's lifetime with the current default graph is # not reliable. For example, Estimator-based binaries may switch graphs in @@ -457,6 +463,23 @@ def eager_py_func(func, inp, Tout, name=None): return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name) +# NOTE(lithuak): this function is here only as a workaround for github +# issue #35084. It is almost identical to `eager_py_func` with one difference: +# it instructs underlying EagerFunc not to use `tape_cache` to avoid memory +# leak. When the issue #35084 is fixed - this function should be removed +# and all the call sites should be changed back to using `eager_py_func`. +def eager_py_func_without_tape_cache(func, inp, Tout, name=None): + if ops.executing_eagerly_outside_functions(): + with ops.device(context.context().host_address_space()): + return _internal_py_func( + func=func, inp=inp, Tout=Tout, eager=True, name=name, + use_tape_cache=False) + + return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, + name=name, use_tape_cache=False) + + + def py_func_common(func, inp, Tout, stateful=True, name=None): """Wraps a python function and uses it as a TensorFlow op. From 1127ae0a91fcee00d2931ef142f0ac2c63bdc7be Mon Sep 17 00:00:00 2001 From: Ashutosh Hathidara Date: Thu, 12 Mar 2020 21:50:25 +0530 Subject: [PATCH 055/492] Resolved description --- tensorflow/lite/python/lite.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 25bdd2201c8..fc73f094549 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -308,8 +308,9 @@ class TFLiteConverterV2(TFLiteConverterBase): to apply when converting the model. E.g. `[Optimize.DEFAULT]` representative_dataset: A representative dataset that can be used to generate input and output samples for the model. The converter can use the - dataset to evaluate different optimizations. Note that this is a necessary - attribute since the conversion optimization depends upon it. + dataset to evaluate different optimizations. Note that this is an optional + attribute but it is necessary if INT8 is the only support builtin ops in + target ops. experimental_new_converter: Experimental flag, subject to change. Enables MLIR-based conversion instead of TOCO conversion. experimental_new_quantizer: Experimental flag, subject to change. From b4bd0ab0bad66694c5535042aa6804b9b4895e12 Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Thu, 12 Mar 2020 21:19:39 +0200 Subject: [PATCH 056/492] Adding _eager_py_func to control the using of `tape_cache` --- tensorflow/python/data/ops/dataset_ops.py | 6 ++- tensorflow/python/ops/script_ops.py | 45 +++++++++++------------ 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index be460ed4ceb..d2cf0a709d4 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -855,8 +855,10 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): return structure.to_tensor_list(output_spec, values) - return script_ops.eager_py_func_without_tape_cache( - generator_py_func, inp=[iterator_id_t], Tout=flat_output_types) + return script_ops._eager_py_func(generator_py_func, + inp=[iterator_id_t], + Tout=flat_output_types, + use_tape_cache=False) # pylint: disable=protected-access def finalize_fn(iterator_id_t): """Releases host-side state for the iterator with ID `iterator_id_t`.""" diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 9403ea11130..61981dbd8a7 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -80,7 +80,7 @@ class EagerFunc(object): is_grad_func: Whether this EagerFunc is the gradient of another EagerPyFunc. use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`. - NOTE(lithuak): see the note for `eager_py_func_without_tape_cache`. + NOTE(lithuak): see the note for `_eager_py_func`. This parameter should be removed once the #35084 issue is fixed. """ self._func = func @@ -375,6 +375,25 @@ def _EagerPyFuncGrad(op, *dy): is_grad_func=True) +# NOTE(lithuak): this function as a layer of indirection was added with one +# specific purpose: as a workaround for github issue #35084. +# It does all the same as `eager_py_func` used to do with one difference: +# it can be used to instruct underlying EagerFunc not to use `tape_cache` +# to avoid memory leak. When the issue #35084 is fixed - this function should +# be removed, its body should be moved back to become the body of +# `eager_py_func` and all the call sites should be reverted to +# using `eager_py_func` without `use_tape_cache` argument of any value. +def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True): + if ops.executing_eagerly_outside_functions(): + with ops.device(context.context().host_address_space()): + return _internal_py_func( + func=func, inp=inp, Tout=Tout, eager=True, name=name, + use_tape_cache=use_tape_cache) + + return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, + name=name, use_tape_cache=use_tape_cache) + + @tf_export("py_function") def eager_py_func(func, inp, Tout, name=None): """Wraps a python function into a TensorFlow op that executes it eagerly. @@ -455,29 +474,7 @@ def eager_py_func(func, inp, Tout, name=None): A list of `Tensor` or a single `Tensor` which `func` computes; an empty list if `func` returns None. """ - if ops.executing_eagerly_outside_functions(): - with ops.device(context.context().host_address_space()): - return _internal_py_func( - func=func, inp=inp, Tout=Tout, eager=True, name=name) - - return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name) - - -# NOTE(lithuak): this function is here only as a workaround for github -# issue #35084. It is almost identical to `eager_py_func` with one difference: -# it instructs underlying EagerFunc not to use `tape_cache` to avoid memory -# leak. When the issue #35084 is fixed - this function should be removed -# and all the call sites should be changed back to using `eager_py_func`. -def eager_py_func_without_tape_cache(func, inp, Tout, name=None): - if ops.executing_eagerly_outside_functions(): - with ops.device(context.context().host_address_space()): - return _internal_py_func( - func=func, inp=inp, Tout=Tout, eager=True, name=name, - use_tape_cache=False) - - return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, - name=name, use_tape_cache=False) - + _eager_py_func(func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True) def py_func_common(func, inp, Tout, stateful=True, name=None): From a007162bc2bbe092aa4c295efb4dc8e94e6a004e Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Thu, 12 Mar 2020 23:02:20 +0200 Subject: [PATCH 057/492] Fix the return statement for `eager_py_func` --- tensorflow/python/ops/script_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 61981dbd8a7..954152fe88f 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -386,9 +386,8 @@ def _EagerPyFuncGrad(op, *dy): def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True): if ops.executing_eagerly_outside_functions(): with ops.device(context.context().host_address_space()): - return _internal_py_func( - func=func, inp=inp, Tout=Tout, eager=True, name=name, - use_tape_cache=use_tape_cache) + return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, + name=name, use_tape_cache=use_tape_cache) return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name, use_tape_cache=use_tape_cache) @@ -474,7 +473,8 @@ def eager_py_func(func, inp, Tout, name=None): A list of `Tensor` or a single `Tensor` which `func` computes; an empty list if `func` returns None. """ - _eager_py_func(func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True) + return _eager_py_func(func=func, inp=inp, Tout=Tout, + name=name, use_tape_cache=True) def py_func_common(func, inp, Tout, stateful=True, name=None): From 03144aeb46dd6776c96c90c2effb1ae4062464ca Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 12 Mar 2020 23:59:40 +0000 Subject: [PATCH 058/492] Load external library on stack --- .../benchmark/external_delegate_provider.cc | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/tensorflow/lite/tools/benchmark/external_delegate_provider.cc b/tensorflow/lite/tools/benchmark/external_delegate_provider.cc index 35a532b07cc..9174b4a1f95 100644 --- a/tensorflow/lite/tools/benchmark/external_delegate_provider.cc +++ b/tensorflow/lite/tools/benchmark/external_delegate_provider.cc @@ -73,25 +73,22 @@ struct ExternalLib { using DestroyDelegatePtr = std::add_pointer::type; // Open a given delegate library and load the create/destroy symbols - void load(const std::string library) { - if (!is_loaded) { - void* handle = LibSupport::Load(library.c_str()); - if (handle == nullptr) { - TFLITE_LOG(INFO) << "Unable to load external delegate from : " - << library; - } else { - create = reinterpret_cast( - LibSupport::GetSymbol(handle, "tflite_plugin_create_delegate")); - destroy = reinterpret_cast( - LibSupport::GetSymbol(handle, "tflite_plugin_destroy_delegate")); - is_loaded = create && destroy; - } + bool load(const std::string library) { + void* handle = LibSupport::Load(library.c_str()); + if (handle == nullptr) { + TFLITE_LOG(INFO) << "Unable to load external delegate from : " << library; + } else { + create = reinterpret_cast( + LibSupport::GetSymbol(handle, "tflite_plugin_create_delegate")); + destroy = reinterpret_cast( + LibSupport::GetSymbol(handle, "tflite_plugin_destroy_delegate")); + return create && destroy; } + return false; } CreateDelegatePtr create{nullptr}; DestroyDelegatePtr destroy{nullptr}; - bool is_loaded{false}; }; } // namespace @@ -110,9 +107,6 @@ class ExternalDelegateProvider : public DelegateProvider { const BenchmarkParams& params) const final; std::string GetName() const final { return "EXTERNAL"; } - - private: - mutable ExternalLib delegate_lib_; }; REGISTER_DELEGATE_PROVIDER(ExternalDelegateProvider); @@ -147,9 +141,8 @@ TfLiteDelegatePtr ExternalDelegateProvider::CreateTfLiteDelegate( TfLiteDelegatePtr delegate(nullptr, [](TfLiteDelegate*) {}); std::string lib_path = params.Get("external_delegate_path"); if (!lib_path.empty()) { - delegate_lib_.load(lib_path); - - if (delegate_lib_.is_loaded) { + ExternalLib delegate_lib; + if (delegate_lib.load(lib_path)) { // Parse delegate options const std::vector options = SplitString( params.Get("external_delegate_options"), ';'); @@ -171,9 +164,9 @@ TfLiteDelegatePtr ExternalDelegateProvider::CreateTfLiteDelegate( // Create delegate delegate = - TfLiteDelegatePtr(delegate_lib_.create(ckeys.data(), cvalues.data(), - num_options, nullptr), - delegate_lib_.destroy); + TfLiteDelegatePtr(delegate_lib.create(ckeys.data(), cvalues.data(), + num_options, nullptr), + delegate_lib.destroy); } } return delegate; From 52169bdc586d483eaf155a16eee87be62ab4b6f2 Mon Sep 17 00:00:00 2001 From: Ashutosh Hathidara Date: Fri, 13 Mar 2020 14:15:44 +0530 Subject: [PATCH 059/492] Added code for preprocessing of YUB images before feeding yub_to_rgb --- tensorflow/python/ops/image_ops_impl.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 72682e6fee4..81acda07168 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -3293,6 +3293,31 @@ def yuv_to_rgb(images): The output is only well defined if the Y value in images are in [0,1], U and V value are in [-0.5,0.5]. + As per the above description, you need to scale your YUV images if their + pixel values are not in the required range. Below given example illustrates + preprocessing of each channel of images before feeding them to `yub_to_rgb`. + + ```python + yub_images = tf.random.uniform(shape=[100, 64, 64, 3], maxval=255) + last_dimension_axis = len(yub_images.shape) - 1 + yub_tensor_images = tf.truediv( + tf.subtract( + yub_images, + tf.reduce_min(yub_images) + ), + tf.subtract( + tf.reduce_max(yub_images), + tf.reduce_min(yub_images) + ) + ) + y, u, v = tf.split(yub_tensor_images, 3, axis=last_dimension_axis) + target_uv_min, target_uv_max = -0.5, 0.5 + u = u * (target_uv_max - target_uv_min) + target_uv_min + v = v * (target_uv_max - target_uv_min) + target_uv_min + preprocessed_yub_images = tf.concat([y, u, v], axis=last_dimension_axis) + rgb_tensor_images = tf.image.yuv_to_rgb(preprocessed_yub_images) + ``` + Args: images: 2-D or higher rank. Image data to convert. Last dimension must be size 3. From 49de661e024d6068d3158f2f622aec1c468c0ee3 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 13 Mar 2020 12:55:32 +0000 Subject: [PATCH 060/492] Fix sanity check issue --- tensorflow/lite/tools/benchmark/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index c327f7a3b8f..a9274747aef 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -307,7 +307,7 @@ cc_library( deps = [ ":benchmark_model_lib", ":delegate_provider_hdr", - ":logging" + ":logging", ], alwayslink = 1, ) From 22f9d7f431fe9acb0fa4a2c357151c026926b246 Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Fri, 13 Mar 2020 18:52:35 +0200 Subject: [PATCH 061/492] Fixed tests in kernel_tests/bucket_by_sequence_length_test --- .../kernel_tests/bucket_by_sequence_length_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py index 0dd7ae1f083..3a1dd00e7de 100644 --- a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py @@ -48,7 +48,7 @@ def _format_record(array, sparse): return { "values": array, "indices": [[i] for i in range(len(array))], - "dense_shape": (len(array),) + "dense_shape": [len(array),] } return array @@ -402,13 +402,15 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase, bucket_size = 10 def _build_dataset(): - input_data = [range(i+1) for i in range(min_len, max_len)] + input_data = [list(range(i+1)) for i in range(min_len, max_len)] def generator_fn(): for record in input_data: yield _format_record(record, sparse=True) + dataset = dataset_ops.Dataset.from_generator( generator=generator_fn, output_types=_get_record_type(sparse=True)) + dataset = dataset.map(_to_sparse_tensor) return dataset From b76202fbf029b6d53c25c85d9916072aae430373 Mon Sep 17 00:00:00 2001 From: Ilya Persky Date: Fri, 13 Mar 2020 19:38:46 +0200 Subject: [PATCH 062/492] Fixed pyling issues --- .../data/kernel_tests/from_generator_test.py | 9 ++--- tensorflow/python/data/ops/dataset_ops.py | 34 +++++++++---------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py index b020f4b67bb..c8f54ec1d8f 100644 --- a/tensorflow/python/data/kernel_tests/from_generator_test.py +++ b/tensorflow/python/data/kernel_tests/from_generator_test.py @@ -268,7 +268,8 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaisesOpError(r"element of TypeSpec\(TensorShape\(\[3\]\), tf\.int64, None\) was expected"): + with self.assertRaisesOpError(r"element of TypeSpec\(TensorShape\(\[3\]\), " + r"tf\.int64, None\) was expected"): self.evaluate(get_next()) self.assertAllEqual([11, 12, 13], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -447,9 +448,9 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): def testFromGeneratorRaggedTensor(self): def generator(): - yield ragged_factory_ops.constant([[1, 2], [3]], - dtype=dtypes.int64, - ragged_rank=1) + yield ragged_factory_ops.constant([[1, 2], [3]], + dtype=dtypes.int64, + ragged_rank=1) dataset = dataset_ops.Dataset.from_generator( generator, diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 43fe0509bd7..dee5909d27c 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -410,8 +410,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): def element_spec(self): """The type specification of an element of this dataset. - >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) - >>> dataset.element_spec + >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec TensorSpec(shape=(), dtype=tf.int32, name=None) Returns: @@ -713,7 +712,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int64))) >>> >>> list(dataset.take(1)) - [(, )] + [(, + )] There is also a deprecated way to call `from_generator` by either with `output_types` argument alone or together with `output_shapes` argument. @@ -777,15 +777,15 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): if output_spec is None: if output_shapes is None: output_shapes = nest.map_structure( - lambda _: tensor_shape.TensorShape(None), output_types) + lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) + output_types, tensor_shape.as_shape, output_shapes) output_spec = nest.map_structure_up_to( - output_types, - lambda shape, dtype: tensor_spec.TensorSpec(shape, dtype), # pylint: disable=unnecessary-lambda - output_shapes, - output_types) + output_types, + lambda shape, dtype: tensor_spec.TensorSpec(shape, dtype), # pylint: disable=unnecessary-lambda + output_shapes, + output_types) if args is None: args = () @@ -840,19 +840,19 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): values = structure.normalize_element(values, dtypes=output_dtypes) except (TypeError, ValueError): six.reraise(TypeError, TypeError( - "`generator` yielded an element that did not match the expected " - "structure. The expected structure was %s, but the yielded " - "element was %s." % (serialize_structure(output_spec), values)), + "`generator` yielded an element that did not match the expected " + "structure. The expected structure was %s, but the yielded " + "element was %s." % (serialize_structure(output_spec), values)), sys.exc_info()[2]) values_spec = structure.type_spec_from_value(values) if not structure.are_compatible(values_spec, output_spec): raise TypeError( - "`generator` yielded an element of TypeSpec%s where an element " - "of TypeSpec%s was expected." - % (serialize_structure(values_spec), - serialize_structure(output_spec))) + "`generator` yielded an element of TypeSpec%s where an element " + "of TypeSpec%s was expected." + % (serialize_structure(values_spec), + serialize_structure(output_spec))) return structure.to_tensor_list(output_spec, values) @@ -2310,7 +2310,7 @@ class DatasetV1(DatasetV2): def from_generator(generator, output_types=None, output_shapes=None, args=None, output_spec=None): return DatasetV1Adapter(DatasetV2.from_generator( - generator, output_types, output_shapes, args, output_spec)) + generator, output_types, output_shapes, args, output_spec)) @staticmethod @functools.wraps(DatasetV2.range) From 0e9211350b926529343ac6715383659aa31c1644 Mon Sep 17 00:00:00 2001 From: Guy David Date: Sat, 14 Mar 2020 00:02:06 +0200 Subject: [PATCH 063/492] fixed to match coding conventions --- tensorflow/lite/delegates/gpu/common/model_builder.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index bda3e1bcd0e..c43b996145e 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1415,8 +1415,7 @@ class MulOperationParser : public TFLiteOperationParser { } RETURN_IF_ERROR(ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader)); - } - else { + } else { // The runtime input tensor must be bound to 1st input and the constant // input tensor must be bound to 2nd input. int runtime_tensor = 0; From 7c0a37e2ad123dfeb409c682a1cab37630678642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cjaketae=E2=80=9D?= Date: Sat, 14 Mar 2020 08:31:05 +0900 Subject: [PATCH 064/492] Improve preprocessing text docs --- tensorflow/python/keras/preprocessing/text.py | 70 +++++++++++++++++-- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/keras/preprocessing/text.py b/tensorflow/python/keras/preprocessing/text.py index 603308e3738..e9cfe677b5c 100644 --- a/tensorflow/python/keras/preprocessing/text.py +++ b/tensorflow/python/keras/preprocessing/text.py @@ -23,16 +23,69 @@ from keras_preprocessing import text from tensorflow.python.util.tf_export import keras_export -text_to_word_sequence = text.text_to_word_sequence -one_hot = text.one_hot hashing_trick = text.hashing_trick Tokenizer = text.Tokenizer -keras_export( - 'keras.preprocessing.text.text_to_word_sequence')(text_to_word_sequence) -keras_export('keras.preprocessing.text.one_hot')(one_hot) -keras_export('keras.preprocessing.text.hashing_trick')(hashing_trick) -keras_export('keras.preprocessing.text.Tokenizer')(Tokenizer) + +@keras_export('keras.preprocessing.text.text_to_word_sequence') +def text_to_word_sequence(text, + filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', + lower=True, split=" "): + """Converts a text to a sequence of words (or tokens). + + This function transforms a string of text into a list of words + while ignoring `filters` which include punctuations by default. + + >>> text = 'This is a sample sentence.' + >>> tf.keras.preprocessing.text.text_to_word_sequence(text) + ['this', 'is', 'a', 'sample', 'sentence'] + + Arguments: + text: Input text (string). + filters: list (or concatenation) of characters to filter out, such as + punctuation. Default: `'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n'`, + includes basic punctuation, tabs, and newlines. + lower: boolean. Whether to convert the input to lowercase. + split: str. Separator for word splitting. + + Returns: + A list of words (or tokens). + """ + return text.text_to_word_sequence( + text, filters=filters, lower=lower, split=split) + + +@keras_export('tf.keras.preprocessing.text.one_hot') +def one_hot(text, n, + filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', + lower=True, + split=' '): + """One-hot encodes a text into a list of word indexes of size `n`. + + This function receives as input a string of text and returns a + list of encoded integers each corresponding to a word (or token) + in the given input string. + + >>> text = 'This is a sample sentence.' + >>> tf.keras.preprocessing.text.one_hot(text, 20) + [4, 18, 1, 15, 17] + + Arguments: + text: Input text (string). + n: int. Size of vocabulary. + filters: list (or concatenation) of characters to filter out, such as + punctuation. Default: ``!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n``, + includes basic punctuation, tabs, and newlines. + lower: boolean. Whether to set the text to lowercase. + split: str. Separator for word splitting. + + Returns: + List of integers in `[1, n]`. Each integer encodes a word + (unicity non-guaranteed). + """ + return text.one_hot( + text, n, filters=filters, lower=lower, split=split) + # text.tokenizer_from_json is only available if keras_preprocessing >= 1.1.0 try: @@ -41,3 +94,6 @@ try: tokenizer_from_json) except AttributeError: pass + +keras_export('keras.preprocessing.text.hashing_trick')(hashing_trick) +keras_export('keras.preprocessing.text.Tokenizer')(Tokenizer) From 5ae1f6d934008c7a5c6f094202ae738574e4487e Mon Sep 17 00:00:00 2001 From: Fabio Di Domenico Date: Mon, 16 Mar 2020 14:48:02 +0200 Subject: [PATCH 065/492] Expose ability to enable NNApi in C api --- tensorflow/lite/c/c_api.cc | 7 +++++++ tensorflow/lite/c/c_api.h | 4 ++++ tensorflow/lite/c/c_api_internal.h | 2 ++ tensorflow/lite/c/c_api_test.cc | 1 + 4 files changed, 14 insertions(+) diff --git a/tensorflow/lite/c/c_api.cc b/tensorflow/lite/c/c_api.cc index 8fd2ec0d51a..a2f050fbada 100644 --- a/tensorflow/lite/c/c_api.cc +++ b/tensorflow/lite/c/c_api.cc @@ -79,6 +79,11 @@ void TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions* options, options->num_threads = num_threads; } +void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options, + bool enable) { + options->useNNAPI = enable; +} + void TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions* options, TfLiteDelegate* delegate) { options->delegates.push_back(delegate); @@ -123,6 +128,8 @@ TfLiteInterpreter* TfLiteInterpreterCreate( } if (optional_options) { + interpreter->UseNNAPI(optional_options->useNNAPI); + if (optional_options->num_threads != TfLiteInterpreterOptions::kDefaultNumThreads) { interpreter->SetNumThreads(optional_options->num_threads); diff --git a/tensorflow/lite/c/c_api.h b/tensorflow/lite/c/c_api.h index 754fc3b8bbd..8b49cbb5411 100644 --- a/tensorflow/lite/c/c_api.h +++ b/tensorflow/lite/c/c_api.h @@ -120,6 +120,10 @@ TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsDelete( TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetNumThreads( TfLiteInterpreterOptions* options, int32_t num_threads); +// Enable or disable the NN API for the interpreter (true to enable). +TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI( + TfLiteInterpreterOptions* options, bool enable); + // Adds a delegate to be applied during `TfLiteInterpreter` creation. // // If delegate application fails, interpreter creation will also fail with an diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h index 973d822fce4..ce07f16c33d 100644 --- a/tensorflow/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -49,6 +49,8 @@ struct TfLiteInterpreterOptions { void* error_reporter_user_data = nullptr; std::vector delegates; + + bool useNNAPI = false; }; struct TfLiteInterpreter { diff --git a/tensorflow/lite/c/c_api_test.cc b/tensorflow/lite/c/c_api_test.cc index 1de35cc9dc7..59c60044d45 100644 --- a/tensorflow/lite/c/c_api_test.cc +++ b/tensorflow/lite/c/c_api_test.cc @@ -38,6 +38,7 @@ TEST(CApiSimple, Smoke) { TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); ASSERT_NE(options, nullptr); TfLiteInterpreterOptionsSetNumThreads(options, 2); + TfLiteInterpreterOptionsSetUseNNAPI(options, true); TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); ASSERT_NE(interpreter, nullptr); From 6c356d87f1e2b1354275bab9200c5be8dac2def6 Mon Sep 17 00:00:00 2001 From: angusluo Date: Tue, 17 Mar 2020 11:58:35 +0800 Subject: [PATCH 066/492] concat to tf.concat for consistency Change concat to tf.concat for consistency in this doc and easier to copy and experiment with the code. --- tensorflow/python/ops/array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 8c45161c450..cbb5db77801 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1538,14 +1538,14 @@ def concat(values, axis, name="concat"): >>> t1 = [[1, 2, 3], [4, 5, 6]] >>> t2 = [[7, 8, 9], [10, 11, 12]] - >>> concat([t1, t2], 0) + >>> tf.concat([t1, t2], 0) - >>> concat([t1, t2], 1) + >>> tf.concat([t1, t2], 1) From 51d76d6f722241b9e36e4e4b769412c1b5f9ed19 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 16 Mar 2020 18:49:54 +0900 Subject: [PATCH 067/492] minor spelling tweaks --- tensorflow/lite/delegates/flex/kernel_test.cc | 2 +- tensorflow/lite/delegates/gpu/README.md | 2 +- .../lite/delegates/gpu/cl/cl_command_queue.h | 4 ++-- tensorflow/lite/delegates/gpu/cl/cl_program.h | 2 +- tensorflow/lite/delegates/gpu/cl/gl_interop.h | 2 +- .../lite/delegates/gpu/cl/inference_context.h | 6 ++--- .../gpu/cl/kernels/fully_connected.cc | 2 +- .../delegates/gpu/cl/kernels/max_unpooling.cc | 8 +++---- .../delegates/gpu/cl/kernels/strided_slice.cc | 6 ++--- .../lite/delegates/gpu/cl/opencl_wrapper.h | 2 +- tensorflow/lite/delegates/gpu/cl/precision.h | 2 +- tensorflow/lite/delegates/gpu/cl/tensor.cc | 2 +- .../delegates/gpu/common/memory_management.h | 4 ++-- .../greedy_by_size_assignment.cc | 4 ++-- .../greedy_by_size_assignment.h | 4 ++-- .../gpu/common/memory_management/internal.h | 2 +- .../gpu/common/testing/interpreter_utils.h | 4 ++-- .../gpu/common/workgroup_selection.cc | 2 +- tensorflow/lite/delegates/gpu/delegate.h | 2 +- .../gpu/gl/compiler/variable_accessor.h | 2 +- tensorflow/lite/delegates/gpu/gl/gl_errors.cc | 2 +- tensorflow/lite/delegates/gpu/gl/gl_sync.h | 2 +- .../lite/delegates/gpu/gl/kernels/add_test.cc | 2 +- .../gl/workgroups/ideal_workgroup_picker.cc | 4 ++-- .../delegates/gpu/metal/compiled_model.cc | 22 +++++++++---------- .../delegates/gpu/metal/inference_context.h | 6 ++--- .../delegates/gpu/metal/kernels/add_test.mm | 2 +- .../delegates/nnapi/acceleration_test_list.cc | 4 ++-- .../lite/delegates/nnapi/nnapi_delegate.cc | 2 +- .../lite/delegates/nnapi/nnapi_delegate.h | 2 +- .../delegates/nnapi/nnapi_delegate_test.cc | 8 +++---- 31 files changed, 60 insertions(+), 60 deletions(-) diff --git a/tensorflow/lite/delegates/flex/kernel_test.cc b/tensorflow/lite/delegates/flex/kernel_test.cc index 5b3a6d16470..380dbfb4f03 100644 --- a/tensorflow/lite/delegates/flex/kernel_test.cc +++ b/tensorflow/lite/delegates/flex/kernel_test.cc @@ -38,7 +38,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, } // There is no easy way to pass a parameter into the TfLiteDelegate's -// 'prepare' function, so we keep a global map for testing purpused. +// 'prepare' function, so we keep a global map for testing purposed. // To avoid collisions use: GetPrepareFunction<__LINE__>(). std::map>* GetGlobalOpLists() { static auto* op_list = new std::map>; diff --git a/tensorflow/lite/delegates/gpu/README.md b/tensorflow/lite/delegates/gpu/README.md index ee21ba27b95..552e1cdbec6 100644 --- a/tensorflow/lite/delegates/gpu/README.md +++ b/tensorflow/lite/delegates/gpu/README.md @@ -113,7 +113,7 @@ const TfLiteGpuDelegateOptionsV2 kDefaultOptions = TfLiteGpuDelegateOptionsV2Default(); ``` -Similar for `NewTfLiteMetalDelgate()`: +Similar for `NewTfLiteMetalDelegate()`: ```c++ const TfLiteMetalDelegateOptions kDefaultOptions = { diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h index 18609c8309f..84ffeca67eb 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h @@ -124,9 +124,9 @@ class ProfilingCommandQueue : public CLCommandQueue { double GetQueueExecutionTimeMs() const; // Difference from GetQueueExecutionTimeMs is that this number doesn't include - // time between kernels(kernels launchs or preparing) on GPU. Usually, this + // time between kernels(kernels launches or preparing) on GPU. Usually, this // time should be 5-10% better than GetQueueExecutionTimeMs, because 5-10% - // spend on something else(maybe kernels launchs or preparing) + // spend on something else(maybe kernels launches or preparing) double GetSumOfEventsTimeMs() const; // This label will be used for all subsequent dispatches. diff --git a/tensorflow/lite/delegates/gpu/cl/cl_program.h b/tensorflow/lite/delegates/gpu/cl/cl_program.h index 997c31343af..b6deb3beb95 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_program.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_program.h @@ -64,7 +64,7 @@ class CLProgram { // Return the cl_device_id associated with the program object. // This can be the device associated with context on which the program object - // has been created or can be device that was specified when a progam object + // has been created or can be device that was specified when a program object // was created using clCreateProgramWithBinary. cl_device_id GetDeviceId() const { return device_id_; } diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.h b/tensorflow/lite/delegates/gpu/cl/gl_interop.h index 74c9553016b..597bee857c6 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.h +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.h @@ -46,7 +46,7 @@ Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, bool IsEglSyncFromClEventSupported(); // Creates CL event from EGL sync. -// Created event could only be comsumed by AcquiredGlObject::Acquire call as +// Created event could only be consumed by AcquiredGlObject::Acquire call as // a 'wait_event'. Status CreateClEventFromEglSync(cl_context context, const EglSync& egl_sync, CLEvent* event); diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.h b/tensorflow/lite/delegates/gpu/cl/inference_context.h index b8a0b7741b6..40b20e8806a 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.h +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.h @@ -47,7 +47,7 @@ struct CLNode { // for every operation. std::vector ranges; - // Mostly for debug purposess. + // Mostly for debug purposes. std::string name; CLNode() = default; @@ -129,8 +129,8 @@ class InferenceContext { CalculationsPrecision precision_; TensorStorageType storage_type_; - // Directly mapped nodes from graph, but some of them "inactiv" due - // to fusion (inactiv = fused). + // Directly mapped nodes from graph, but some of them "inactive" due + // to fusion (inactive = fused). // Memory is allocated only once, in ConvertOperations, and is not modified // anywhere. std::vector nodes_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc index c536e759210..44a3e97554c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc @@ -29,7 +29,7 @@ namespace { // vec mat mult) on 4 parts to create more threads // tid.y thread process every 4-th element in vec vec dot // Good results for ~1024 x 1024 sizes, for other can be written more -// otimized shaders +// optimized shaders std::string GetFullyConnectedKernelCode( const OperationDef& op_def, const LinearStorage& biases, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc index 98f0918e15f..194daee5f1e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc @@ -25,7 +25,7 @@ namespace gpu { namespace cl { namespace { -std::string GetMaxUnoolingKernelCode( +std::string GetMaxUnpoolingKernelCode( const OperationDef& op_def, const CLDevice& device, const std::vector& linked_operations) { TensorCodeGenerator src("src_data", @@ -102,7 +102,7 @@ std::string GetMaxUnoolingKernelCode( return c; } -std::string GetMaxUnooling3DKernelCode( +std::string GetMaxUnpooling3DKernelCode( const OperationDef& op_def, const CLDevice& device, const std::vector& linked_operations) { TensorCodeGenerator src( @@ -219,7 +219,7 @@ MaxUnpooling& MaxUnpooling::operator=(MaxUnpooling&& kernel) { } Status MaxUnpooling::Compile(const CreationContext& creation_context) { - const auto code = GetMaxUnoolingKernelCode( + const auto code = GetMaxUnpoolingKernelCode( definition_, *creation_context.device, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, @@ -292,7 +292,7 @@ MaxUnpooling3D& MaxUnpooling3D::operator=(MaxUnpooling3D&& kernel) { } Status MaxUnpooling3D::Compile(const CreationContext& creation_context) { - const auto code = GetMaxUnooling3DKernelCode( + const auto code = GetMaxUnpooling3DKernelCode( definition_, *creation_context.device, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc index 38f217dcd18..4f5cf9b26c7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc @@ -95,7 +95,7 @@ std::string GetStridedSliceCode( return c; } -bool Is4Alighed(const SliceAttributes& attr) { +bool Is4Aligned(const SliceAttributes& attr) { return attr.strides.c == 1 && attr.starts.c % 4 == 0; } @@ -129,7 +129,7 @@ int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height, offset.z = src_channels + attr.ends.c; } } - if (Is4Alighed(attr)) { + if (Is4Aligned(attr)) { offset.z /= 4; } if (attr.strides.b > 0) { @@ -167,7 +167,7 @@ StridedSlice& StridedSlice::operator=(StridedSlice&& operation) { } Status StridedSlice::Compile(const CreationContext& creation_context) { - const auto code = GetStridedSliceCode(definition_, Is4Alighed(attributes_), + const auto code = GetStridedSliceCode(definition_, Is4Aligned(attributes_), linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h index daf7f76773b..16ae24437a3 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h @@ -640,7 +640,7 @@ extern PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR; extern PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR; extern PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR; -// For convinient image creation +// For convenient image creation // It uses clCreateImage if it available (clCreateImage available since cl 1.2) // otherwise it will use legacy clCreateImage2D cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags, diff --git a/tensorflow/lite/delegates/gpu/cl/precision.h b/tensorflow/lite/delegates/gpu/cl/precision.h index e5bf480802b..f25db33673d 100644 --- a/tensorflow/lite/delegates/gpu/cl/precision.h +++ b/tensorflow/lite/delegates/gpu/cl/precision.h @@ -30,7 +30,7 @@ enum class CalculationsPrecision { F32, F32_F16, F16 }; // F32_F16 - as F16, but some operations (Convolution, // DepthWiseConvolution, FullyConnected, ConvolutionTransposed) // have accumulator in F32 and usually it calculates 4 mads in F16, sum them, -// than converts this partial sum to F32 and add to acumulator. +// than converts this partial sum to F32 and add to accumulator. DataType DeduceDataTypeFromPrecision(CalculationsPrecision precision); diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc index 610ba407eb9..e9de22c6dc0 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc @@ -475,7 +475,7 @@ Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, case TensorStorageType::SINGLE_TEXTURE_2D: { if (slices != 1) { return InvalidArgumentError(absl::StrCat( - "SINGLE_TEXTURE_2D support only cnannels in range [1-4], but ", + "SINGLE_TEXTURE_2D support only channels in range [1-4], but ", shape.c, "was provided")); } cl_image_desc desc; diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.h b/tensorflow/lite/delegates/gpu/common/memory_management.h index 652bb4b6e78..e45c361d955 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management.h @@ -82,7 +82,7 @@ enum class MemoryStrategy { Status BestGreedy(const std::vector>& usage_records, ObjectsAssignment* assignment); -// Calculates the assignement of shared objects to given tensors, including +// Calculates the assignment of shared objects to given tensors, including // objects' sizes. Below there are specializations for different types, that // support more memory strategies. // If reallocation_graph is provided, assignment of shared objects support @@ -130,7 +130,7 @@ Status AssignObjectsToTensors( MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); -// Calculates the assignement of tensors to offsets, considering those tensors +// Calculates the assignment of tensors to offsets, considering those tensors // are going to be allocated in one continuous memory block. Status AssignOffsetsToTensors( const std::vector>& usage_records, diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc index 1234326b4ea..bf56c6d92dd 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc @@ -67,7 +67,7 @@ Status GreedyBySizeAssignment( assignment->offsets.resize(num_tensors); assignment->total_size = 0; - // Ordered records are to be sorted by size of corrseponding tensor. + // Ordered records are to be sorted by size of corresponding tensor. std::vector> ordered_records; for (size_t i = 0; i < num_tensors; ++i) { ordered_records.emplace_back(&usage_records[i], i); @@ -133,7 +133,7 @@ Status GreedyBySizeAssignment( // - We have tensor usage records of all intermideate tensors as an input. Each // record consists of tensor size, first and last tasks, that use it. Let's call // [first_task..last_task] a tensor usage interval; -// - Distance between two usage intervals is the absoulte difference between +// - Distance between two usage intervals is the absolute difference between // closest tasks in their intervals. If two usage intervals don't intersect, // than the distance between them is positive; // - Calculate positional maximums vector, e.g. the vector of lower bounds on diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h index 2cb8ceee0e1..fb875fd0920 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h @@ -36,7 +36,7 @@ namespace gpu { // gap; // - If such a gap has been found, current tensor should be allocated into this // gap. Otherwise we can allocate it after the rightmost tensor, which usage -// interval intersects with usage inteval of current tensor. So we assign +// interval intersects with usage interval of current tensor. So we assign // corresponding offset to current tensor and the tensor becomes assigned. Status GreedyBySizeAssignment( const std::vector>& usage_records, @@ -47,7 +47,7 @@ Status GreedyBySizeAssignment( // - We have tensor usage records of all intermideate tensors as an input. Each // record consists of tensor size, first and last tasks, that use it. Let's call // [first_task..last_task] a tensor usage interval; -// - Distance between two usage intervals is the absoulte difference between +// - Distance between two usage intervals is the absolute difference between // closest tasks in their intervals. If two usage intervals don't intersect, // than the distance between them is positive; // - Calculate positional maximums vector, e.g. the vector of lower bounds on diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/internal.h b/tensorflow/lite/delegates/gpu/common/memory_management/internal.h index c9e93a721f4..702fd2992cc 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/internal.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/internal.h @@ -46,7 +46,7 @@ bool CompareBySize(const TensorUsageWithIndex& first, const TensorUsageWithIndex& second); // TaskProfile is a vector with information about all intermediate tensors, that -// should exist in memory during the executon of the task. Elements of the +// should exist in memory during the execution of the task. Elements of the // vector must be sorted in non-increasing order of corresponding tensors sizes. using TaskProfile = std::vector>; diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h index 8ce9d2bfb20..a38a5d1363a 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h +++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h @@ -29,7 +29,7 @@ namespace gpu { namespace testing { // Runs Tensorflow Lite model using Tensorflow Lite with a delegate and -// an appropriate operations resolver. If delegate is nullptr, infererence will +// an appropriate operations resolver. If delegate is nullptr, inference will // be done only on CPU. Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, TfLiteDelegate* delegate, @@ -38,7 +38,7 @@ Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, std::vector* outputs); // Runs Tensorflow Lite model using Tensorflow Lite with a delegate and -// builtin operations resolver. If delegate is nullptr, infererence will +// builtin operations resolver. If delegate is nullptr, inference will // be done only on CPU. Status InterpreterInvoke(const ::tflite::Model* model, TfLiteDelegate* delegate, const std::vector& inputs, diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc index 47d9f8c3060..d6d22aa6a62 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc @@ -126,7 +126,7 @@ std::vector GetDivisorsForRange(int number, int range) { std::vector GetPossibleSizes(int number, WorkGroupSizeAlignment z_alignment) { if (z_alignment == WorkGroupSizeAlignment::PRECISE) { - // we will use for potential sizes, sizes that cover grid preciselly + // we will use for potential sizes, sizes that cover grid precisely // work group size * k (k is integer) == grid_size return GetDivisors(number); } else { diff --git a/tensorflow/lite/delegates/gpu/delegate.h b/tensorflow/lite/delegates/gpu/delegate.h index 65c52310f2f..29bececf39b 100644 --- a/tensorflow/lite/delegates/gpu/delegate.h +++ b/tensorflow/lite/delegates/gpu/delegate.h @@ -79,7 +79,7 @@ typedef struct { // each time inference engine needs to make a decision, it uses // ordered priorities to do so. // For example: - // MAX_PRECISION at priority1 would not allow to decrease presision, + // MAX_PRECISION at priority1 would not allow to decrease precision, // but moving it to priority2 or priority3 would result in F16 calculation. // // Priority is defined in TfLiteGpuInferencePriority. diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h index ac807ff9a98..c9946a00395 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h @@ -60,7 +60,7 @@ class VariableAccessor : public InlineRewrite { // Returns const variables that need to be inlined in the a shader's code. std::string GetConstDeclarations() const; - // Returns shared varaible declarations that need to be inlined. + // Returns shared variable declarations that need to be inlined. std::string GetSharedVariableDeclarations() const; // Returns uniform parameter declarations that need to be inlined. diff --git a/tensorflow/lite/delegates/gpu/gl/gl_errors.cc b/tensorflow/lite/delegates/gpu/gl/gl_errors.cc index 2c29127839d..1a40e38ea9c 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_errors.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_errors.cc @@ -131,7 +131,7 @@ Status GetEglError() { case EGL_CONTEXT_LOST: return InternalError( "A power management event has occurred. The application must destroy " - "all contexts and reinitialise OpenGL ES state and objects to " + "all contexts and reinitialize OpenGL ES state and objects to " "continue rendering."); } return UnknownError("EGL error: " + std::to_string(error)); diff --git a/tensorflow/lite/delegates/gpu/gl/gl_sync.h b/tensorflow/lite/delegates/gpu/gl/gl_sync.h index 4f89e01abed..dadb4b1192f 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_sync.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_sync.h @@ -77,7 +77,7 @@ class GlSync { // Waits until GPU is done with processing. Status GlSyncWait(); -// Waits until all comands are flushed and then performs active waiting by +// Waits until all commands are flushed and then performs active waiting by // spinning a thread and checking sync status. It leads to shorter wait time // (up to tens of ms) but consumes more CPU. Status GlActiveSyncWait(); diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/add_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/add_test.cc index e3be6205b17..f4c81841b9f 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/add_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/add_test.cc @@ -74,7 +74,7 @@ TEST(AddTest, InputTensorAndScalar) { Pointwise(FloatNear(1e-6), {-1.9, 0.3, 0.8, 0.9, 1.2, 2.1})); } -TEST(AddTest, InputTensorWithConstandBroadcast) { +TEST(AddTest, InputTensorWithConstantBroadcast) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc index 65636fe6467..b67cc36c903 100644 --- a/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc @@ -33,8 +33,8 @@ namespace { // (b/117291356). // Describes the ideal convolution for the specific operation case -// Case here means specific "kernel + strides" conbination for specific -// operatoins type, not sizes of input and output tensors, they can be any. +// Case here means specific "kernel + strides" combination for specific +// operations type, not sizes of input and output tensors, they can be any. struct IdealByCase { bool ParamsAccepted(OperationType in_op_type, HW in_kernel, HW in_strides) const { diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc index 716376b735b..9608aaddeb4 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc @@ -129,7 +129,7 @@ uint32_t BufferUseCount(ValueId id, } // Examines if the second operation can be linked to the first one. Linking may -// be skipped in the situation when conflic may happen: if first operation's +// be skipped in the situation when conflict may happen: if first operation's // output is used by more than 1 other operation. bool CanFuseOperations(const ComputeTaskDescriptorPtr first, const ComputeTaskDescriptorPtr second, @@ -444,9 +444,9 @@ ComputeTaskDescriptorPtr NonLinkableStub(int operation_id, ValueId input_id, } ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { - auto fused_desciptor = std::make_shared(); + auto fused_descriptor = std::make_shared(); // The id of fused descriptor is the id of the first descriptor in the list. - fused_desciptor->id = chain.front()->id; + fused_descriptor->id = chain.front()->id; FusionSequence sequence; if (chain.front()->is_linkable) { // The first task is linkable so it contains only linkable code. Insert @@ -503,7 +503,7 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { buffer.declaration + name + "[[buffer(" + index + ")]],\n"; call_arguments += ", buffer" + index; input_index++; - fused_desciptor->input_buffers.push_back({buffer.id, ""}); + fused_descriptor->input_buffers.push_back({buffer.id, ""}); } } // We have an output id that is the input for the next task. @@ -517,7 +517,7 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { buffer.declaration + name + "[[buffer(" + index + ")]],\n"; call_arguments += ", buffer" + index; immutable_index++; - fused_desciptor->immutable_buffers.push_back(buffer); + fused_descriptor->immutable_buffers.push_back(buffer); } for (auto buffer : desc->uniform_buffers) { @@ -527,7 +527,7 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { buffer.declaration + name + "[[buffer(" + index + ")]],\n"; call_arguments += ", buffer" + index; uniform_index++; - fused_desciptor->uniform_buffers.push_back({"", buffer.data_function}); + fused_descriptor->uniform_buffers.push_back({"", buffer.data_function}); } if (desc->is_linkable) { @@ -539,7 +539,7 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { } ComputeTaskDescriptorPtr non_linkable = sequence.front(); - fused_desciptor->shader_source = + fused_descriptor->shader_source = absl::Substitute(non_linkable->shader_source, function_code, buffer_declarations, call_code); std::vector alias; @@ -547,13 +547,13 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { for (int i = 0; i < chain.size() - 1; i++) { alias.push_back(chain[i]->output_buffer.id); } - fused_desciptor->output_buffer = { + fused_descriptor->output_buffer = { fused_id, "", non_linkable->output_buffer.dimensions_function, alias}; - fused_desciptor->resize_function = non_linkable->resize_function; + fused_descriptor->resize_function = non_linkable->resize_function; for (const auto& desc : sequence) { - fused_desciptor->description += desc->description + "_"; + fused_descriptor->description += desc->description + "_"; } - return fused_desciptor; + return fused_descriptor; } } // namespace diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h index 536b87f780c..8569a4ed009 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.h +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h @@ -35,7 +35,7 @@ limitations under the License. /// 2. Model compilation. Global list of ComputeTaskDescriptors is transformed /// into the sorted list of sets of descriptors. A set can be transformed /// later into a single GPU task. -/// 3. GPU compute tasks generation. Shader code generation happes here. +/// 3. GPU compute tasks generation. Shader code generation happens here. /// 4. Intermediate resource allocation. /// Inference. @interface TFLInferenceContext : NSObject @@ -72,11 +72,11 @@ limitations under the License. /// Inserts all GPU compute tasks into the command encoder. /// @param inputOutputBuffers Must be created and passed into the method with pairs ID:buffer /// @param encoderBlock User-defined block to take control over command encoder. Can be nil. -/// The block can be used, for example, for fine-graned benchmarking where end encoding +/// The block can be used, for example, for fine-grained benchmarking where end encoding /// is performed and command buffer is committed with completion block. A new command /// buffer must be created and new command encoder must be returned by the block. /// The block is called after every dispatch encoding. -/// @discussion No GPU sychronization functions are used inside. All GPU resources must be created +/// @discussion No GPU synchronization functions are used inside. All GPU resources must be created /// with the same device which has been used in compileModelWithDevice() method. - (void)encodeWithEncoder:(id)commandEncoder inputOutputBuffers:(const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm index 4cd675f9032..10481b2a867 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm @@ -90,7 +90,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } -- (void)testInputTensorWithConstandBroadcast { +- (void)testInputTensorWithConstantBroadcast { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index ea41ea01d81..14f48d3ffed 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -24,7 +24,7 @@ const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig = # # The test_id is test_suite_name / test_name, this differs from the # name used by the build because of the / separator instead of . -# Parametrised tests names are composed by the base test name / test / ordinal +# Parameterized tests names are composed by the base test name / test / ordinal # the ordinal is the position in the list of parameters generated by the # cardinal product of all the different parameter sets @@ -39,7 +39,7 @@ const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig = ## Test Arguments # -# The test can be parametrised with the minimum Android SDK version +# The test can be parameterized with the minimum Android SDK version # to apply the acceleration validation for. # If omitted will use 27 diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index a90b28fcd10..341d79adb48 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -155,7 +155,7 @@ bool IsScalarInputSupported(int builtin_code) { } } -// Check if the operation requires explict conversion from int8 to uint8 values. +// Check if the operation requires explicit conversion from int8 to uint8 values. bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, const TfLiteNode* node) { const int input_id = node->inputs->data[0]; diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h index 2bc43620d96..fe777ea99aa 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h @@ -172,7 +172,7 @@ class StatefulNnApiDelegate : public TfLiteDelegate { bool disallow_nnapi_cpu; // Tensor to ANeuralNetworksMemory mapping. std::vector tensor_memory_map; - // Constains a non zero value if any NNAPI method call + // Contains a non zero value if any NNAPI method call // operation returned a non zero result code. int nnapi_errno; // Cache of kernels already built in StatefulNnApiDelegate::DoPrepare diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc index 919c1ddcc2b..ea9111c4567 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -4811,17 +4811,17 @@ class PadV2OpConstModel : public PadOpModel { }; // Test case where paddings is a non-const tensor. -template -class PadV2OpDynamicModel : public PadOpModel { +template +class PadV2OpDynamicModel : public PadOpModel { public: PadV2OpDynamicModel(const TensorData& input, std::initializer_list paddings_shape, - RegularInputOuput constant_values, + RegularInputOutput constant_values, const TensorData& output) { this->input_ = this->AddInput(input); this->paddings_ = this->AddInput(TensorType_INT32); this->constant_values_ = this->AddConstInput( - GetTensorType(), {constant_values}, {1}); + GetTensorType(), {constant_values}, {1}); this->output_ = this->AddOutput(output); this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options, From b97023504c53efab507c20ed8af5f6430c475834 Mon Sep 17 00:00:00 2001 From: Fabio Di Domenico Date: Tue, 17 Mar 2020 09:55:26 +0200 Subject: [PATCH 068/492] Moved to experimental api --- tensorflow/lite/c/c_api.cc | 5 ----- tensorflow/lite/c/c_api.h | 4 ---- tensorflow/lite/c/c_api_experimental.cc | 5 +++++ tensorflow/lite/c/c_api_experimental.h | 4 ++++ tensorflow/lite/c/c_api_experimental_test.cc | 1 + tensorflow/lite/c/c_api_test.cc | 1 - 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/c/c_api.cc b/tensorflow/lite/c/c_api.cc index a2f050fbada..831dcc10286 100644 --- a/tensorflow/lite/c/c_api.cc +++ b/tensorflow/lite/c/c_api.cc @@ -79,11 +79,6 @@ void TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions* options, options->num_threads = num_threads; } -void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options, - bool enable) { - options->useNNAPI = enable; -} - void TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions* options, TfLiteDelegate* delegate) { options->delegates.push_back(delegate); diff --git a/tensorflow/lite/c/c_api.h b/tensorflow/lite/c/c_api.h index 8b49cbb5411..754fc3b8bbd 100644 --- a/tensorflow/lite/c/c_api.h +++ b/tensorflow/lite/c/c_api.h @@ -120,10 +120,6 @@ TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsDelete( TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetNumThreads( TfLiteInterpreterOptions* options, int32_t num_threads); -// Enable or disable the NN API for the interpreter (true to enable). -TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI( - TfLiteInterpreterOptions* options, bool enable); - // Adds a delegate to be applied during `TfLiteInterpreter` creation. // // If delegate application fails, interpreter creation will also fail with an diff --git a/tensorflow/lite/c/c_api_experimental.cc b/tensorflow/lite/c/c_api_experimental.cc index 4a3354e5f55..e934d7fede9 100644 --- a/tensorflow/lite/c/c_api_experimental.cc +++ b/tensorflow/lite/c/c_api_experimental.cc @@ -50,6 +50,11 @@ void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options, options->op_resolver.AddCustom(name, registration, min_version, max_version); } +void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options, + bool enable) { + options->useNNAPI = enable; +} + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/c/c_api_experimental.h b/tensorflow/lite/c/c_api_experimental.h index a647e32b479..0398c385874 100644 --- a/tensorflow/lite/c/c_api_experimental.h +++ b/tensorflow/lite/c/c_api_experimental.h @@ -49,6 +49,10 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp( const TfLiteRegistration* registration, int32_t min_version, int32_t max_version); +// Enable or disable the NN API for the interpreter (true to enable). +TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI( + TfLiteInterpreterOptions* options, bool enable); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/c/c_api_experimental_test.cc b/tensorflow/lite/c/c_api_experimental_test.cc index 6de8236d5e7..71a08b5af26 100644 --- a/tensorflow/lite/c/c_api_experimental_test.cc +++ b/tensorflow/lite/c/c_api_experimental_test.cc @@ -41,6 +41,7 @@ TEST(CApiExperimentalTest, Smoke) { TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); TfLiteInterpreterOptionsAddBuiltinOp(options, kTfLiteBuiltinAdd, GetDummyRegistration(), 1, 1); + TfLiteInterpreterOptionsSetUseNNAPI(options, true); TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); ASSERT_NE(interpreter, nullptr); diff --git a/tensorflow/lite/c/c_api_test.cc b/tensorflow/lite/c/c_api_test.cc index 59c60044d45..1de35cc9dc7 100644 --- a/tensorflow/lite/c/c_api_test.cc +++ b/tensorflow/lite/c/c_api_test.cc @@ -38,7 +38,6 @@ TEST(CApiSimple, Smoke) { TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); ASSERT_NE(options, nullptr); TfLiteInterpreterOptionsSetNumThreads(options, 2); - TfLiteInterpreterOptionsSetUseNNAPI(options, true); TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); ASSERT_NE(interpreter, nullptr); From 86dcc9732d97cc7ff60e1317b5107ad7a11cdabb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 09:30:23 -0700 Subject: [PATCH 069/492] Add uint64 support to DynamicPartition and FloorMod on CPU. PiperOrigin-RevId: 301389554 Change-Id: I5414bc9f35e7f42a52c63767aa23f2508dcb3e37 --- tensorflow/core/kernels/cwise_op_floor_mod.cc | 3 ++- tensorflow/core/kernels/dynamic_partition_op.cc | 2 ++ tensorflow/core/ops/math_ops.cc | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_floor_mod.cc b/tensorflow/core/kernels/cwise_op_floor_mod.cc index 481fc3b8989..3305f54bcca 100644 --- a/tensorflow/core/kernels/cwise_op_floor_mod.cc +++ b/tensorflow/core/kernels/cwise_op_floor_mod.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER2(BinaryOp, CPU, "FloorMod", functor::safe_floor_mod, int32, int64); +REGISTER3(BinaryOp, CPU, "FloorMod", functor::safe_floor_mod, int32, int64, + uint64); REGISTER2(BinaryOp, CPU, "FloorMod", functor::floor_fmod, float, double); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/dynamic_partition_op.cc b/tensorflow/core/kernels/dynamic_partition_op.cc index 95af19c4c48..90ed71dccce 100644 --- a/tensorflow/core/kernels/dynamic_partition_op.cc +++ b/tensorflow/core/kernels/dynamic_partition_op.cc @@ -164,6 +164,8 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared { DynamicPartitionOp) TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION); +// For partitioning fingerprints. +TF_CALL_uint64(REGISTER_DYNAMIC_PARTITION); #undef REGISTER_DYNAMIC_PARTITION } // namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 4ab1d3e68d0..e441c73cc57 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -587,7 +587,7 @@ REGISTER_OP("FloorMod") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {int32, int64, bfloat16, half, float, double}") + .Attr("T: {int32, int64, uint64, bfloat16, half, float, double}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); REGISTER_OP("TruncateMod") From b62d3ccac58b432452f6730aed6a37c605491ac2 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 17 Mar 2020 09:46:39 -0700 Subject: [PATCH 070/492] Start XLA HLO to TensorFlow ops legalization Legalize from XLA HLO op to TensorFlow op some simple ops. Started by just flipping some of the patterns around (reasonably arbitrarily chosen binary and unary ops). There is no expectation that TF -> HLO -> TF would given a graph close to the original. Useful for interop and some staging work. Adding a pass but not as part of any pipeline. Currently this live in tensorflow/ directory as that matches the current convention here where the legalization pass lives with the target dialect, Will move to `Conversion` to match the new layout proposed [upstream](https://llvm.discourse.group/t/rfc-canonical-file-paths-to-dialects/621). PiperOrigin-RevId: 301392804 Change-Id: Ie7b030019dfc47aa74dab502113628de5329f808 --- tensorflow/compiler/mlir/BUILD | 1 + tensorflow/compiler/mlir/tensorflow/BUILD | 38 ++ .../mlir/tensorflow/tests/legalize_hlo.mlir | 552 ++++++++++++++++++ .../tensorflow/transforms/legalize_hlo.cc | 70 +++ .../transforms/legalize_hlo_patterns.td | 72 +++ 5 files changed, 733 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 2ed1c274f75..7ad8a80695d 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -71,6 +71,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/xla:lhlo", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index e2aae0ec52e..3bed4e753e0 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -167,6 +167,44 @@ gentbl( ], ) +gentbl( + name = "hlo_legalize_tf_inc_gen", + tbl_outs = [ + ("-gen-rewriters", "transforms/generated_legalize_hlo.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/legalize_hlo_patterns.td", + td_srcs = [ + "//tensorflow/compiler/mlir/xla:hlo_ops_td_files", + "@llvm-project//llvm:support", + "@llvm-project//mlir:StdOpsTdFiles", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + ], +) + +cc_library( + name = "tf_legalize_hlo", + srcs = [ + "transforms/generated_legalize_hlo.inc", + "transforms/legalize_hlo.cc", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/core:framework", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "tensorflow_types", srcs = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir new file mode 100644 index 00000000000..c1b53debd7c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -0,0 +1,552 @@ +// RUN: tf-opt -tf-legalize-hlo %s | FileCheck %s --dump-input-on-failure + +//===----------------------------------------------------------------------===// +// Binary op legalizations. +//===----------------------------------------------------------------------===// + +func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { +%0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32> +%1 = xla_hlo.add %0, %arg0 : tensor<2xi32> +return %1 : tensor<2xi32> +} +// CHECK-LABEL: func @add( +// CHECK-SAME: [[VAL_0:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_1:%.*]] = "tf.AddV2"([[VAL_0]], [[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: [[VAL_2:%.*]] = "tf.AddV2"([[VAL_1]], [[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_2]] : tensor<2xi32> + +func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { +%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +return %0 : tensor<1x2xi32> +} +// CHECK-LABEL: func @broadcast_add( +// CHECK-SAME: [[VAL_3:%.*]]: tensor<1xi32>, [[VAL_4:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_5:%.*]] = "tf.AddV2"([[VAL_3]], [[VAL_4]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_5]] : tensor<1x2xi32> + +func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { +%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> +return %0 : tensor<4x4x4x4xi32> +} +// CHECK-LABEL: func @broadcast_multi_dim_add( +// CHECK-SAME: [[VAL_6:%.*]]: tensor<4x1x1xi32>, [[VAL_7:%.*]]: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { +// CHECK: [[VAL_8:%.*]] = "tf.AddV2"([[VAL_6]], [[VAL_7]]) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> +// CHECK: return [[VAL_8]] : tensor<4x4x4x4xi32> + +func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { +%0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> +return %0 : tensor<2xi32> +} +// CHECK-LABEL: func @div( +// CHECK-SAME: [[VAL_9:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_10:%.*]] = "tf.RealDiv"([[VAL_9]], [[VAL_9]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_10]] : tensor<2xi32> + +func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { +%0 = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +return %0 : tensor<1x2xi32> +} +// CHECK-LABEL: func @broadcast_div( +// CHECK-SAME: [[VAL_11:%.*]]: tensor<1xi32>, [[VAL_12:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_13:%.*]] = "tf.RealDiv"([[VAL_11]], [[VAL_12]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_13]] : tensor<1x2xi32> + +func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { +%0 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> +return %0 : tensor<4xi32> +} +// CHECK-LABEL: func @shift_left( +// CHECK-SAME: [[VAL_14:%.*]]: tensor<4xi32>, [[VAL_15:%.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: [[VAL_16:%.*]] = "tf.LeftShift"([[VAL_14]], [[VAL_15]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return [[VAL_16]] : tensor<4xi32> + +func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { +%0 = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor +return %0 : tensor +} +// CHECK-LABEL: func @div_dynamic( +// CHECK-SAME: [[VAL_17:%.*]]: tensor, [[VAL_18:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_19:%.*]] = "tf.RealDiv"([[VAL_17]], [[VAL_18]]) : (tensor, tensor) -> tensor +// CHECK: return [[VAL_19]] : tensor + +func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { +%0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor +return %0 : tensor +} +// CHECK-LABEL: func @div_unranked( +// CHECK-SAME: [[VAL_20:%.*]]: tensor<*xi32>, [[VAL_21:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_22:%.*]] = "tf.Div"([[VAL_20]], [[VAL_21]]) : (tensor<*xi32>, tensor) -> tensor +// CHECK: return [[VAL_22]] : tensor + +func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +%0 = xla_hlo.max %arg0, %arg1 : tensor<4xf32> +return %0 : tensor<4xf32> +} +// CHECK-LABEL: func @maximum( +// CHECK-SAME: [[VAL_23:%.*]]: tensor<4xf32>, [[VAL_24:%.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: [[VAL_25:%.*]] = "tf.Maximum"([[VAL_23]], [[VAL_24]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return [[VAL_25]] : tensor<4xf32> + +func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +%0 = xla_hlo.min %arg0, %arg1 : tensor<4xf32> +return %0 : tensor<4xf32> +} +// CHECK-LABEL: func @minimum( +// CHECK-SAME: [[VAL_26:%.*]]: tensor<4xf32>, [[VAL_27:%.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: [[VAL_28:%.*]] = "tf.Minimum"([[VAL_26]], [[VAL_27]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return [[VAL_28]] : tensor<4xf32> + +func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { +%0 = xla_hlo.mul %arg0, %arg0 : tensor<2xi32> +return %0 : tensor<2xi32> +} +// CHECK-LABEL: func @mul( +// CHECK-SAME: [[VAL_29:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_30:%.*]] = "tf.Mul"([[VAL_29]], [[VAL_29]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_30]] : tensor<2xi32> + +func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { +%0 = "xla_hlo.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +return %0 : tensor<1x2xi32> +} +// CHECK-LABEL: func @broadcast_mul( +// CHECK-SAME: [[VAL_31:%.*]]: tensor<1xi32>, [[VAL_32:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_33:%.*]] = "tf.Mul"([[VAL_31]], [[VAL_32]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_33]] : tensor<1x2xi32> + +func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { +%0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> +return %0 : tensor<2xi32> +} +// CHECK-LABEL: func @real_div( +// CHECK-SAME: [[VAL_34:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_35:%.*]] = "tf.RealDiv"([[VAL_34]], [[VAL_34]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_35]] : tensor<2xi32> + +func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { +%0 = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +return %0 : tensor<1x2xi32> +} +// CHECK-LABEL: func @broadcast_real_div( +// CHECK-SAME: [[VAL_36:%.*]]: tensor<1xi32>, [[VAL_37:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_38:%.*]] = "tf.RealDiv"([[VAL_36]], [[VAL_37]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_38]] : tensor<1x2xi32> + +func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { +%0 = xla_hlo.sub %arg0, %arg0 : tensor<2xi32> +return %0 : tensor<2xi32> +} +// CHECK-LABEL: func @sub( +// CHECK-SAME: [[VAL_39:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_40:%.*]] = "tf.Sub"([[VAL_39]], [[VAL_39]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_40]] : tensor<2xi32> + +func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { +%0 = "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +return %0 : tensor<1x2xi32> +} +// CHECK-LABEL: func @broadcast_sub( +// CHECK-SAME: [[VAL_41:%.*]]: tensor<1xi32>, [[VAL_42:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_43:%.*]] = "tf.Sub"([[VAL_41]], [[VAL_42]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_43]] : tensor<1x2xi32> + +func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { +%0 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> +return %0 : tensor<4xi32> +} +// CHECK-LABEL: func @shift_right( +// CHECK-SAME: [[VAL_44:%.*]]: tensor<4xi32>, [[VAL_45:%.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: [[VAL_46:%.*]] = "tf.RightShift"([[VAL_44]], [[VAL_45]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return [[VAL_46]] : tensor<4xi32> + +func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { +%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +return %0 : tensor<2x4xi32> +} +// CHECK-LABEL: func @broadcast_shift_right( +// CHECK-SAME: [[VAL_47:%.*]]: tensor<4xi32>, [[VAL_48:%.*]]: tensor<2x4xi32>) -> tensor<2x4xi32> { +// CHECK: [[VAL_49:%.*]] = "tf.RightShift"([[VAL_47]], [[VAL_48]]) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK: return [[VAL_49]] : tensor<2x4xi32> + +func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { +%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> +return %0 : tensor<4xui8> +} +// CHECK-LABEL: func @shift_right_unsigned( +// CHECK-SAME: [[VAL_50:%.*]]: tensor<4xui8>, [[VAL_51:%.*]]: tensor<4xui8>) -> tensor<4xui8> { +// CHECK: [[VAL_52:%.*]] = "tf.RightShift"([[VAL_50]], [[VAL_51]]) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> +// CHECK: return [[VAL_52]] : tensor<4xui8> + +func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { +%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> +return %0 : tensor<2x4xui8> +} +// CHECK-LABEL: func @broadcast_shift_right_unsigned( +// CHECK-SAME: [[VAL_53:%.*]]: tensor<4xui8>, [[VAL_54:%.*]]: tensor<2x4xui8>) -> tensor<2x4xui8> { +// CHECK: [[VAL_55:%.*]] = "tf.RightShift"([[VAL_53]], [[VAL_54]]) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> +// CHECK: return [[VAL_55]] : tensor<2x4xui8> + +//===----------------------------------------------------------------------===// +// Unary op legalizations. +//===----------------------------------------------------------------------===// + +func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @abs( +// CHECK-SAME: [[VAL_0:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_1:%.*]] = "tf.Abs"([[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_1]] : tensor<2xf32> + +func @abs_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.abs"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @abs_dynamic( +// CHECK-SAME: [[VAL_2:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_3:%.*]] = "tf.Abs"([[VAL_2]]) : (tensor) -> tensor +// CHECK: return [[VAL_3]] : tensor + +func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @abs_unranked( +// CHECK-SAME: [[VAL_4:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_5:%.*]] = "tf.Abs"([[VAL_4]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_5]] : tensor<*xf32> + +func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @ceil( +// CHECK-SAME: [[VAL_6:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_7:%.*]] = "tf.Ceil"([[VAL_6]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_7]] : tensor<2xf32> + +func @ceil_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.ceil"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @ceil_dynamic( +// CHECK-SAME: [[VAL_8:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_9:%.*]] = "tf.Ceil"([[VAL_8]]) : (tensor) -> tensor +// CHECK: return [[VAL_9]] : tensor + +func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @ceil_unranked( +// CHECK-SAME: [[VAL_10:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_11:%.*]] = "tf.Ceil"([[VAL_10]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_11]] : tensor<*xf32> + +func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @cos( +// CHECK-SAME: [[VAL_12:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_13:%.*]] = "tf.Cos"([[VAL_12]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_13]] : tensor<2xf32> + +func @cos_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.cos"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @cos_dynamic( +// CHECK-SAME: [[VAL_14:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_15:%.*]] = "tf.Cos"([[VAL_14]]) : (tensor) -> tensor +// CHECK: return [[VAL_15]] : tensor + +func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @cos_unranked( +// CHECK-SAME: [[VAL_16:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_17:%.*]] = "tf.Cos"([[VAL_16]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_17]] : tensor<*xf32> + +func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @exp( +// CHECK-SAME: [[VAL_18:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_19:%.*]] = "tf.Exp"([[VAL_18]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_19]] : tensor<2xf32> + +func @exp_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.exp"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @exp_dynamic( +// CHECK-SAME: [[VAL_20:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_21:%.*]] = "tf.Exp"([[VAL_20]]) : (tensor) -> tensor +// CHECK: return [[VAL_21]] : tensor + +func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @exp_unranked( +// CHECK-SAME: [[VAL_22:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_23:%.*]] = "tf.Exp"([[VAL_22]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_23]] : tensor<*xf32> + +func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @floor( +// CHECK-SAME: [[VAL_24:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_25:%.*]] = "tf.Floor"([[VAL_24]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_25]] : tensor<2xf32> + +func @floor_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @floor_dynamic( +// CHECK-SAME: [[VAL_26:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_27:%.*]] = "tf.Floor"([[VAL_26]]) : (tensor) -> tensor +// CHECK: return [[VAL_27]] : tensor + +func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @floor_unranked( +// CHECK-SAME: [[VAL_28:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_29:%.*]] = "tf.Floor"([[VAL_28]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_29]] : tensor<*xf32> + +func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { + %0 = "xla_hlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} +// CHECK-LABEL: func @is_finite( +// CHECK-SAME: [[VAL_30:%.*]]: tensor<2xf32>) -> tensor<2xi1> { +// CHECK: [[VAL_31:%.*]] = "tf.IsFinite"([[VAL_30]]) : (tensor<2xf32>) -> tensor<2xi1> +// CHECK: return [[VAL_31]] : tensor<2xi1> + +func @is_finite_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.is_finite"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @is_finite_dynamic( +// CHECK-SAME: [[VAL_32:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_33:%.*]] = "tf.IsFinite"([[VAL_32]]) : (tensor) -> tensor +// CHECK: return [[VAL_33]] : tensor + +func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { + %0 = "xla_hlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} +// CHECK-LABEL: func @is_finite_unranked( +// CHECK-SAME: [[VAL_34:%.*]]: tensor<*xf32>) -> tensor<*xi1> { +// CHECK: [[VAL_35:%.*]] = "tf.IsFinite"([[VAL_34]]) : (tensor<*xf32>) -> tensor<*xi1> +// CHECK: return [[VAL_35]] : tensor<*xi1> + +func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @log( +// CHECK-SAME: [[VAL_36:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_37:%.*]] = "tf.Log"([[VAL_36]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_37]] : tensor<2xf32> + +func @log_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.log"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @log_dynamic( +// CHECK-SAME: [[VAL_38:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_39:%.*]] = "tf.Log"([[VAL_38]]) : (tensor) -> tensor +// CHECK: return [[VAL_39]] : tensor + +func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @log_unranked( +// CHECK-SAME: [[VAL_40:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_41:%.*]] = "tf.Log"([[VAL_40]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_41]] : tensor<*xf32> + +func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @log1p( +// CHECK-SAME: [[VAL_42:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_43:%.*]] = "tf.Log1p"([[VAL_42]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_43]] : tensor<2xf32> + +func @log1p_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @log1p_dynamic( +// CHECK-SAME: [[VAL_44:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_45:%.*]] = "tf.Log1p"([[VAL_44]]) : (tensor) -> tensor +// CHECK: return [[VAL_45]] : tensor + +func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @log1p_unranked( +// CHECK-SAME: [[VAL_46:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_47:%.*]] = "tf.Log1p"([[VAL_46]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_47]] : tensor<*xf32> + +func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { + %0 = "xla_hlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} +// CHECK-LABEL: func @not_op_unranked( +// CHECK-SAME: [[VAL_48:%.*]]: tensor<*xi1>) -> tensor<*xi1> { +// CHECK: [[VAL_49:%.*]] = "tf.LogicalNot"([[VAL_48]]) : (tensor<*xi1>) -> tensor<*xi1> +// CHECK: return [[VAL_49]] : tensor<*xi1> + +func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @neg( +// CHECK-SAME: [[VAL_50:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_51:%.*]] = "tf.Neg"([[VAL_50]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_51]] : tensor<2xf32> + +func @neg_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.neg"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @neg_dynamic( +// CHECK-SAME: [[VAL_52:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_53:%.*]] = "tf.Neg"([[VAL_52]]) : (tensor) -> tensor +// CHECK: return [[VAL_53]] : tensor + +func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @neg_unranked( +// CHECK-SAME: [[VAL_54:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_55:%.*]] = "tf.Neg"([[VAL_54]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_55]] : tensor<*xf32> + +func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @sin( +// CHECK-SAME: [[VAL_56:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_57:%.*]] = "tf.Sin"([[VAL_56]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_57]] : tensor<2xf32> + +func @sin_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.sin"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @sin_dynamic( +// CHECK-SAME: [[VAL_58:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_59:%.*]] = "tf.Sin"([[VAL_58]]) : (tensor) -> tensor +// CHECK: return [[VAL_59]] : tensor + +func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @sin_unranked( +// CHECK-SAME: [[VAL_60:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_61:%.*]] = "tf.Sin"([[VAL_60]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_61]] : tensor<*xf32> + +func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @rsqrt( +// CHECK-SAME: [[VAL_62:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_63:%.*]] = "tf.Rsqrt"([[VAL_62]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_63]] : tensor<2xf32> + +func @rsqrt_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.rsqrt"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @rsqrt_dynamic( +// CHECK-SAME: [[VAL_64:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_65:%.*]] = "tf.Rsqrt"([[VAL_64]]) : (tensor) -> tensor +// CHECK: return [[VAL_65]] : tensor + +func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @rsqrt_unranked( +// CHECK-SAME: [[VAL_66:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_67:%.*]] = "tf.Rsqrt"([[VAL_66]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_67]] : tensor<*xf32> + +func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @sqrt( +// CHECK-SAME: [[VAL_68:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_69:%.*]] = "tf.Sqrt"([[VAL_68]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_69]] : tensor<2xf32> + +func @sqrt_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.sqrt"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @sqrt_dynamic( +// CHECK-SAME: [[VAL_70:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_71:%.*]] = "tf.Sqrt"([[VAL_70]]) : (tensor) -> tensor +// CHECK: return [[VAL_71]] : tensor + +func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @sqrt_unranked( +// CHECK-SAME: [[VAL_72:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_73:%.*]] = "tf.Sqrt"([[VAL_72]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_73]] : tensor<*xf32> + +func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @tanh( +// CHECK-SAME: [[VAL_74:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_75:%.*]] = "tf.Tanh"([[VAL_74]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_75]] : tensor<2xf32> + +func @tanh_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.tanh"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @tanh_dynamic( +// CHECK-SAME: [[VAL_76:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_77:%.*]] = "tf.Tanh"([[VAL_76]]) : (tensor) -> tensor +// CHECK: return [[VAL_77]] : tensor + +func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @tanh_unranked( +// CHECK-SAME: [[VAL_78:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_79:%.*]] = "tf.Tanh"([[VAL_78]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_79]] : tensor<*xf32> + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc new file mode 100644 index 00000000000..281efe98d2e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for legalizing HLO to TensorFlow. + +#include + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace TF { +namespace { + +class LegalizeHloToTf : public FunctionPass { + public: + LegalizeHloToTf() = default; + LegalizeHloToTf(const LegalizeHloToTf &) {} + + /// Performs the legalization to the TF dialect. + void runOnFunction() override; +}; + +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc" + +/// Performs the lowering to XLA dialect. +void LegalizeHloToTf::runOnFunction() { + MLIRContext &context = getContext(); + + // Add legalization patterns to the list. + OwningRewritePatternList patterns; + populateWithGenerated(&context, &patterns); + + ConversionTarget target(context); + target.addLegalDialect(); + target.addLegalOp(); + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +static PassRegistration pass( + "tf-legalize-hlo", "Legalize from HLO to the TF dialect"); + +} // end namespace + +std::unique_ptr> CreateLegalizeHloToTfPass() { + return std::make_unique(); +} + +} // end namespace TF +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td new file mode 100644 index 00000000000..bc4dd24f498 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern definition file for HLO to TF. + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" + +def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; + +//===----------------------------------------------------------------------===// +// Binary op patterns. +//===----------------------------------------------------------------------===// + +class DirectBinaryPat + : Pat<(FromOp $l, $r, $_), (ToOp $l, $r)>; + +foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op], + [HLO_DivOp, TF_DivOp], + [HLO_ShiftLeftOp, TF_LeftShiftOp], + [HLO_MaxOp, TF_MaximumOp], + [HLO_MinOp, TF_MinimumOp], + [HLO_MulOp, TF_MulOp], + [HLO_PowOp, TF_PowOp], + [HLO_DivOp, TF_RealDivOp], + [HLO_SubOp, TF_SubOp]] in + def : DirectBinaryPat; + +def LowerRightShiftSigned : + Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), + [(SignedIntTensor $r)]>; + +def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r)>; + +//===----------------------------------------------------------------------===// +// Unary op patterns. +//===----------------------------------------------------------------------===// + +foreach Mapping = [ + [HLO_AbsOp, TF_AbsOp], + [HLO_CeilOp, TF_CeilOp], + [HLO_CosOp, TF_CosOp], + [HLO_ExpOp, TF_ExpOp], + [HLO_FloorOp, TF_FloorOp], + [HLO_ImagOp, TF_ImagOp], + [HLO_IsFiniteOp, TF_IsFiniteOp], + [HLO_LogOp, TF_LogOp], + [HLO_Log1pOp, TF_Log1pOp], + [HLO_NotOp, TF_LogicalNotOp], + [HLO_NegOp, TF_NegOp], + [HLO_RealOp, TF_RealOp], + [HLO_RsqrtOp, TF_RsqrtOp], + [HLO_SinOp, TF_SinOp], + [HLO_SqrtOp, TF_SqrtOp], + [HLO_TanhOp, TF_TanhOp], + ] in { + def : Pat<(Mapping[0] $input), (Mapping[1] $input)>; +} From c3c2104933f9f2db69b443bdcec7ad9fb68505f8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 09:47:08 -0700 Subject: [PATCH 071/492] Storing kRead kinds doesn't have any effect, but might take a lot of RAM. PiperOrigin-RevId: 301392884 Change-Id: I3b866f93546b16559f0a5551bbe525138844744c --- .../compiler/jit/resource_operation_safety_analysis.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index fc2f69e2ad3..0cdd6474177 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -299,7 +299,11 @@ Status ComputeIncompatibleResourceOperationPairs( result->push_back({incoming_op.first, n->id()}); } - resource_op_set->Add({n->id(), *op_kind}); + // Some graphs might have a lot of 'kRead' kinds, but they are always safe + // for incoming ops, so not storing them might save a lot of memory. + if (op_kind != XlaResourceOpKind::kRead) { + resource_op_set->Add({n->id(), *op_kind}); + } } if (vlog) { From 9f0ff44f9fe5be0460bb1e2a31799ada815963a0 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 17 Mar 2020 09:53:03 -0700 Subject: [PATCH 072/492] Update tests under keras.utils to use combinations. Change all test_util.run_all_in_graph_and_eager_modes to combination. PiperOrigin-RevId: 301393990 Change-Id: I7084404a9a256a11804bd474d1383f9c36de7305 --- tensorflow/python/keras/tests/BUILD | 1 + tensorflow/python/keras/tests/model_subclassing_test.py | 8 +++++--- tensorflow/python/keras/utils/BUILD | 2 ++ tensorflow/python/keras/utils/metrics_utils_test.py | 7 ++++--- tensorflow/python/keras/utils/tf_utils_test.py | 8 +++++--- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD index fe339380d01..94f5624bd4e 100644 --- a/tensorflow/python/keras/tests/BUILD +++ b/tensorflow/python/keras/tests/BUILD @@ -102,6 +102,7 @@ tf_py_test( ":model_subclassing_test_util", "//tensorflow/python:client_testlib", "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/keras/tests/model_subclassing_test.py b/tensorflow/python/keras/tests/model_subclassing_test.py index e903bd89717..761f720cea5 100644 --- a/tensorflow/python/keras/tests/model_subclassing_test.py +++ b/tensorflow/python/keras/tests/model_subclassing_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import copy import os +from absl.testing import parameterized import numpy as np from tensorflow.python import keras @@ -29,6 +30,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util +from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.tests import model_subclassing_test_util as model_util @@ -606,8 +608,8 @@ class GraphSpecificModelSubclassingTests(test.TestCase): _ = model.evaluate([x1, x2], [y1, y2], verbose=0) -@test_util.run_all_in_graph_and_eager_modes -class CustomCallSignatureTests(test.TestCase): +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) +class CustomCallSignatureTests(test.TestCase, parameterized.TestCase): def test_no_inputs_in_signature(self): model = model_util.CustomCallModel() @@ -669,7 +671,7 @@ class CustomCallSignatureTests(test.TestCase): arg = array_ops.ones([1]) model(arg, a=3) if not context.executing_eagerly(): - self.assertEqual(len(model.inputs), 1) + self.assertLen(model.inputs, 1) @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index 5056efbd021..681dec5932e 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -244,6 +244,7 @@ tf_py_test( ":tf_utils", "//tensorflow/python:client_testlib", "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", ], ) @@ -370,6 +371,7 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python/eager:context", "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/keras/utils/metrics_utils_test.py b/tensorflow/python/keras/utils/metrics_utils_test.py index a27f2b6af26..38467a63c1a 100644 --- a/tensorflow/python/keras/utils/metrics_utils_test.py +++ b/tensorflow/python/keras/utils/metrics_utils_test.py @@ -23,6 +23,7 @@ from absl.testing import parameterized from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.keras import combinations from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.ops import script_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -30,7 +31,7 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import googletest -@test_util.run_all_in_graph_and_eager_modes +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): @parameterized.parameters([ @@ -249,8 +250,8 @@ class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y]) -@test_util.run_all_in_graph_and_eager_modes -class FilterTopKTest(test_util.TensorFlowTestCase): +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) +class FilterTopKTest(test_util.TensorFlowTestCase, parameterized.TestCase): def test_one_dimensional(self): x = constant_op.constant([.3, .1, .2, -.5, 42.]) diff --git a/tensorflow/python/keras/utils/tf_utils_test.py b/tensorflow/python/keras/utils/tf_utils_test.py index 2f87af2ef06..afbc0ba23a2 100644 --- a/tensorflow/python/keras/utils/tf_utils_test.py +++ b/tensorflow/python/keras/utils/tf_utils_test.py @@ -18,18 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util +from tensorflow.python.keras import combinations from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import variables from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class TestIsSymbolicTensor(test.TestCase): +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) +class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): def test_default_behavior(self): if context.executing_eagerly(): From 8d01f78f829fe26447dcbf4a3d741e8f8329f969 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 10:03:20 -0700 Subject: [PATCH 073/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301396010 Change-Id: I607103f55e1b09ccdb1351d8234f889dd7eee38c --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 6456f104ad3..52a9bf9551b 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From f7896058b2c332bef81ed5860567b71c4f2ce10e Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 17 Mar 2020 10:06:33 -0700 Subject: [PATCH 074/492] Update tests under keras.saving to use combinations. Change all test_util.run_all_in_graph_and_eager_modes to combination. PiperOrigin-RevId: 301396996 Change-Id: I1d79695f819bb289a428b3fd97965841a873bda9 --- tensorflow/python/keras/saving/BUILD | 4 + .../python/keras/saving/hdf5_format_test.py | 176 +++++++++--------- tensorflow/python/keras/saving/save_test.py | 38 ++-- .../saving/saved_model/saved_model_test.py | 46 ++--- .../python/keras/saving/saving_utils_test.py | 11 +- 5 files changed, 144 insertions(+), 131 deletions(-) diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index eda5776c7a6..eda4df9b742 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -92,6 +92,7 @@ tf_py_test( deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -106,6 +107,7 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python/feature_column:feature_column_v2", "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -142,6 +144,7 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -156,6 +159,7 @@ tf_py_test( deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/keras/saving/hdf5_format_test.py b/tensorflow/python/keras/saving/hdf5_format_test.py index b0e28640861..22ffcc5ed02 100644 --- a/tensorflow/python/keras/saving/hdf5_format_test.py +++ b/tensorflow/python/keras/saving/hdf5_format_test.py @@ -30,7 +30,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util +from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import optimizers from tensorflow.python.keras import testing_utils @@ -51,10 +51,10 @@ except ImportError: h5py = None +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): @keras_parameterized.run_with_all_saved_model_formats - @test_util.run_in_graph_and_eager_modes def test_weight_loading(self): temp_dir = self.get_temp_dir() self.addCleanup(shutil.rmtree, temp_dir) @@ -83,7 +83,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): y = model.predict(x) self.assertAllClose(ref_y, y) - @test_util.run_in_graph_and_eager_modes def test_weight_preprocessing(self): input_dim = 3 output_dim = 3 @@ -210,7 +209,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): for (x, y) in zip(weights1, weights2) ] - @test_util.run_in_graph_and_eager_modes def test_sequential_weight_loading(self): if h5py is None: return @@ -243,7 +241,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): self.assertAllClose(y, ref_y) @keras_parameterized.run_with_all_saved_model_formats - @test_util.run_in_graph_and_eager_modes def test_nested_model_weight_loading(self): save_format = testing_utils.get_save_format() temp_dir = self.get_temp_dir() @@ -282,7 +279,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): self.assertAllClose(y, ref_y) - @test_util.run_in_graph_and_eager_modes def test_sequential_weight_loading_group_name_with_incorrect_length(self): if h5py is None: return @@ -314,16 +310,16 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): model.compile(loss=keras.losses.MSE, optimizer='rmsprop', metrics=[keras.metrics.categorical_accuracy]) - with self.assertRaisesRegexp(ValueError, - r'Layer #0 \(named \"d1\"\) expects 1 ' - r'weight\(s\), but the saved weights have 2 ' - r'element\(s\)\.'): - hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named \"d1\"\) expects 1 ' + r'weight\(s\), but the saved weights have 2 ' + r'element\(s\)\.'): + hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers) - hdf5_format.load_weights_from_hdf5_group_by_name( - f_model, model.layers, skip_mismatch=True) - self.assertAllClose(keras.backend.get_value(ref_model.layers[1].kernel), - keras.backend.get_value(model.layers[1].kernel)) + hdf5_format.load_weights_from_hdf5_group_by_name( + f_model, model.layers, skip_mismatch=True) + self.assertAllClose(keras.backend.get_value(ref_model.layers[1].kernel), + keras.backend.get_value(model.layers[1].kernel)) def test_sequential_weight_loading_group_name_with_incorrect_shape(self): if h5py is None: @@ -779,7 +775,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase): self.assertRegexpMatches( h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$') - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_functional_model_with_custom_loss_and_metric(self): def _make_model(): inputs = keras.Input(shape=(4,)) @@ -818,7 +814,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase): evaluation_results['sparse_categorical_crossentropy'] + evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_save_uncompiled_model_with_optimizer(self): with self.cached_session() as session: saved_model_dir = self._save_model_dir() @@ -901,6 +897,7 @@ class _make_subclassed_built(_make_subclassed): # pylint: disable=invalid-name self.build((None, input_size)) +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase): """Tests saving a whole model that contains other models.""" @@ -913,7 +910,6 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase): ('subclassed', _make_subclassed), ('subclassed_built', _make_subclassed_built), ]) - @test_util.run_in_graph_and_eager_modes def test_functional(self, model_fn): """Tests serializing a model that uses a nested model to share weights.""" if h5py is None: @@ -926,22 +922,23 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase): outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])]) return keras.Model(inputs=inputs, outputs=outputs) - x = (np.random.normal(size=(16, 4)).astype(np.float32), - np.random.normal(size=(16, 4)).astype(np.float32)) - model = _make_model() - predictions = model(x) - # Save and reload. - model_path = os.path.join(self.get_temp_dir(), 'model.h5') - model.save(model_path) - del model - loaded_model = keras.models.load_model( - model_path, - custom_objects={ - '_make_subclassed': _make_subclassed, - '_make_subclassed_built': _make_subclassed_built, - }, - compile=False) - self.assertAllClose(loaded_model(x), predictions, 1e-9) + with self.cached_session(): + x = (np.random.normal(size=(16, 4)).astype(np.float32), + np.random.normal(size=(16, 4)).astype(np.float32)) + model = _make_model() + predictions = model(x) + # Save and reload. + model_path = os.path.join(self.get_temp_dir(), 'model.h5') + model.save(model_path) + del model + loaded_model = keras.models.load_model( + model_path, + custom_objects={ + '_make_subclassed': _make_subclassed, + '_make_subclassed_built': _make_subclassed_built, + }, + compile=False) + self.assertAllClose(loaded_model(x), predictions, 1e-9) class SubclassedModel(training.Model): @@ -955,7 +952,7 @@ class SubclassedModel(training.Model): return self.b_layer(self.x_layer(a)) -class TestWeightSavingAndLoadingTFFormat(test.TestCase): +class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase): def test_keras_optimizer_warning(self): graph = ops.Graph() @@ -974,7 +971,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): str(mock_log.call_args), 'Keras optimizer') - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_tensorflow_format_overwrite(self): with self.cached_session() as session: model = SubclassedModel() @@ -1025,12 +1022,12 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): model.save_weights(prefix, save_format='tensorflow') op_count = len(graph.get_operations()) model.save_weights(prefix, save_format='tensorflow') - self.assertEqual(len(graph.get_operations()), op_count) + self.assertLen(graph.get_operations(), op_count) model.load_weights(prefix) op_count = len(graph.get_operations()) model.load_weights(prefix) - self.assertEqual(len(graph.get_operations()), op_count) + self.assertLen(graph.get_operations(), op_count) def _weight_loading_test_template(self, make_model_fn): with self.cached_session(): @@ -1079,7 +1076,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): load_model.train_on_batch(train_x, train_y) self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x))) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_weight_loading_graph_model(self): def _make_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -1089,7 +1086,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self._weight_loading_test_template(_make_graph_model) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_weight_loading_subclassed_model(self): self._weight_loading_test_template(SubclassedModel) @@ -1127,7 +1124,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): y = self.evaluate(model(x)) self.assertAllClose(ref_y, y) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_weight_loading_graph_model_added_layer(self): def _save_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -1144,7 +1141,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self._new_layer_weight_loading_test_template( _save_graph_model, _restore_graph_model) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_weight_loading_graph_model_added_no_weight_layer(self): def _save_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -1161,7 +1158,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self._new_layer_weight_loading_test_template( _save_graph_model, _restore_graph_model) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_weight_loading_subclassed_model_added_layer(self): class SubclassedModelRestore(training.Model): @@ -1178,7 +1175,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self._new_layer_weight_loading_test_template( SubclassedModel, SubclassedModelRestore) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_incompatible_checkpoint(self): save_path = trackable.Checkpoint().save( os.path.join(self.get_temp_dir(), 'ckpt')) @@ -1191,57 +1188,62 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): AssertionError, 'Nothing except the root object matched'): m.load_weights(save_path) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_directory_passed(self): - m = keras.Model() - v = m.add_weight(name='v', shape=[]) - self.evaluate(v.assign(42.)) - prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'ckpt/') - m.save_weights(prefix) - self.evaluate(v.assign(2.)) - m.load_weights(prefix) - self.assertEqual(42., self.evaluate(v)) + with self.cached_session(): + m = keras.Model() + v = m.add_weight(name='v', shape=[]) + self.evaluate(v.assign(42.)) + prefix = os.path.join(self.get_temp_dir(), + '{}'.format(ops.uid()), 'ckpt/') + m.save_weights(prefix) + self.evaluate(v.assign(2.)) + m.load_weights(prefix) + self.assertEqual(42., self.evaluate(v)) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_relative_path(self): - m = keras.Model() - v = m.add_weight(name='v', shape=[]) - os.chdir(self.get_temp_dir()) + with self.cached_session(): + m = keras.Model() + v = m.add_weight(name='v', shape=[]) + os.chdir(self.get_temp_dir()) - prefix = 'ackpt' - self.evaluate(v.assign(42.)) - m.save_weights(prefix) - self.assertTrue(file_io.file_exists('ackpt.index')) - self.evaluate(v.assign(1.)) - m.load_weights(prefix) - self.assertEqual(42., self.evaluate(v)) + prefix = 'ackpt' + self.evaluate(v.assign(42.)) + m.save_weights(prefix) + self.assertTrue(file_io.file_exists('ackpt.index')) + self.evaluate(v.assign(1.)) + m.load_weights(prefix) + self.assertEqual(42., self.evaluate(v)) - prefix = 'subdir/ackpt' - self.evaluate(v.assign(43.)) - m.save_weights(prefix) - self.assertTrue(file_io.file_exists('subdir/ackpt.index')) - self.evaluate(v.assign(2.)) - m.load_weights(prefix) - self.assertEqual(43., self.evaluate(v)) + prefix = 'subdir/ackpt' + self.evaluate(v.assign(43.)) + m.save_weights(prefix) + self.assertTrue(file_io.file_exists('subdir/ackpt.index')) + self.evaluate(v.assign(2.)) + m.load_weights(prefix) + self.assertEqual(43., self.evaluate(v)) - prefix = 'ackpt/' - self.evaluate(v.assign(44.)) - m.save_weights(prefix) - self.assertTrue(file_io.file_exists('ackpt/.index')) - self.evaluate(v.assign(3.)) - m.load_weights(prefix) - self.assertEqual(44., self.evaluate(v)) + prefix = 'ackpt/' + self.evaluate(v.assign(44.)) + m.save_weights(prefix) + self.assertTrue(file_io.file_exists('ackpt/.index')) + self.evaluate(v.assign(3.)) + m.load_weights(prefix) + self.assertEqual(44., self.evaluate(v)) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_nonexistent_prefix_directory(self): - m = keras.Model() - v = m.add_weight(name='v', shape=[]) - self.evaluate(v.assign(42.)) - prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'bckpt') - m.save_weights(prefix) - self.evaluate(v.assign(2.)) - m.load_weights(prefix) - self.assertEqual(42., self.evaluate(v)) + with self.cached_session(): + m = keras.Model() + v = m.add_weight(name='v', shape=[]) + self.evaluate(v.assign(42.)) + prefix = os.path.join(self.get_temp_dir(), + '{}'.format(ops.uid()), 'bckpt') + m.save_weights(prefix) + self.evaluate(v.assign(2.)) + m.load_weights(prefix) + self.assertEqual(42., self.evaluate(v)) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index 965a1b88cc7..a2bc687c3b8 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import sys +from absl.testing import parameterized import numpy as np from tensorflow.python import keras @@ -28,6 +29,7 @@ from tensorflow.python.eager import context from tensorflow.python.feature_column import feature_column_lib from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.keras import combinations from tensorflow.python.keras import testing_utils from tensorflow.python.keras.saving import model_config from tensorflow.python.keras.saving import save @@ -43,7 +45,7 @@ except ImportError: h5py = None -class TestSaveModel(test.TestCase): +class TestSaveModel(test.TestCase, parameterized.TestCase): def setUp(self): super(TestSaveModel, self).setUp() @@ -99,7 +101,7 @@ class TestSaveModel(test.TestCase): save.save_model(self.model, path, save_format='tf') save.load_model(path) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_saving_with_dense_features(self): cols = [ feature_column_lib.numeric_column('a'), @@ -128,13 +130,14 @@ class TestSaveModel(test.TestCase): inputs_a = np.arange(10).reshape(10, 1) inputs_b = np.arange(10).reshape(10, 1).astype('str') - # Initialize tables for V1 lookup. - if not context.executing_eagerly(): - self.evaluate(lookup_ops.tables_initializer()) + with self.cached_session(): + # Initialize tables for V1 lookup. + if not context.executing_eagerly(): + self.evaluate(lookup_ops.tables_initializer()) - self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10) + self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_saving_with_sequence_features(self): cols = [ feature_column_lib.sequence_numeric_column('a'), @@ -182,17 +185,18 @@ class TestSaveModel(test.TestCase): inputs_b = sparse_tensor.SparseTensor(indices_b, values_b, (batch_size, timesteps, 1)) - # Initialize tables for V1 lookup. - if not context.executing_eagerly(): - self.evaluate(lookup_ops.tables_initializer()) + with self.cached_session(): + # Initialize tables for V1 lookup. + if not context.executing_eagerly(): + self.evaluate(lookup_ops.tables_initializer()) - self.assertLen( - loaded_model.predict({ - 'a': inputs_a, - 'b': inputs_b - }, steps=1), batch_size) + self.assertLen( + loaded_model.predict({ + 'a': inputs_a, + 'b': inputs_b + }, steps=1), batch_size) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_saving_h5_for_rnn_layers(self): # See https://github.com/tensorflow/tensorflow/issues/35731 for details. inputs = keras.Input([10, 91], name='train_input') @@ -213,7 +217,7 @@ class TestSaveModel(test.TestCase): rnn_layers[1].kernel.name) self.assertIn('rnn_cell1', rnn_layers[1].kernel.name) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_saving_optimizer_weights(self): class MyModel(keras.Model): diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 76e87d91553..f56d55b18d5 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -44,7 +44,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec -from tensorflow.python.framework import test_util +from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import regularizers from tensorflow.python.keras import testing_utils @@ -700,7 +700,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.assertAllClose(model(input_arr), loaded(input_arr)) -class TestLayerCallTracing(test.TestCase): +class TestLayerCallTracing(test.TestCase, parameterized.TestCase): def test_functions_have_same_trace(self): @@ -773,7 +773,7 @@ class TestLayerCallTracing(test.TestCase): assert_num_traces(LayerWithChildLayer, training_keyword=False) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_maintains_losses(self): layer = LayerWithLoss() layer(np.ones((2, 3))) @@ -786,7 +786,7 @@ class TestLayerCallTracing(test.TestCase): self.assertAllEqual(previous_losses, layer.losses) -@test_util.run_all_in_graph_and_eager_modes +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) class MetricTest(test.TestCase, parameterized.TestCase): def _save_model_dir(self, dirname='saved_model'): @@ -870,28 +870,30 @@ class MetricTest(test.TestCase, parameterized.TestCase): # while returning nothing. super(CustomMetric, self).update_state(*args) - metric = CustomMetric() - save_dir = self._save_model_dir('first_save') + with self.cached_session(): + metric = CustomMetric() + save_dir = self._save_model_dir('first_save') - if requires_build: - metric(*self.generate_inputs(num_tensor_args)) # pylint: disable=not-callable + if requires_build: + metric(*self.generate_inputs(num_tensor_args)) # pylint: disable=not-callable - self.evaluate([v.initializer for v in metric.variables]) + self.evaluate([v.initializer for v in metric.variables]) - with self.assertRaisesRegexp(ValueError, 'Unable to restore custom object'): - self._test_metric_save_and_load(metric, save_dir, num_tensor_args) - with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}): - loaded = self._test_metric_save_and_load( - metric, - save_dir, - num_tensor_args, - test_sample_weight=False) + with self.assertRaisesRegexp(ValueError, + 'Unable to restore custom object'): + self._test_metric_save_and_load(metric, save_dir, num_tensor_args) + with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}): + loaded = self._test_metric_save_and_load( + metric, + save_dir, + num_tensor_args, + test_sample_weight=False) - self._test_metric_save_and_load( - loaded, - self._save_model_dir('second_save'), - num_tensor_args, - test_sample_weight=False) + self._test_metric_save_and_load( + loaded, + self._save_model_dir('second_save'), + num_tensor_args, + test_sample_weight=False) def test_custom_metric_wrapped_call(self): diff --git a/tensorflow/python/keras/saving/saving_utils_test.py b/tensorflow/python/keras/saving/saving_utils_test.py index a9df7502412..bc0ea6edf11 100644 --- a/tensorflow/python/keras/saving/saving_utils_test.py +++ b/tensorflow/python/keras/saving/saving_utils_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras import backend as K +from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import sequential @@ -62,7 +63,7 @@ class TraceModelCallTest(keras_parameterized.TestCase): self.assertAllClose(expected, actual) @keras_parameterized.run_with_all_model_types - @test_util.run_in_graph_and_eager_modes + @keras_parameterized.run_all_keras_modes def test_trace_model_outputs(self): input_dim = 5 if testing_utils.get_model_type() == 'functional' else None model = testing_utils.get_small_mlp(10, 3, input_dim) @@ -155,7 +156,7 @@ class TraceModelCallTest(keras_parameterized.TestCase): expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]} self._assert_all_close(expected_outputs, signature_outputs) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_trace_features_layer(self): columns = [feature_column_lib.numeric_column('x')] model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)]) @@ -176,7 +177,7 @@ class TraceModelCallTest(keras_parameterized.TestCase): self.assertAllClose({'output_1': [[1., 2.]]}, fn({'x': [[1.]], 'y': [[2.]]})) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_specify_input_signature(self): model = testing_utils.get_small_sequential_mlp(10, 3, None) inputs = array_ops.ones((8, 5)) @@ -193,7 +194,7 @@ class TraceModelCallTest(keras_parameterized.TestCase): expected_outputs = {'output_1': model(inputs)} self._assert_all_close(expected_outputs, signature_outputs) - @test_util.run_in_graph_and_eager_modes + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_subclassed_model_with_input_signature(self): class Model(keras.Model): @@ -218,7 +219,7 @@ class TraceModelCallTest(keras_parameterized.TestCase): self._assert_all_close(expected_outputs, signature_outputs) @keras_parameterized.run_with_all_model_types - @test_util.run_in_graph_and_eager_modes + @keras_parameterized.run_all_keras_modes def test_model_with_fixed_input_dim(self): """Ensure that the batch_dim is removed when saving. From 6094289d90e69533fae5964ea221e57a7a78570e Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Tue, 17 Mar 2020 10:15:52 -0700 Subject: [PATCH 075/492] [TF lite conversion logging] Implement a simple sanitizer to prune error message returned from MLIR. PiperOrigin-RevId: 301399261 Change-Id: Ie3d9fdc18c07e5cc780e4f3b63f1968e2136dfc3 --- .../mlir/lite/flatbuffer_translate.cc | 2 +- .../tests/mlir2flatbuffer/disable_custom.mlir | 14 +++++++++++ .../lite/toco/logging/conversion_log_util.cc | 21 ++++++++++++++++ .../lite/toco/logging/conversion_log_util.h | 4 +++ .../toco/logging/conversion_log_util_test.cc | 25 +++++++++++++++++++ .../lite/toco/python/toco_python_api.cc | 3 ++- 6 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 9e9330e2c96..a5831559546 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -1397,7 +1397,7 @@ Optional Translator::TranslateInternal() { err += "Ops that need custom implementation (enabled via setting the " "-emit-custom-ops flag): " + - failed_custom_ops_list; + failed_custom_ops_list + "."; auto& failed_region = named_regions[first_failed_func]; return failed_region.second->getParentOp()->emitError() diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir new file mode 100644 index 00000000000..046fe6ac9ef --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir @@ -0,0 +1,14 @@ +// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s + +// CHECK: error: 'tf.MyCustomOp' op is neither a custom op nor a flex op +// CHECK: error: failed while converting: 'main' +// CHECK: Ops that need custom implementation (enabled via setting the -emit-custom-ops flag): MyCustomOp. + +func @main(tensor<4xf32>) -> tensor<4xf32> { +^bb0(%arg0: tensor<4xf32>): + %0 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = "tf.MyCustomOp"(%1, %0) {name = "MyCustomOp"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %3 = "tfl.exp"(%2) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32> + return %3 : tensor<4xf32> +} diff --git a/tensorflow/lite/toco/logging/conversion_log_util.cc b/tensorflow/lite/toco/logging/conversion_log_util.cc index 8c3b27de173..c23c305c750 100644 --- a/tensorflow/lite/toco/logging/conversion_log_util.cc +++ b/tensorflow/lite/toco/logging/conversion_log_util.cc @@ -202,6 +202,27 @@ string GetModelHash(const Model& model) { return ""; } +// This function scans through the error message string, extracts the part about +// missing ops and prunes away all other information in the error info. +string SanitizeErrorMessage(const string& error_message) { + const string s1 = "Ops that can be supported by the flex runtime"; + const string s2 = "Ops that need custom implementation"; + string pruned_message; + size_t pos = error_message.find(s1); + if (pos != string::npos) { + // Find the terminate point for flex op list. + auto end = error_message.find(".", pos); + pruned_message.append(error_message.substr(pos, end - pos + 1)); + } + pos = error_message.find(s2); + if (pos != string::npos) { + // Find the terminate point for custom op list. + auto end = error_message.find(".", pos); + pruned_message.append(error_message.substr(pos, end - pos + 1)); + } + return pruned_message; +} + void PopulateConversionLog(const Model& model, TocoConversionLog* log) { // Get the list of ops after conversion. const std::vector op_names = GetOperatorNames(model); diff --git a/tensorflow/lite/toco/logging/conversion_log_util.h b/tensorflow/lite/toco/logging/conversion_log_util.h index 9ed688085b6..2237615adbb 100644 --- a/tensorflow/lite/toco/logging/conversion_log_util.h +++ b/tensorflow/lite/toco/logging/conversion_log_util.h @@ -23,6 +23,10 @@ limitations under the License. namespace toco { +// This function scans through the error message string, extracts the part about +// missing ops and prunes away all other information in the error info. +string SanitizeErrorMessage(const string& error_message); + // Populates the TocoConversionLog proto after analyzing the model. void PopulateConversionLog(const Model& model, TocoConversionLog* log); diff --git a/tensorflow/lite/toco/logging/conversion_log_util_test.cc b/tensorflow/lite/toco/logging/conversion_log_util_test.cc index ac53471cec3..c4960715f25 100644 --- a/tensorflow/lite/toco/logging/conversion_log_util_test.cc +++ b/tensorflow/lite/toco/logging/conversion_log_util_test.cc @@ -224,5 +224,30 @@ TEST(ConversionLogUtilTest, TestGetOpSignatures) { "MyAwesomeCustomOp::VERSION:1")); } +TEST(ConversionLogUtilTest, TestSanitizeErrorMessage) { + const string error = + "error: failed while converting: 'main': Ops that can be supported by " + "the flex runtime (enabled via setting the -emit-select-tf-ops flag): " + "ResizeNearestNeighbor,ResizeNearestNeighbor. Ops that need custom " + "implementation (enabled via setting the -emit-custom-ops flag): " + "CombinedNonMaxSuppression.\nTraceback (most recent call last): File " + "/usr/local/bin/toco_from_protos, line 8, in "; + const string pruned_error = + "Ops that can be supported by " + "the flex runtime (enabled via setting the -emit-select-tf-ops flag): " + "ResizeNearestNeighbor,ResizeNearestNeighbor.Ops that need custom " + "implementation (enabled via setting the -emit-custom-ops flag): " + "CombinedNonMaxSuppression."; + EXPECT_EQ(SanitizeErrorMessage(error), pruned_error); +} + +TEST(ConversionLogUtilTest, TestSanitizeErrorMessageNoMatching) { + const string error = + "error: failed while converting: 'main': Traceback (most recent call " + "last): File " + "/usr/local/bin/toco_from_protos, line 8, in "; + EXPECT_EQ(SanitizeErrorMessage(error), ""); +} + } // namespace } // namespace toco diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index 01000f590c1..31de4cfc726 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -67,7 +67,8 @@ void PopulateConversionLogHelper(const toco::ModelFlags& model_flags, // Dump post-conversion toco logs. TocoConversionLog toco_log_after; PopulateConversionLog(*flatbuffer_model, &toco_log_after); - toco_log_after.set_toco_err_logs(error_message); + // Make sure we sanitize the error message. + toco_log_after.set_toco_err_logs(SanitizeErrorMessage(error_message)); std::ofstream ostream_after(toco_flags->conversion_summary_dir() + "/toco_log_after.pb"); toco_log_after.SerializeToOstream(&ostream_after); From a05dc8341c63e18da36c192c859f6c493dd7efb1 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 17 Mar 2020 10:26:24 -0700 Subject: [PATCH 076/492] Export `timeseries.dataset_from_array` to the public API. PiperOrigin-RevId: 301401580 Change-Id: Ibd516e412b58ebf51f8c428924beca1f39484352 --- .../python/keras/preprocessing/timeseries.py | 66 +++++++++++++------ .../keras/preprocessing/timeseries_test.py | 66 ++++++++++--------- 2 files changed, 81 insertions(+), 51 deletions(-) diff --git a/tensorflow/python/keras/preprocessing/timeseries.py b/tensorflow/python/keras/preprocessing/timeseries.py index ca41f1952e3..7f14542fdac 100644 --- a/tensorflow/python/keras/preprocessing/timeseries.py +++ b/tensorflow/python/keras/preprocessing/timeseries.py @@ -23,42 +23,46 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import keras_export -def timeseries_dataset( +@keras_export('keras.preprocessing.timeseries.dataset_from_array', v1=[]) +def dataset_from_array( data, targets, sequence_length, - sampling_rate=1, sequence_stride=1, + sampling_rate=1, batch_size=128, shuffle=False, seed=None, start_index=None, end_index=None): - """Utility function for generating batches of temporal data. + """Creates a dataset of sliding windows over a timeseries provided as array. This function takes in a sequence of data-points gathered at equal intervals, along with time series parameters such as - spacing between two sequence, length of history, etc., to produce batches for - training/validation. + length of the sequences/windows, spacing between two sequence/windows, etc., + to produce batches of timeseries inputs and targets. Arguments: - data: Indexable generator (such as a list or a Numpy array) + data: Numpy array or eager tensor containing consecutive data points (timesteps). Axis 0 is expected to be the time dimension. targets: Targets corresponding to timesteps in `data`. - It should have same length as `data`. + It should have same length as `data`. `targets[i]` should be the target + corresponding to the window that starts at index `i` + (see example 2 below). Pass None if you don't have target data (in this case the dataset will only yield the input data). sequence_length: Length of the output sequences (in number of timesteps). + sequence_stride: Period between successive output sequences. + For stride `s`, output samples would + start at index `data[i]`, `data[i + s]`, `data[i + 2 * s]`, etc. sampling_rate: Period between successive individual timesteps within sequences. For rate `r`, timesteps `data[i], data[i + r], ... data[i + sequence_length]` are used for create a sample sequence. - sequence_stride: Period between successive output sequences. - For stride `s`, output samples would - start at index `data[i]`, `data[i + s]`, `data[i + 2 * s]`, etc. batch_size: Number of timeseries samples in each batch (except maybe the last one). shuffle: Whether to shuffle output samples, @@ -73,11 +77,11 @@ def timeseries_dataset( This is useful to reserve part of the data for test or validation. Returns: - A tf.data.Dataset instance. If `targets` was pass, the dataset yields + A tf.data.Dataset instance. If `targets` was passed, the dataset yields tuple `(batch_of_sequences, batch_of_targets)`. If not, the dataset yields only `batch_of_sequences`. - Example: + Example 1: Consider indices `[0, 1, ... 99]`. With `sequence_length=10, sampling_rate=2, sequence_stride=3`, `shuffle=False`, the dataset will yield batches of sequences @@ -94,6 +98,22 @@ def timeseries_dataset( In this case the last 3 data points are discarded since no full sequence can be generated to include them (the next sequence would have started at index 81, and thus its last step would have gone over 99). + + Example 2: temporal regression. Consider an array `data` of scalar + values, of shape `(steps,)`. To generate a dataset that uses the past 10 + timesteps to predict the next timestep, you would use: + + ```python + input_data = data[:-10] + targets = data[10:] + dataset = tf.keras.preprocessing.timeseries.dataset_from_array( + input_data, targets, sequence_length=10) + for batch in dataset: + inputs, targets = batch + assert np.array_equal(inputs[0], data[:10]) # First sequence: steps [0-9] + assert np.array_equal(targets[0], data[10]) # Corresponding target: step 10 + break + ``` """ # Validate the shape of data and targets if targets is not None and len(targets) != len(data): @@ -152,19 +172,25 @@ def timeseries_dataset( sequence_length = math_ops.cast(sequence_length, dtype=index_dtype) sampling_rate = math_ops.cast(sampling_rate, dtype=index_dtype) + positions_ds = dataset_ops.Dataset.from_tensors(start_positions).repeat() + # For each initial window position, generates indices of the window elements indices = dataset_ops.Dataset.zip( - (dataset_ops.Dataset.range(len(start_positions)), - dataset_ops.Dataset.from_tensors(start_positions).repeat())).map( - lambda i, positions: math_ops.range( # pylint: disable=g-long-lambda - positions[i], - positions[i] + sequence_length * sampling_rate, - sampling_rate), - num_parallel_calls=dataset_ops.AUTOTUNE) + (dataset_ops.Dataset.range(len(start_positions)), positions_ds)).map( + lambda i, positions: math_ops.range( # pylint: disable=g-long-lambda + positions[i], + positions[i] + sequence_length * sampling_rate, + sampling_rate), + num_parallel_calls=dataset_ops.AUTOTUNE) dataset = sequences_from_indices(data, indices, start_index, end_index) if targets is not None: - target_ds = sequences_from_indices(targets, indices, start_index, end_index) + indices = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(len(start_positions)), positions_ds)).map( + lambda i, positions: positions[i], + num_parallel_calls=dataset_ops.AUTOTUNE) + target_ds = sequences_from_indices( + targets, indices, start_index, end_index) dataset = dataset_ops.Dataset.zip((dataset, target_ds)) if shuffle: # Shuffle locally at each iteration diff --git a/tensorflow/python/keras/preprocessing/timeseries_test.py b/tensorflow/python/keras/preprocessing/timeseries_test.py index ab1640191bf..4dcb0277ee0 100644 --- a/tensorflow/python/keras/preprocessing/timeseries_test.py +++ b/tensorflow/python/keras/preprocessing/timeseries_test.py @@ -31,27 +31,29 @@ class TimeseriesDatasetTest(test.TestCase): # Test ordering, targets, sequence length, batch size data = np.arange(100) targets = data * 2 - dataset = timeseries.timeseries_dataset( + dataset = timeseries.dataset_from_array( data, targets, sequence_length=9, batch_size=5) # Expect 19 batches for i, batch in enumerate(dataset): self.assertLen(batch, 2) + inputs, targets = batch if i < 18: - self.assertEqual(batch[0].shape, (5, 9)) + self.assertEqual(inputs.shape, (5, 9)) if i == 18: # Last batch: size 2 - self.assertEqual(batch[0].shape, (2, 9)) + self.assertEqual(inputs.shape, (2, 9)) # Check target values - self.assertAllClose(batch[0] * 2, batch[1]) - for j in range(min(5, len(batch[0]))): + self.assertAllClose(targets, inputs[:, 0] * 2) + for j in range(min(5, len(inputs))): # Check each sample in the batch - self.assertAllClose(batch[0][j], np.arange(i * 5 + j, i * 5 + j + 9)) + self.assertAllClose(inputs[j], np.arange(i * 5 + j, i * 5 + j + 9)) def test_no_targets(self): data = np.arange(50) - dataset = timeseries.timeseries_dataset( + dataset = timeseries.dataset_from_array( data, None, sequence_length=10, batch_size=5) # Expect 9 batches + i = None for i, batch in enumerate(dataset): if i < 8: self.assertEqual(batch.shape, (5, 10)) @@ -60,23 +62,24 @@ class TimeseriesDatasetTest(test.TestCase): for j in range(min(5, len(batch))): # Check each sample in the batch self.assertAllClose(batch[j], np.arange(i * 5 + j, i * 5 + j + 10)) + self.assertEqual(i, 8) def test_shuffle(self): # Test cross-epoch random order and seed determinism data = np.arange(10) targets = data * 2 - dataset = timeseries.timeseries_dataset( + dataset = timeseries.dataset_from_array( data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123) first_seq = None for x, y in dataset.take(1): self.assertNotAllClose(x, np.arange(0, 5)) - self.assertAllClose(x * 2, y) + self.assertAllClose(x[:, 0] * 2, y) first_seq = x # Check that a new iteration with the same dataset yields different results for x, _ in dataset.take(1): self.assertNotAllClose(x, first_seq) # Check determism with same seed - dataset = timeseries.timeseries_dataset( + dataset = timeseries.dataset_from_array( data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123) for x, _ in dataset.take(1): self.assertAllClose(x, first_seq) @@ -84,48 +87,49 @@ class TimeseriesDatasetTest(test.TestCase): def test_sampling_rate(self): data = np.arange(100) targets = data * 2 - dataset = timeseries.timeseries_dataset( + dataset = timeseries.dataset_from_array( data, targets, sequence_length=9, batch_size=5, sampling_rate=2) for i, batch in enumerate(dataset): self.assertLen(batch, 2) + inputs, targets = batch if i < 16: - self.assertEqual(batch[0].shape, (5, 9)) + self.assertEqual(inputs.shape, (5, 9)) if i == 16: # Last batch: size 3 - self.assertEqual(batch[0].shape, (3, 9)) + self.assertEqual(inputs.shape, (3, 9)) # Check target values - self.assertAllClose(batch[0] * 2, batch[1]) - for j in range(min(5, len(batch[0]))): + self.assertAllClose(inputs[:, 0] * 2, targets) + for j in range(min(5, len(inputs))): # Check each sample in the batch start_index = i * 5 + j end_index = start_index + 9 * 2 - self.assertAllClose(batch[0][j], - np.arange(start_index, end_index, 2)) + self.assertAllClose(inputs[j], np.arange(start_index, end_index, 2)) def test_sequence_stride(self): data = np.arange(100) targets = data * 2 - dataset = timeseries.timeseries_dataset( + dataset = timeseries.dataset_from_array( data, targets, sequence_length=9, batch_size=5, sequence_stride=3) for i, batch in enumerate(dataset): self.assertLen(batch, 2) + inputs, targets = batch if i < 6: - self.assertEqual(batch[0].shape, (5, 9)) + self.assertEqual(inputs.shape, (5, 9)) if i == 6: # Last batch: size 1 - self.assertEqual(batch[0].shape, (1, 9)) + self.assertEqual(inputs.shape, (1, 9)) # Check target values - self.assertAllClose(batch[0] * 2, batch[1]) - for j in range(min(5, len(batch[0]))): + self.assertAllClose(inputs[:, 0] * 2, targets) + for j in range(min(5, len(inputs))): # Check each sample in the batch start_index = i * 5 * 3 + j * 3 end_index = start_index + 9 - self.assertAllClose(batch[0][j], + self.assertAllClose(inputs[j], np.arange(start_index, end_index)) def test_start_and_end_index(self): data = np.arange(100) - dataset = timeseries.timeseries_dataset( + dataset = timeseries.dataset_from_array( data, None, sequence_length=9, batch_size=5, sequence_stride=3, sampling_rate=2, start_index=10, end_index=90) @@ -137,23 +141,23 @@ class TimeseriesDatasetTest(test.TestCase): # bad targets with self.assertRaisesRegex(ValueError, 'data and targets to have the same number'): - _ = timeseries.timeseries_dataset(np.arange(10), np.arange(9), 3) + _ = timeseries.dataset_from_array(np.arange(10), np.arange(9), 3) # bad start index with self.assertRaisesRegex(ValueError, 'start_index must be '): - _ = timeseries.timeseries_dataset(np.arange(10), None, 3, start_index=-1) + _ = timeseries.dataset_from_array(np.arange(10), None, 3, start_index=-1) with self.assertRaisesRegex(ValueError, 'start_index must be '): - _ = timeseries.timeseries_dataset(np.arange(10), None, 3, start_index=11) + _ = timeseries.dataset_from_array(np.arange(10), None, 3, start_index=11) # bad end index with self.assertRaisesRegex(ValueError, 'end_index must be '): - _ = timeseries.timeseries_dataset(np.arange(10), None, 3, end_index=-1) + _ = timeseries.dataset_from_array(np.arange(10), None, 3, end_index=-1) with self.assertRaisesRegex(ValueError, 'end_index must be '): - _ = timeseries.timeseries_dataset(np.arange(10), None, 3, end_index=11) + _ = timeseries.dataset_from_array(np.arange(10), None, 3, end_index=11) # bad sampling_rate with self.assertRaisesRegex(ValueError, 'sampling_rate must be '): - _ = timeseries.timeseries_dataset(np.arange(10), None, 3, sampling_rate=0) + _ = timeseries.dataset_from_array(np.arange(10), None, 3, sampling_rate=0) # bad sequence stride with self.assertRaisesRegex(ValueError, 'sequence_stride must be '): - _ = timeseries.timeseries_dataset( + _ = timeseries.dataset_from_array( np.arange(10), None, 3, sequence_stride=0) From ba5e03c88b82ad0a7a0cefb3b50cb4d71656b636 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Tue, 17 Mar 2020 10:27:30 -0700 Subject: [PATCH 077/492] Create Variables to track mini-batches seen in Model.fit / evaluate / predict. Use these counters in the TensorBoard Callback. PiperOrigin-RevId: 301401826 Change-Id: I2975eb5ab24bc5c32539817337238d1a7b0c2258 --- tensorflow/python/keras/callbacks.py | 448 +++++++----------- tensorflow/python/keras/callbacks_test.py | 6 +- tensorflow/python/keras/callbacks_v1.py | 29 +- tensorflow/python/keras/engine/training.py | 63 ++- tensorflow/python/keras/engine/training_v1.py | 3 + .../keras/tests/model_subclassing_test.py | 15 + .../python/keras/utils/version_utils.py | 22 + ...orflow.keras.callbacks.-tensor-board.pbtxt | 2 + ...orflow.keras.callbacks.-tensor-board.pbtxt | 1 + 9 files changed, 305 insertions(+), 284 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index bb9e61d01a2..9177d89c67b 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -35,21 +35,19 @@ import six from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import distributed_file_utils from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.distribute import multi_worker_training_state as training_state from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 -from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.training import checkpoint_management @@ -1614,7 +1612,7 @@ class LearningRateScheduler(Callback): @keras_export('keras.callbacks.TensorBoard', v1=[]) -class TensorBoard(Callback): +class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -1676,11 +1674,10 @@ class TensorBoard(Callback): batches. Note that writing too frequently to TensorBoard can slow down your training. profile_batch: Profile the batch(es) to sample compute characteristics. - profile_batch must be a non-negative integer or a comma separated string - of pair of positive integers. A pair of positive integers signify a - range of batches to profile. By default, it will profile the second - batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow - eager mode. + profile_batch must be a non-negative integer or a tuple of integers. + A pair of positive integers signify a range of batches to profile. + By default, it will profile the second batch. Set profile_batch=0 + to disable profiling. Must run in TensorFlow eager mode. embeddings_freq: frequency (in epochs) at which embedding layers will be visualized. If set to 0, embeddings won't be visualized. embeddings_metadata: a dictionary which maps layer name to a file name in @@ -1713,30 +1710,18 @@ class TensorBoard(Callback): self.histogram_freq = histogram_freq self.write_graph = write_graph self.write_images = write_images - if update_freq == 'batch': - self.update_freq = 1 - else: - self.update_freq = update_freq + self.update_freq = 1 if update_freq == 'batch' else update_freq self.embeddings_freq = embeddings_freq self.embeddings_metadata = embeddings_metadata + self._init_profile_batch(profile_batch) + self._epoch = 0 - self._samples_seen = 0 - self._samples_seen_at_last_write = 0 - self._current_batch = 0 - - # A collection of file writers currently in use, to be closed when - # training ends for this callback. Writers are keyed by the - # directory name under the root logdir: e.g., "train" or - # "validation". - self._train_run_name = 'train' - self._validation_run_name = 'validation' + # Lazily initialized in order to avoid creating event files when + # not needed. self._writers = {} - self._start_batch, self._stop_batch = self._init_profile_batch( - profile_batch) - if self._start_batch > 0: - profiler.warmup() # Improve the profiling accuracy. - # True when a trace is running. - self._is_tracing = False + + # Used to restore any existing `SummaryWriter` after training ends. + self._prev_summary_state = [] def _validate_kwargs(self, kwargs): """Handle arguments were supported in V1.""" @@ -1768,37 +1753,56 @@ class TensorBoard(Callback): def set_model(self, model): """Sets Keras model and writes graph if specified.""" self.model = model + self._log_write_dir = self._get_log_write_dir() - # In case this callback is used via native Keras, _get_distribution_strategy does not exist. - if hasattr(self.model, '_get_distribution_strategy'): - # TensorBoard callback involves writing a summary file in a - # possibly distributed settings. - self._log_write_dir = distributed_file_utils.write_dirpath( - self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access - else: - self._log_write_dir = self.log_dir + self._train_dir = os.path.join(self._log_write_dir, 'train') + self._train_step = self.model._train_counter # pylint: disable=protected-access - with context.eager_mode(): - self._close_writers() - if self.write_graph: - with self._get_writer(self._train_run_name).as_default(): - with summary_ops_v2.always_record_summaries(): - if not model.run_eagerly: - summary_ops_v2.graph(K.get_graph(), step=0) + self._val_dir = os.path.join(self._log_write_dir, 'validation') + self._val_step = self.model._test_counter # pylint: disable=protected-access - summary_writable = ( - self.model._is_graph_network or # pylint: disable=protected-access - self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access - if summary_writable: - summary_ops_v2.keras_model('keras', self.model, step=0) + self._writers = {} # Resets writers. + if self.write_graph: + self._write_keras_model_graph() if self.embeddings_freq: self._configure_embeddings() - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - self._prev_summary_recording = summary_state.is_recording - self._prev_summary_writer = summary_state.writer - self._prev_summary_step = summary_state.step + @property + def _train_writer(self): + if 'train' not in self._writers: + self._writers['train'] = summary_ops_v2.create_file_writer_v2( + self._train_dir) + return self._writers['train'] + + @property + def _val_writer(self): + if 'val' not in self._writers: + self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir) + return self._writers['val'] + + def _get_log_write_dir(self): + """For multi-worker, only chief should write, others write to '/tmp'.""" + return distributed_file_utils.write_dirpath(self.log_dir, + self.model.distribute_strategy) + + def _delete_tmp_write_dir(self): + """Deletes tmp write directories for multi-worker.""" + distributed_file_utils.remove_temp_dirpath(self.log_dir, + self.model.distribute_strategy) + + def _write_keras_model_graph(self): + """Writes Keras graph networks to TensorBoard.""" + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): + if not self.model.run_eagerly: + summary_ops_v2.graph(K.get_graph(), step=0) + + summary_writable = ( + self.model._is_graph_network or # pylint: disable=protected-access + self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access + if summary_writable: + summary_ops_v2.keras_model('keras', self.model, step=0) def _configure_embeddings(self): """Configure the Projector for embeddings.""" @@ -1839,74 +1843,44 @@ class TensorBoard(Callback): writer = DummyWriter(self._log_write_dir) projector.visualize_embeddings(writer, config) - def _close_writers(self): - """Close all remaining open file writers owned by this callback. - - If there are no such file writers, this is a no-op. - """ - with context.eager_mode(): - for writer in six.itervalues(self._writers): - writer.close() - self._writers.clear() - - def _get_writer(self, writer_name): - """Get a summary writer for the given subdirectory under the logdir. - - A writer will be created if it does not yet exist. - - Arguments: - writer_name: The name of the directory for which to create or - retrieve a writer. Should be either `self._train_run_name` or - `self._validation_run_name`. - - Returns: - A `SummaryWriter` object. - """ - if writer_name not in self._writers: - path = os.path.join(self._log_write_dir, writer_name) - writer = summary_ops_v2.create_file_writer_v2(path) - self._writers[writer_name] = writer - return self._writers[writer_name] - - def _set_default_writer(self, writer_name): + def _push_writer(self, writer, step): """Sets the default writer for custom batch-level summaries.""" if self.update_freq == 'epoch': - # Writer is only used for custom summaries, which are written - # batch-by-batch. return - step = self._total_batches_seen[writer_name] + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + self._prev_summary_state.append({ + 'is_recording': summary_state.is_recording, + 'writer': summary_state.writer, + 'step': summary_state.step + }) - def _should_record(): - return math_ops.equal(step % self.update_freq, 0) + if self.update_freq == 'epoch': + should_record = False + writer = None + else: + should_record = lambda: math_ops.equal(step % self.update_freq, 0) + + summary_state.is_recording = should_record + summary_state.writer = writer + # TODO(b/151339474): Fix deadlock when not using .value() here. + summary_ops_v2.set_step(step.value()) + + def _pop_writer(self): + """Pops the current writer.""" + if self.update_freq == 'epoch': + return + + prev_state = self._prev_summary_state.pop() summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.is_recording = _should_record - summary_state.writer = self._get_writer(writer_name) - summary_ops_v2.set_step(step) + summary_state.is_recording = prev_state['is_recording'] + summary_state.writer = prev_state['writer'] + summary_ops_v2.set_step(prev_state['step']) - def _init_batch_steps(self): - """Create the total batch counters.""" - if ops.executing_eagerly_outside_functions(): - # Variables are needed for the `step` value of custom tf.summaries - # to be updated inside a tf.function. - self._total_batches_seen = { - self._train_run_name: variables.Variable(0, dtype='int64'), - self._validation_run_name: variables.Variable(0, dtype='int64') - } - else: - # Custom tf.summaries are not supported in legacy graph mode. - self._total_batches_seen = { - self._train_run_name: 0, - self._validation_run_name: 0 - } - - def _increment_step(self, writer_name): - step = self._total_batches_seen[writer_name] - if isinstance(step, variables.Variable): - step.assign_add(1) - else: - self._total_batches_seen[writer_name] += 1 + def _close_writers(self): + for writer in self._writers.values(): + writer.close() def _init_profile_batch(self, profile_batch): """Validate profile_batch value and set the range of batches to profile. @@ -1926,75 +1900,79 @@ class TensorBoard(Callback): """ profile_batch_error_message = ( - 'profile_batch must be a non-negative integer or a comma separated ' - 'string of pair of positive integers. A pair of positive integers ' - 'signify a range of batches to profile.') - try: - profile_range = [int(i) for i in str(profile_batch).split(',')] - except ValueError: - raise ValueError(profile_batch_error_message) - if len(profile_range) == 1: # single batch - start_batch, stop_batch = profile_range[0], profile_range[0] - if start_batch < 0: - raise ValueError(profile_batch_error_message) - elif len(profile_range) == 2: # (start_batch, stop_batch) - start_batch, stop_batch = profile_range - # [0, 0], [-1, 100], [6, 5] are illegal. - if start_batch <= 0 or start_batch > stop_batch: - raise ValueError(profile_batch_error_message) + 'profile_batch must be a non-negative integer or 2-tuple of positive ' + 'integers. A pair of positive integers signifies a range of batches ' + 'to profile. Found: {}'.format(profile_batch)) + + # Support legacy way of specifying "start,stop" or "start" as str. + if isinstance(profile_batch, six.string_types): + profile_batch = str(profile_batch).split(',') + profile_batch = nest.map_structure(int, profile_batch) + + if isinstance(profile_batch, int): + self._start_batch = profile_batch + self._stop_batch = profile_batch + elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2: + self._start_batch, self._stop_batch = profile_batch else: raise ValueError(profile_batch_error_message) - return start_batch, stop_batch + + if self._start_batch < 0 or self._stop_batch < self._start_batch: + raise ValueError(profile_batch_error_message) + + if self._start_batch > 0: + profiler.warmup() # Improve the profiling accuracy. + # True when a trace is running. + self._is_tracing = False + + # Setting `profile_batch=0` disables profiling. + self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0) def on_train_begin(self, logs=None): - self._init_batch_steps() - if self._start_batch == 1: - self._enable_trace() + self._push_writer(self._train_writer, self._train_step) + + def on_train_end(self, logs=None): + self._pop_writer() + + if self._is_tracing: + self._stop_trace() + + self._close_writers() + self._delete_tmp_write_dir() def on_test_begin(self, logs=None): - self._set_default_writer(self._validation_run_name) + self._push_writer(self._val_writer, self._val_step) + + def on_test_end(self, logs=None): + self._pop_writer() + + def on_train_batch_begin(self, batch, logs=None): + if not self._should_trace: + return + + if self._epoch == 0 and batch == self._start_batch: + self._start_trace() def on_train_batch_end(self, batch, logs=None): - """Writes scalar summaries for metrics on every training batch. - - Performs profiling if current batch is in profiler_batches. + """Performs profiling if current batch is in profiler_batches. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Metric results for this batch. """ - # TODO(b/150629188): Make TensorBoard callback not use batch hooks - # by default. - if self.update_freq == 'epoch' and self._start_batch is None: + if not self._should_trace: return - # Don't output batch_size and batch number as TensorBoard summaries - logs = logs or {} - train_batches = self._total_batches_seen[self._train_run_name] - if self.update_freq != 'epoch' and batch % self.update_freq == 0: - self._log_metrics(logs, prefix='batch_', step=train_batches) - - self._increment_step(self._train_run_name) - if self._is_tracing: - control_flow_ops.cond( - math_ops.greater_equal(train_batches, self._stop_batch), - lambda: self._log_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda - else: - control_flow_ops.cond( - math_ops.equal(train_batches, self._start_batch - 1), - lambda: self._enable_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda - - def on_test_batch_end(self, batch, logs=None): - if self.update_freq == 'epoch': - return - self._increment_step(self._validation_run_name) + if self._is_tracing and batch >= self._stop_batch: + self._stop_trace() def on_epoch_begin(self, epoch, logs=None): - self._set_default_writer(self._train_run_name) + # Keeps track of epoch for profiling. + self._epoch = epoch def on_epoch_end(self, epoch, logs=None): """Runs metrics and histogram summaries at epoch end.""" - self._log_metrics(logs, prefix='epoch_', step=epoch) + self._log_epoch_metrics(epoch, logs) if self.histogram_freq and epoch % self.histogram_freq == 0: self._log_weights(epoch) @@ -2002,124 +1980,57 @@ class TensorBoard(Callback): if self.embeddings_freq and epoch % self.embeddings_freq == 0: self._log_embeddings(epoch) - def on_train_end(self, logs=None): - if self._is_tracing: - self._log_trace() - self._close_writers() - - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.is_recording = self._prev_summary_recording - summary_state.writer = self._prev_summary_writer - summary_state.step = self._prev_summary_step - - # In case this callback is used via native Keras, _get_distribution_strategy does not exist. - if hasattr(self.model, '_get_distribution_strategy'): - # Safely remove the unneeded temp files. - distributed_file_utils.remove_temp_dirpath( - self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access - - def _enable_trace(self): - """Starts to collect trace graph to TensorBoard. - - Collects both trace and graph in eager mode, and trace only in graph mode. - """ - if context.executing_eagerly(): - # Graph must be traced in eager mode. - summary_ops_v2.trace_on(graph=True, profiler=False) - profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) + def _start_trace(self): + summary_ops_v2.trace_on(graph=True, profiler=False) + profiler.start(logdir=self._train_dir) self._is_tracing = True - def _enable_trace_return_true(self): - """Starts to collect trace graph to TensorBoard and returns True. - - Returns: - True. - """ - self._enable_trace() - return True - - def _log_trace(self): - """Logs the trace graph to TensorBoard. - - Logs both trace and graph in eager mode, and trace only in graph mode. - """ - profiler.stop() - if context.executing_eagerly(): - # Graph must be traced in eager mode. - with self._get_writer(self._train_run_name).as_default(), \ - summary_ops_v2.always_record_summaries(): + def _stop_trace(self, batch=None): + """Logs the trace graph to TensorBoard.""" + if batch is None: + batch = self._stop_batch + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): # TODO(b/126388999): Remove step info in the summary name. - step = K.get_value(self._total_batches_seen[self._train_run_name]) - summary_ops_v2.trace_export(name='batch_%d' % step, step=step) + summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch) + profiler.stop() self._is_tracing = False - def _log_trace_return_true(self): - """Logs the trace graph to TensorBoard and returns True. - - Returns: - True. - """ - self._log_trace() - return True - - def _log_metrics(self, logs, prefix, step): - """Writes metrics out as custom scalar summaries. + def _log_epoch_metrics(self, epoch, logs): + """Writes epoch metrics out as scalar summaries. Arguments: - logs: Dict. Keys are scalar summary names, values are NumPy scalars. - prefix: String. The prefix to apply to the scalar summary names. - step: Int. The global step to use for TensorBoard. + epoch: Int. The global step to use for TensorBoard. + logs: Dict. Keys are scalar summary names, values are scalars. """ - if logs is None: - logs = {} + if not logs: + return - # Group metrics by the name of their associated file writer. Values - # are lists of metrics, as (name, scalar_value) pairs. - logs_by_writer = { - self._train_run_name: [], - self._validation_run_name: [], - } - validation_prefix = 'val_' - for (name, value) in logs.items(): - if name in ('batch', 'size', 'num_steps'): - # Scrub non-metric items. - continue - if name.startswith(validation_prefix): - name = name[len(validation_prefix):] - writer_name = self._validation_run_name - else: - writer_name = self._train_run_name - name = prefix + name # assign batch or epoch prefix - logs_by_writer[writer_name].append((name, value)) + train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')} + val_logs = {k: v for k, v in logs.items() if k.startswith('val_')} - with context.eager_mode(): - with summary_ops_v2.always_record_summaries(): - for writer_name in logs_by_writer: - these_logs = logs_by_writer[writer_name] - if not these_logs: - # Don't create a "validation" events file if we don't - # actually have any validation data. - continue - writer = self._get_writer(writer_name) - with writer.as_default(): - for (name, value) in these_logs: - summary_ops_v2.scalar(name, value, step=step) + with summary_ops_v2.always_record_summaries(): + if train_logs: + with self._train_writer.as_default(): + for name, value in train_logs.items(): + summary_ops_v2.scalar('epoch_' + name, value, step=epoch) + if val_logs: + with self._val_writer.as_default(): + for name, value in val_logs.items(): + name = name[4:] # Remove 'val_' prefix. + summary_ops_v2.scalar('epoch_' + name, value, step=epoch) def _log_weights(self, epoch): """Logs the weights of the Model to TensorBoard.""" - writer = self._get_writer(self._train_run_name) - with context.eager_mode(), \ - writer.as_default(), \ - summary_ops_v2.always_record_summaries(): - for layer in self.model.layers: - for weight in layer.weights: - weight_name = weight.name.replace(':', '_') - with ops.init_scope(): - weight = K.get_value(weight) - summary_ops_v2.histogram(weight_name, weight, step=epoch) - if self.write_images: - self._log_weight_as_image(weight, weight_name, epoch) - writer.flush() + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): + for layer in self.model.layers: + for weight in layer.weights: + weight_name = weight.name.replace(':', '_') + summary_ops_v2.histogram(weight_name, weight, step=epoch) + if self.write_images: + self._log_weight_as_image(weight, weight_name, epoch) + self._train_writer.flush() def _log_weight_as_image(self, weight, weight_name, epoch): """Logs a weight as a TensorBoard image.""" @@ -2150,6 +2061,9 @@ class TensorBoard(Callback): 'keras_embedding.ckpt-{}'.format(epoch)) self.model.save_weights(embeddings_ckpt) + def _implements_train_batch_hooks(self): + return not (self._start_batch == 0 and self._stop_batch == 0) + @keras_export('keras.callbacks.ReduceLROnPlateau') class ReduceLROnPlateau(Callback): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index eb62d0b29ee..54f71402177 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -2079,17 +2079,19 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): model.fit( np.zeros((64, 1)), np.zeros((64, 1)), + batch_size=32, callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)], ) # Verifies trace exists in the first train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) model.fit( np.zeros((64, 1)), np.zeros((64, 1)), + batch_size=32, callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)], ) # Verifies trace exists in the second train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) def test_TensorBoard_autoTrace_profileBatchRange(self): model = self._get_seq_model() diff --git a/tensorflow/python/keras/callbacks_v1.py b/tensorflow/python/keras/callbacks_v1.py index 524e039f597..09af890b76c 100644 --- a/tensorflow/python/keras/callbacks_v1.py +++ b/tensorflow/python/keras/callbacks_v1.py @@ -39,7 +39,7 @@ from tensorflow.python.util.tf_export import keras_export @keras_export(v1=['keras.callbacks.TensorBoard']) -class TensorBoard(callbacks.Callback): +class TensorBoard(callbacks.TensorBoard): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -127,7 +127,8 @@ class TensorBoard(callbacks.Callback): embeddings_data=None, update_freq='epoch', profile_batch=2): - super(TensorBoard, self).__init__() + # Don't call super's init since it is an eager-only version. + callbacks.Callback.__init__(self) self.log_dir = log_dir self.histogram_freq = histogram_freq if self.histogram_freq and context.executing_eagerly(): @@ -342,6 +343,21 @@ class TensorBoard(callbacks.Callback): self.writer.add_summary(summary, step) self.writer.flush() + def on_train_batch_begin(self, batch, logs=None): + if (not self._is_profiling and + self._total_batches_seen == self._profile_batch - 1): + profiler.start(self.log_dir) + self._is_profiling = True + + def on_train_batch_end(self, batch, logs=None): + return self.on_batch_end(batch, logs) + + def on_test_begin(self, logs=None): + pass + + def on_test_end(self, logs=None): + pass + def on_batch_end(self, batch, logs=None): """Writes scalar summaries for metrics on every training batch. @@ -358,18 +374,13 @@ class TensorBoard(callbacks.Callback): self._write_custom_summaries(self._total_batches_seen, batch_logs) self._samples_seen_at_last_write = self._samples_seen self._total_batches_seen += 1 + if self._is_profiling: profiler.stop() self._is_profiling = False - elif (not self._is_profiling and - self._total_batches_seen == self._profile_batch - 1): - profiler.start(self.log_dir) - self._is_profiling = True def on_train_begin(self, logs=None): - if self._profile_batch == 1: - profiler.start(self.log_dir) - self._is_profiling = True + pass def on_epoch_begin(self, epoch, logs=None): """Add histogram op to Model eval_function callbacks, reset batch count.""" diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 7dcf10a506c..21361f680da 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import copy +import itertools from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribute_coordinator_context as dc_context @@ -28,6 +29,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import monitoring +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import callbacks as callbacks_module from tensorflow.python.keras import optimizers @@ -43,6 +45,8 @@ from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import summary_ops_v2 +from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.profiler import trace @@ -161,6 +165,9 @@ class Model(network.Network, version_utils.ModelVersionSelector): Checkout [guide](https://www.tensorflow.org/guide/keras/overview) for additional details. """ + _TF_MODULE_IGNORED_PROPERTIES = frozenset( + itertools.chain(('_train_counter', '_test_counter', '_predict_counter'), + network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access def __init__(self, *args, **kwargs): super(Model, self).__init__(*args, **kwargs) @@ -186,6 +193,18 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.compiled_loss = None self.compiled_metrics = None + self._init_batch_counters() + + @trackable.no_automatic_dependency_tracking + def _init_batch_counters(self): + # Untracked Variables, used to keep track of mini-batches seen in `fit`, + # `evaluate`, and `predict`. + agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA + self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg) + self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg) + self._predict_counter = variables.Variable( + 0, dtype='int64', aggregation=agg) + def get_weights(self): """Retrieves the weights of the model. @@ -499,11 +518,18 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.train_function def train_function(iterator): + """Runs one call to `self.train_function`.""" + + def run_step(data): + outputs = self.train_step(data) + self._train_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.train_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') + write_scalar_summaries(outputs, step=self._train_counter) return outputs if not self.run_eagerly: @@ -762,6 +788,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.stop_training = False train_function = self.make_train_function() + self._train_counter.assign(0) callbacks.on_train_begin() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to @@ -872,9 +899,15 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.test_function def test_function(iterator): + """Runs one call to `self.test_function`.""" + + def run_step(data): + outputs = self.test_step(data) + self._test_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.test_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') return outputs @@ -1003,6 +1036,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): steps=data_handler.inferred_steps) test_function = self.make_test_function() + self._test_counter.assign(0) callbacks.on_test_begin() for _, iterator in data_handler.enumerate_epochs(): # Single epoch. self.reset_metrics() @@ -1075,9 +1109,15 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.predict_function def predict_function(iterator): + """Runs one call to `self.predict_function`.""" + + def run_step(data): + outputs = self.predict_step(data) + self._predict_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.predict_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='concat') return outputs @@ -1192,6 +1232,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): steps=data_handler.inferred_steps) predict_function = self.make_predict_function() + self._predict_counter.assign(0) callbacks.on_predict_begin() for _, iterator in data_handler.enumerate_epochs(): # Single epoch. with data_handler.catch_stop_iteration(): @@ -1734,3 +1775,13 @@ def _minimize(tape, optimizer, loss, trainable_variables): all_reduce_sum_gradients=False) else: optimizer.apply_gradients(zip(gradients, trainable_variables)) + + +def _is_scalar(x): + return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 + + +def write_scalar_summaries(logs, step): + for name, value in logs.items(): + if _is_scalar(value): + summary_ops_v2.scalar('batch_' + name, value, step=step) diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 1c0fea91337..710f9bf3497 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -162,6 +162,9 @@ class Model(training_lib.Model): self._v1_compile_was_called = False + def _init_batch_counters(self): + pass # Batch counters should not be created in legacy graph mode. + @trackable.no_automatic_dependency_tracking def _set_strategy(self, strategy): self._compile_time_distribution_strategy = strategy diff --git a/tensorflow/python/keras/tests/model_subclassing_test.py b/tensorflow/python/keras/tests/model_subclassing_test.py index 761f720cea5..5af1148f4f0 100644 --- a/tensorflow/python/keras/tests/model_subclassing_test.py +++ b/tensorflow/python/keras/tests/model_subclassing_test.py @@ -737,6 +737,21 @@ class CustomCallSignatureTests(test.TestCase, parameterized.TestCase): self.assertLen(new_model.variables, 1) self.assertLen(new_model.layers, 1) + def test_batch_counters_not_in_variables(self): + + class MyModel(keras.Model): + + def __init__(self): + super(MyModel, self).__init__() + self.layer = keras.layers.Dense(4) + + def call(self, obs): + return self.layer(obs) + + model = MyModel() + model(np.ones((10, 10))) + self.assertLen(model.variables, 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/utils/version_utils.py b/tensorflow/python/keras/utils/version_utils.py index cf485e1080d..377f370430c 100644 --- a/tensorflow/python/keras/utils/version_utils.py +++ b/tensorflow/python/keras/utils/version_utils.py @@ -36,6 +36,13 @@ base_layer = lazy_loader.LazyLoader( base_layer_v1 = lazy_loader.LazyLoader( "base_layer_v1", globals(), "tensorflow.python.keras.engine.base_layer_v1") +callbacks = lazy_loader.LazyLoader( + "callbacks", globals(), + "tensorflow.python.keras.callbacks") +callbacks_v1 = lazy_loader.LazyLoader( + "callbacks_v1", globals(), + "tensorflow.python.keras.callbacks_v1") + # pylint: enable=g-inconsistent-quotes @@ -58,6 +65,21 @@ class LayerVersionSelector(object): return super(LayerVersionSelector, cls).__new__(cls) +class TensorBoardVersionSelector(object): + """Chooses between Keras v1 and v2 TensorBoard callback class.""" + + def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument + eager_enabled = ops.executing_eagerly_outside_functions() + start_cls = cls + cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, + eager_enabled) + if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard: + # Since the v2 class is not a subclass of the v1 class, __init__ has to + # be called manually. + return cls(*args, **kwargs) + return super(TensorBoardVersionSelector, cls).__new__(cls) + + def swap_class(cls, v2_cls, v1_cls, eager_enabled): """Swaps in v2_cls or v1_cls depending on graph mode.""" if cls == object: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt index 4504633d4a1..2e0c6c97826 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt index 24385e2722a..51d6901e936 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" From 67d89da9b5070b88c9d43009146945672f32c451 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Tue, 17 Mar 2020 10:29:40 -0700 Subject: [PATCH 078/492] For tensorflow::io, on windows use '\' as path separator. PiperOrigin-RevId: 301402270 Change-Id: I523e2fa22abe8771df52f624ca0df6481cf10a28 --- tensorflow/core/platform/path.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/core/platform/path.cc b/tensorflow/core/platform/path.cc index 1e88328aace..00e3f0eca28 100644 --- a/tensorflow/core/platform/path.cc +++ b/tensorflow/core/platform/path.cc @@ -38,7 +38,11 @@ namespace io { namespace internal { namespace { +#if defined(PLATFORM_WINDOWS) +const char kPathSep[] = "\\"; +#else const char kPathSep[] = "/"; +#endif // PLATFORM_WINDOWS bool FixBazelEnvPath(const char* path, string* out) { if (path == nullptr) return false; From fd1eef69653fef9b7701d1be93c6427fc72cf226 Mon Sep 17 00:00:00 2001 From: Taehee Jeong Date: Tue, 17 Mar 2020 10:43:49 -0700 Subject: [PATCH 079/492] internal code change PiperOrigin-RevId: 301405619 Change-Id: Ic11eb5922c8ff27675afda3d6b69d10e4dcbae79 --- .../swift/Sources/CoreMlDelegate.swift | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tensorflow/lite/experimental/swift/Sources/CoreMlDelegate.swift diff --git a/tensorflow/lite/experimental/swift/Sources/CoreMlDelegate.swift b/tensorflow/lite/experimental/swift/Sources/CoreMlDelegate.swift new file mode 100644 index 00000000000..21e0276578c --- /dev/null +++ b/tensorflow/lite/experimental/swift/Sources/CoreMlDelegate.swift @@ -0,0 +1,50 @@ +// Copyright 2020 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlowLiteC + +/// A delegate that uses the `Core ML` framework for performing TensorFlow Lite graph operations. +/// +/// - Important: This is an experimental interface that is subject to change. +public final class CoreMLDelegate: Delegate { + /// The configuration options for the `CoreMLDelegate`. + public let options: Options + + // Conformance to the `Delegate` protocol. + public private(set) var cDelegate: CDelegate + + /// Creates a new instance configured with the given `options`. + /// + /// - Parameters: + /// - options: Configurations for the delegate. The default is a new instance of + /// `CoreMLDelegate.Options` with the default configuration values. + public init(options: Options = Options()) { + self.options = options + var delegateOptions = TfLiteCoreMlDelegateOptions() + cDelegate = TfLiteCoreMlDelegateCreate(&delegateOptions) + } + + deinit { + TfLiteCoreMlDelegateDelete(cDelegate) + } +} + +extension CoreMLDelegate { + /// Options for configuring the `CoreMLDelegate`. + // TODO(b/143931022): Add preferred device support. + public struct Options: Equatable, Hashable { + /// Creates a new instance with the default values. + public init() {} + } +} From 73278044ba9c5f5601c167e5233d3e107e00a9be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 10:49:44 -0700 Subject: [PATCH 080/492] Update PIP_BIN_PATH to work with all py versions. PiperOrigin-RevId: 301407013 Change-Id: I86e22da7e18e5c5e0f58d89b1cffaf65d422cef1 --- tensorflow/tools/ci_build/builds/pip_new.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/builds/pip_new.sh b/tensorflow/tools/ci_build/builds/pip_new.sh index 4b0a4914ede..330fa44b0de 100755 --- a/tensorflow/tools/ci_build/builds/pip_new.sh +++ b/tensorflow/tools/ci_build/builds/pip_new.sh @@ -431,7 +431,7 @@ install_tensorflow_pip() { fi # Set path to pip. - PIP_BIN_PATH="$(which pip${PY_MAJOR_MINOR_VER})" + PIP_BIN_PATH="${PYTHON_BIN_PATH} -m pip" # Print python and pip bin paths echo "PYTHON_BIN_PATH to be used to install the .whl: ${PYTHON_BIN_PATH}" From bdb77d1e2e755fb7a3aa5614a16bbc7fe152b64a Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Tue, 17 Mar 2020 10:50:24 -0700 Subject: [PATCH 081/492] Work around the following bug in python 3.8 OverflowError: Python int too large to convert to C long On windows, sizeof(long) is 4 bytes. Therefore, the large integers seem to be a problem when backed by some numpy types: https://stackoverflow.com/questions/38314118/overflowerror-python-int-too-large-to-convert-to-c-long-on-windows-but-not-ma PiperOrigin-RevId: 301407257 Change-Id: Ic0f7379eee360ff053283742f6495c1974b2c5e4 --- tensorflow/compiler/tests/categorical_op_test.py | 2 +- tensorflow/compiler/tests/stateless_random_ops_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index ef6df1f0879..afda99a7e06 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -187,7 +187,7 @@ class CategoricalTest(xla_test.XLATestCase): 0, seed=seed_t, output_dtype=dtypes.int32) - y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) self.assertEqual(y.shape, (42, 0)) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 6576e274300..56b49689607 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -101,7 +101,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_normal( shape=[10000], seed=seed_t, dtype=dtype) - y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) self.assertTrue(np.all(np.isfinite(y))) def testDistributionOfStatelessRandomNormal(self): From 81609059b4bc9b896d706e5b202fb884d0678dad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 11:23:46 -0700 Subject: [PATCH 082/492] Change CreateScalarComparisonComputation to have an option to have a vector of XlaOpGenerator as input. PiperOrigin-RevId: 301414874 Change-Id: Ib6eb8973ade305c2fd1556b8ba8ce5e53b33f418 --- .../compiler/xla/client/lib/comparators.cc | 80 ++++++++++++++----- .../compiler/xla/client/lib/comparators.h | 11 +++ 2 files changed, 69 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc index 11a79a262ef..74e89b767cf 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.cc +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -88,9 +88,39 @@ XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, return Select(is_negative, flipped_value, signed_value); } +void ConvertFloatingPoint(const PrimitiveType& operand_type, XlaOp* lhs_param, + XlaOp* rhs_param) { + if (primitive_util::IsFloatingPointType(operand_type)) { + PrimitiveType compare_type = operand_type; + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + *lhs_param = ConvertElementType(*lhs_param, F32); + *rhs_param = ConvertElementType(*rhs_param, F32); + } + int64 bit_width = primitive_util::BitWidth(compare_type); + *lhs_param = BitcastConvertFloatingPointToIntegral(*lhs_param, bit_width); + *rhs_param = BitcastConvertFloatingPointToIntegral(*rhs_param, bit_width); + } +} + XlaComputation CreateScalarComparisonComputation( const string& name, const std::vector& operand_types, XlaBuilder* builder, XlaOpGenerator generator) { + CHECK_NE(operand_types.size(), 0); + std::vector> generators(operand_types.size()); + generators[0] = generator; + return CreateScalarComparisonComputation(name, operand_types, generators, + builder); +} +} // namespace + +XlaComputation CreateScalarComparisonComputation( + const string& name, const std::vector& operand_types, + const std::vector>& generators, + XlaBuilder* builder) { // Create a default computation where we compare only the first two // parameters of type 'operand_types[0]'. auto b = builder->CreateSubBuilder(name); @@ -99,9 +129,11 @@ XlaComputation CreateScalarComparisonComputation( return b->BuildAndNoteError(); } + CHECK_EQ(operand_types.size(), generators.size()); int64 parameter_count = 0; - XlaOp first_lhs_param; - XlaOp first_rhs_param; + int64 last_generator_index = 0; + std::vector lhs_params; + std::vector rhs_params; // For each type in 'operand_types' we create two parameters of this type. The // idea is that this computation can be used by n-ary Sort, and potentially @@ -114,32 +146,36 @@ XlaComputation CreateScalarComparisonComputation( absl::StrCat("p.", parameter_count, ".lhs")); auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape, absl::StrCat("p.", parameter_count, ".rhs")); - if (parameter_count == 0) { - first_lhs_param = lhs_param; - first_rhs_param = rhs_param; + ConvertFloatingPoint(operand_type, &lhs_param, &rhs_param); + lhs_params.emplace_back(lhs_param); + rhs_params.emplace_back(rhs_param); + if (generators[parameter_count].has_value()) { + last_generator_index = parameter_count; } - ++parameter_count; + parameter_count++; } - if (primitive_util::IsFloatingPointType(operand_types[0])) { - PrimitiveType compare_type = operand_types[0]; - // Special-case handling for BF16. We currently do not support direct - // comparisons with BF16, so we convert to F32 and then use the F32 - // comparison logic. - if (compare_type == BF16) { - compare_type = F32; - first_lhs_param = ConvertElementType(first_lhs_param, F32); - first_rhs_param = ConvertElementType(first_rhs_param, F32); + + CHECK_NE(parameter_count, 0); + + Shape shape = b->GetShape(lhs_params[0]).ValueOrDie(); + shape.set_element_type(PRED); + XlaOp param_equal = Broadcast(One(b.get(), shape.element_type()), + AsInt64Slice(shape.dimensions())); + XlaOp result = param_equal; + + for (int64 i = 0; i < parameter_count; i++) { + if (generators[i].has_value()) { + result = Select(param_equal, + generators[i].value()(lhs_params[i], rhs_params[i], {}), + result); + if (i != last_generator_index) { + param_equal = And(param_equal, Eq(lhs_params[i], rhs_params[i])); + } } - int64 bit_width = primitive_util::BitWidth(compare_type); - first_lhs_param = - BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width); - first_rhs_param = - BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width); } - generator(first_lhs_param, first_rhs_param, {}); + return b->BuildAndNoteError(); } -} // namespace // Creates a scalar less-than computation and returns it. XlaComputation CreateScalarLtComputation( diff --git a/tensorflow/compiler/xla/client/lib/comparators.h b/tensorflow/compiler/xla/client/lib/comparators.h index cbcfc227dd4..25924d4a4f4 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.h +++ b/tensorflow/compiler/xla/client/lib/comparators.h @@ -42,6 +42,17 @@ XlaComputation CreateScalarLtComputation( XlaComputation CreateScalarGtComputation( const std::vector& operand_types, XlaBuilder* builder); +// Creates a scalar comparison computation and returns it. This function takes +// an std::vector> and compare the operands +// where the generator isn't nullopt with the specified comparator +// at that location. +XlaComputation CreateScalarComparisonComputation( + const string& name, const std::vector& operand_types, + const std::vector< + absl::optional)>>& + generators, + XlaBuilder* builder); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ From ce234ef4b79b926c89123933d38bc52468169ea2 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 17 Mar 2020 11:25:45 -0700 Subject: [PATCH 083/492] [NFC] Replace all usages of PatternMatchResult with LogicalResult This also replaces usages of matchSuccess/matchFailure with success/failure respectively. PiperOrigin-RevId: 301415284 Change-Id: I9a127687ac0586bb5da97f7171a0b156888d4f81 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 37 +- .../lite/quantization/quantization_utils.h | 40 +- .../quantization/tensorflow/tf_to_quant.cc | 14 +- .../mlir/lite/quantization/xla/materialize.cc | 14 +- .../mlir/lite/transforms/dilated_conv.h | 37 +- .../mlir/lite/transforms/legalize_tf.cc | 119 +++--- .../transforms/lower_static_tensor_list.cc | 68 +-- .../compiler/mlir/lite/transforms/optimize.cc | 111 +++-- .../transforms/optimize_functional_ops.cc | 16 +- .../mlir/lite/transforms/prepare_tf.cc | 54 +-- .../compiler/mlir/tensorflow/ir/tf_device.cc | 8 +- .../mlir/tensorflow/ir/tf_executor.cc | 40 +- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 16 +- .../transforms/batchmatmul_to_einsum.cc | 12 +- .../transforms/batchmatmul_to_einsum.h | 2 +- .../mlir/tensorflow/transforms/einsum.cc | 20 +- .../mlir/tensorflow/transforms/einsum.h | 4 +- .../mlir/tensorflow/transforms/gpu_fusion.cc | 16 +- .../mlir/tensorflow/transforms/lower_tf.cc | 35 +- .../transforms/unroll_batch_matmul.cc | 22 +- .../transforms/unroll_batch_matmul.h | 4 +- tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 12 +- .../xla/transforms/hlo_legalize_to_lhlo.cc | 44 +- .../mlir/xla/transforms/legalize_tf.cc | 396 +++++++++--------- .../xla/transforms/legalize_to_standard.cc | 32 +- .../xla/transforms/lhlo_legalize_to_affine.cc | 10 +- .../xla/transforms/lhlo_legalize_to_gpu.cc | 10 +- .../lhlo_legalize_to_parallel_loops.cc | 6 +- .../mlir/xla/transforms/lower_general_dot.cc | 12 +- .../xla/transforms/materialize_broadcasts.cc | 20 +- .../mlir/xla/transforms/unfuse_batch_norm.cc | 10 +- .../xla/transforms/xla_legalize_to_linalg.cc | 54 ++- 32 files changed, 643 insertions(+), 652 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 5f8e9c35b94..471e50e0a52 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -804,10 +804,10 @@ struct RemoveAdjacentReshape : public RewritePattern { RemoveAdjacentReshape(MLIRContext *context) : RewritePattern(ReshapeOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + LogicalResult match(Operation *op) const override { auto thisOp = cast(op); auto prevOp = thisOp.getOperand(0).getDefiningOp(); - return isa_and_nonnull(prevOp) ? matchSuccess() : matchFailure(); + return isa_and_nonnull(prevOp) ? success() : failure(); } void rewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -884,28 +884,27 @@ struct RemoveRedundantUnpackPack : public RewritePattern { explicit RemoveRedundantUnpackPack(MLIRContext *context) : RewritePattern(PackOp::getOperationName(), 2, context) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { TFL::PackOp pack_op = cast(op); Operation *first_input = pack_op.getOperand(0).getDefiningOp(); - if (!first_input) return matchFailure(); + if (!first_input) return failure(); auto input_unpack_op = dyn_cast_or_null(first_input); - if (!input_unpack_op) return matchFailure(); + if (!input_unpack_op) return failure(); // The unpack & pack should have the same axis & num inputs/outputs. if (pack_op.axis() != input_unpack_op.axis() || pack_op.values_count() != input_unpack_op.num()) - return matchFailure(); + return failure(); const int total_pack_inputs = pack_op.getNumOperands(); - if (total_pack_inputs != input_unpack_op.getNumResults()) - return matchFailure(); + if (total_pack_inputs != input_unpack_op.getNumResults()) return failure(); for (auto input_output : llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) { Value pack_input = std::get<0>(input_output); Value unpack_output = std::get<1>(input_output); // Make sure the ordering is the same for the pack op & unpack op. - if (pack_input != unpack_output) return matchFailure(); + if (pack_input != unpack_output) return failure(); } // Replace the pack's output to the unpack's input. @@ -913,7 +912,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern { // At this point, we don't manually remove the redundant pack op & unpack op // (we cannot actually), but trust the PatterRewriter to garbage collect // these two ops. - return matchSuccess(); + return success(); } }; @@ -1050,17 +1049,17 @@ struct DropFakeQuant : public RewritePattern { explicit DropFakeQuant(MLIRContext *context) : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + LogicalResult match(Operation *op) const override { // We only match the op with valid "minmax" attribute. - if (!HasValidMinMaxAttribute(op)) return matchFailure(); + if (!HasValidMinMaxAttribute(op)) return failure(); // If all the users of this op have valid "minmax" attributes, it is matched // and can be removed. auto fakeQuantOp = cast(op); for (auto *operand : fakeQuantOp.getResult().getUsers()) - if (!HasValidMinMaxAttribute(operand)) return matchFailure(); + if (!HasValidMinMaxAttribute(operand)) return failure(); - return matchSuccess(); + return success(); } void rewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -1789,8 +1788,8 @@ struct WhileResultOperandsMatchAndImplicitCapture : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(WhileOp while_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(WhileOp while_op, + PatternRewriter &rewriter) const override { // Replace values simply passed through the body with extern values. The // block arguments of body and while match and so the corresponding cond // argument can be easily found. @@ -1843,7 +1842,7 @@ struct WhileResultOperandsMatchAndImplicitCapture } // Done if no values removed from blocks and operands & results match. - if (unchanged) return matchFailure(); + if (unchanged) return failure(); // Replace with new While with matching operands and results. Operation *op = while_op.getOperation(); @@ -1866,7 +1865,7 @@ struct WhileResultOperandsMatchAndImplicitCapture rewriter.replaceOpWithNewOp(new_body_block.getTerminator(), new_body_yield); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 9bb1d677df2..e9d29758823 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -82,17 +82,17 @@ struct ConvertStatsToQDQs : public OpRewritePattern { narrow_range(narrow_range), is_signed(is_signed) {} - PatternMatchResult matchAndRewrite(quant::StatisticsOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(quant::StatisticsOp op, + PatternRewriter& rewriter) const override { Type expressed = op.getType().cast().getElementType(); quant::QuantizedType quant_type; SmallVector mins, maxs; if (op.axisStats().hasValue()) { int stats_num = op.axisStats()->getNumElements(); - if (stats_num == 0 || stats_num % 2 != 0) return this->matchFailure(); + if (stats_num == 0 || stats_num % 2 != 0) return failure(); auto stats = op.axisStats()->dyn_cast(); - if (!stats) return this->matchFailure(); + if (!stats) return failure(); for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { mins.push_back(FloatAttr::getValueAsDouble(*it++)); @@ -108,7 +108,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax, narrow_range, expressed, is_signed); } else { - return this->matchFailure(); + return failure(); } rewriter.setInsertionPointAfter(op); @@ -119,7 +119,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { q.getOperation()->replaceUsesOfWith(dq, op.arg()); op.erase(); - return this->matchSuccess(); + return success(); } private: @@ -156,16 +156,16 @@ struct QuantizationPattern : public RewritePattern { error_tolerance(error_tolerance), single_layer_verify(single_layer_verify) {} - PatternMatchResult matchAndRewrite(Operation* op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { if (op->getNumResults() != 1) { - return matchFailure(); + return failure(); } Value quantized_value = op->getResult(0); for (Operation* quantized_op : quantized_value.getUsers()) { // If it is requantize op, we shouldn't rewrite this op. if (llvm::isa(quantized_op) || llvm::isa(quantized_op)) { - return matchFailure(); + return failure(); } // If it is terminator or not quantizable or any ops form the mlir quant @@ -174,7 +174,7 @@ struct QuantizationPattern : public RewritePattern { quantized_op->hasTrait() || llvm::isa(quantized_op) || llvm::isa(quantized_op)) { - return matchFailure(); + return failure(); } // Collect all the quantized inputs and "clone" the matched op by these @@ -198,7 +198,7 @@ struct QuantizationPattern : public RewritePattern { } else if (static_cast(this)->AllowHybridOperand()) { inputs.push_back(operand); } else { - return matchFailure(); + return failure(); } } @@ -234,7 +234,7 @@ struct QuantizationPattern : public RewritePattern { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); } else { - return matchFailure(); + return failure(); } } @@ -299,7 +299,7 @@ struct QuantizationPattern : public RewritePattern { } } } - return matchSuccess(); + return success(); } bool enable_verify; @@ -317,11 +317,11 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { explicit ConvertUnsignedToSigned(MLIRContext* context) : OpRewritePattern(context, 1) {} - PatternMatchResult matchAndRewrite(Q op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(Q op, + PatternRewriter& rewriter) const override { Type output_type = op.getResult().getType(); auto qtype = QType::getQuantizedElementType(output_type); - if (!qtype || qtype.isSigned()) return this->matchFailure(); + if (!qtype || qtype.isSigned()) return failure(); int num_bits = qtype.getStorageTypeIntegralWidth(); // This is a positive value, and will be applied on zero points and fixed @@ -352,14 +352,14 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { aqtype.getStorageTypeMin() - offset, aqtype.getStorageTypeMax() - offset, op.getLoc()); } else { - return this->matchFailure(); + return failure(); } - if (!new_qtype) return this->matchFailure(); + if (!new_qtype) return failure(); Type new_output_type = new_qtype.castFromExpressedType( QType::castToExpressedType(output_type)); rewriter.replaceOpWithNewOp(op, new_output_type, op.arg()); - return this->matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 64fddd06da6..d2884edafdf 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -73,12 +73,12 @@ struct InsertQuantOpsAfterTFFakeQuantOp MLIRContext *ctx) : OpRewritePattern(ctx) {} - PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, + PatternRewriter &rewriter) const override { // We don't want to insert quantize/dequantize if the quantize op exists. auto res = tf_op.outputs(); if (!res.hasOneUse() || isa(*res.user_begin())) - return this->matchFailure(); + return failure(); // Extract the min/max constant values from the operands. We also consider // a special case that there are tf.Identity ops between the min/max @@ -95,8 +95,8 @@ struct InsertQuantOpsAfterTFFakeQuantOp max = tf_op.max(); rewriter.eraseOp(id2); } - if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure(); - if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure(); + if (!matchPattern(min, m_Constant(&min_value))) return failure(); + if (!matchPattern(max, m_Constant(&max_value))) return failure(); int quant_dim = -1; if (PerAxis) { @@ -114,7 +114,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp TypeAttr qtype = quant::GetQuantizedTypeAttr( rewriter, res_type, min_value, max_value, quant_dim, num_bits, narrow_range, /*is_signed=*/true); - if (!qtype) this->matchFailure(); + if (!qtype) failure(); // Finally, use the quantization parameter to create the quantize and // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp @@ -127,7 +127,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp value.replaceAllUsesWith(dequantize); quantize.getOperation()->replaceUsesOfWith(dequantize, value); - return this->matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc index 0c746d0c943..59704b4c73a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc @@ -58,15 +58,15 @@ class RewriteDequantize : public OpRewritePattern { explicit RewriteDequantize(int64_t size, MLIRContext *context) : OpRewritePattern(context), size_(size) {} - PatternMatchResult matchAndRewrite(quant::DequantizeCastOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(quant::DequantizeCastOp op, + PatternRewriter &rewriter) const override { // quant.dcast // xla_hlo dequantize only takes min/max, so let's recover them from // the quantization parameters. Value dcast = op.arg(); auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType()); if (!type || !type.isa()) { - return matchFailure(); + return failure(); } auto qtype = type.cast(); double scale = qtype.getScale(); @@ -77,7 +77,7 @@ class RewriteDequantize : public OpRewritePattern { // quant.qcast auto qcast = llvm::dyn_cast_or_null(dcast.getDefiningOp()); - if (!qcast) return matchFailure(); + if (!qcast) return failure(); // constant DenseFPElementsAttr attr; @@ -88,7 +88,7 @@ class RewriteDequantize : public OpRewritePattern { attr.getNumElements() <= size_ || attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) { op.getResult().replaceAllUsesWith(qcast.arg()); - return matchSuccess(); + return success(); } // TODO(fengliuai): implement transpose if it has high dimension. @@ -96,7 +96,7 @@ class RewriteDequantize : public OpRewritePattern { auto quantized_result = quant::Quantize(attr, qtype).dyn_cast_or_null(); if (!quantized_result) { - return matchFailure(); + return failure(); } // Pack the uint8 bits to uint32. The shape is changed from from @@ -133,7 +133,7 @@ class RewriteDequantize : public OpRewritePattern { // Convert bf16 output back to f32 rewriter.replaceOpWithNewOp(op, op.getResult().getType(), dequantize); - return matchSuccess(); + return success(); } private: diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index 65bed845bae..d8a26154b2b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -74,31 +74,31 @@ class ConvertTFDilatedConvOp : public OpRewritePattern { PatternRewriter& rewriter) const; public: - PatternMatchResult matchAndRewrite(Conv2dOpTy op, - PatternRewriter& rewriter) const override; + LogicalResult matchAndRewrite(Conv2dOpTy op, + PatternRewriter& rewriter) const override; }; template -PatternMatchResult ConvertTFDilatedConvOp::matchAndRewrite( +LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( Conv2dOpTy op, PatternRewriter& rewriter) const { // Make sure Conv2D has 'VALID' padding. if (op.template getAttrOfType("padding").getValue() != "VALID") { - return Pattern::matchFailure(); + return failure(); } // Make sure dilations are all ones if set. const ArrayAttr& dilations = op.template getAttrOfType("dilations"); if (dilations && !TFIntListIsAllOnes(dilations)) { - return Pattern::matchFailure(); + return failure(); } // Check if the ConvOp is preceded by a `Expand` op and succeeded by a // `Squeeze` op. Operation* prev_op = op.getOperation()->getPrevNode(); - if (!prev_op) return Pattern::matchFailure(); + if (!prev_op) return failure(); Operation* next_op = op.getOperation()->getNextNode(); - if (!next_op) return Pattern::matchFailure(); + if (!next_op) return failure(); TF::ExpandDimsOp expand_op; TF::SqueezeOp squeeze_op; @@ -107,7 +107,7 @@ PatternMatchResult ConvertTFDilatedConvOp::matchAndRewrite( if (llvm::isa(prev_op)) { if (!llvm::isa(next_op)) { // Expand/Squeeze op must come in pair. - return Pattern::matchFailure(); + return failure(); } expand_op = llvm::cast(prev_op); squeeze_op = llvm::cast(next_op); @@ -119,24 +119,24 @@ PatternMatchResult ConvertTFDilatedConvOp::matchAndRewrite( (*const_op.value().cast().getIntValues().begin()) .getSExtValue(); } else { - return Pattern::matchFailure(); + return failure(); } // Make sure that the `squeeze_dims` is equal to `expand_axis`. auto squeeze_dims = squeeze_op.squeeze_dims(); if (squeeze_dims.size() != 1 || squeeze_dims[0].cast().getInt() != expand_axis) { - return Pattern::matchFailure(); + return failure(); } // Update previous/next op pointer. prev_op = prev_op->getPrevNode(); - if (!prev_op) return Pattern::matchFailure(); + if (!prev_op) return failure(); next_op = next_op->getNextNode(); - if (!next_op) return Pattern::matchFailure(); + if (!next_op) return failure(); } // SpaceToBatchND op. - if (!llvm::isa(prev_op)) return Pattern::matchFailure(); + if (!llvm::isa(prev_op)) return failure(); // TODO(b/149936532): Check `padding` input, currently ignored. TF::SpaceToBatchNDOp stb_op = llvm::cast(prev_op); @@ -148,7 +148,7 @@ PatternMatchResult ConvertTFDilatedConvOp::matchAndRewrite( if (llvm::isa(next_op)) { pad_op = llvm::cast(next_op); next_op = next_op->getNextNode(); - if (!next_op) return Pattern::matchFailure(); + if (!next_op) return failure(); } // BatchToSpaceND + BiasAdd. @@ -160,8 +160,7 @@ PatternMatchResult ConvertTFDilatedConvOp::matchAndRewrite( // Must be BiasAdd + BatchToSpaceND. biasadd_op = llvm::cast(next_op); next_op = next_op->getNextNode(); - if (!next_op || !llvm::isa(next_op)) - return Pattern::matchFailure(); + if (!next_op || !llvm::isa(next_op)) return failure(); bts_op = llvm::cast(next_op); } else if (llvm::isa(next_op)) { // BatchToSpaceND + (optional) BiasAdd. @@ -172,12 +171,12 @@ PatternMatchResult ConvertTFDilatedConvOp::matchAndRewrite( final_op_is_bts = false; } } else { - return Pattern::matchFailure(); + return failure(); } llvm::Optional dilations_attr = ExtractDilationsAttrFromBlockShape( stb_op.block_shape(), bts_op.block_shape(), rewriter); - if (!dilations_attr.hasValue()) return Pattern::matchFailure(); + if (!dilations_attr.hasValue()) return failure(); op.setAttr("dilations", dilations_attr.getValue()); // Padding is set to 'SAME' when `stb_op` has non-zero paddings. @@ -228,7 +227,7 @@ PatternMatchResult ConvertTFDilatedConvOp::matchAndRewrite( } stb_op.getResult().dropAllUses(); - return Pattern::matchSuccess(); + return success(); } template diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index d2001db8b40..80689f7b7c4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -98,12 +98,12 @@ bool HasSameStaticShapes(Operation* op) { #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc" -#define DECL_CONVERT_OP(tf_op) \ - struct ConvertTF##tf_op##Op : public RewritePattern { \ - explicit ConvertTF##tf_op##Op(MLIRContext* context) \ - : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \ - PatternMatchResult matchAndRewrite( \ - Operation* op, PatternRewriter& rewriter) const override; \ +#define DECL_CONVERT_OP(tf_op) \ + struct ConvertTF##tf_op##Op : public RewritePattern { \ + explicit ConvertTF##tf_op##Op(MLIRContext* context) \ + : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \ + LogicalResult matchAndRewrite(Operation* op, \ + PatternRewriter& rewriter) const override; \ } // TODO(antiagainst): Define this pattern in a table-driven manner once variadic @@ -127,14 +127,14 @@ DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP -PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite( +LogicalResult ConvertTFRandomUniformOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto random_uniform_op = cast(op); if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) { - return matchFailure(); + return failure(); } if (!random_uniform_op.dtype().isF32()) { - return matchFailure(); + return failure(); } typedef tensorflow::random::UniformDistribution< tensorflow::random::PhiloxRandom, float> @@ -149,7 +149,7 @@ PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite( random_uniform_op.output().getType().dyn_cast_or_null()) { if (auto ranked_output = output_type.dyn_cast_or_null()) { if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) { - return matchFailure(); + return failure(); } num_elements = output_type.getNumElements(); size_t offset = 0; @@ -165,13 +165,13 @@ PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite( } auto output_data = DenseFPElementsAttr::get(output_type, data); rewriter.replaceOpWithNewOp(op, output_type, output_data); - return matchSuccess(); + return success(); } } - return matchFailure(); + return failure(); } -PatternMatchResult ConvertTFConcatOp::matchAndRewrite( +LogicalResult ConvertTFConcatOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); @@ -180,17 +180,17 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite( // Extract axis attribute from constant concat_dims tensor ElementsAttr axis; if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis))) - return matchFailure(); + return failure(); StringAttr fused_activation_function = StringAttr::get("NONE", rewriter.getContext()); rewriter.replaceOpWithNewOp( op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis), fused_activation_function); - return matchSuccess(); + return success(); } -PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite( +LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); @@ -198,15 +198,14 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite( auto output_type = tf_concat_op.output().getType(); // Extract axis attribute from constant axis tensor ElementsAttr axis; - if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) - return matchFailure(); + if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); StringAttr fused_activation_function = StringAttr::get("NONE", rewriter.getContext()); rewriter.replaceOpWithNewOp( op, output_type, values, ExtractSingleElementAsInteger(axis), fused_activation_function); - return matchSuccess(); + return success(); } // The following is effectively: @@ -215,11 +214,11 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite( // ConstBoolAttrTrue:$transpose_b), // (TFL_FullyConnectedOp:$__0 $a, $b, // NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>; -PatternMatchResult ConvertTFMatMulOp::matchAndRewrite( +LogicalResult ConvertTFMatMulOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_matmul_op = cast(op); - if (tf_matmul_op.transpose_a()) return matchFailure(); - if (!tf_matmul_op.transpose_b()) return matchFailure(); + if (tf_matmul_op.transpose_a()) return failure(); + if (!tf_matmul_op.transpose_b()) return failure(); Type output_type = tf_matmul_op.getResult().getType(); // TODO(jpienaar): Follow up post shuffle discussion. @@ -230,10 +229,10 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite( op->getOperand(1), no_input, rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false)); rewriter.replaceOp(op, {fc_op.getResult(0)}); - return matchSuccess(); + return success(); } -PatternMatchResult ConvertTFPackOp::matchAndRewrite( +LogicalResult ConvertTFPackOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_pack_op = cast(op); @@ -245,10 +244,10 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite( rewriter.replaceOpWithNewOp(op, output_type, values, values_count, axis); - return matchSuccess(); + return success(); } -PatternMatchResult ConvertTFReshapeOp::matchAndRewrite( +LogicalResult ConvertTFReshapeOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_reshape_op = cast(op); @@ -269,10 +268,10 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite( } rewriter.replaceOpWithNewOp(op, tf_reshape_op.output().getType(), input, shape); - return matchSuccess(); + return success(); } -PatternMatchResult ConvertTFSplitOp::matchAndRewrite( +LogicalResult ConvertTFSplitOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_split_op = cast(op); @@ -284,10 +283,10 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite( rewriter.replaceOpWithNewOp(op, output_types, tf_split_op.split_dim(), tf_split_op.value(), num_split); - return matchSuccess(); + return success(); } -PatternMatchResult ConvertTFSplitVOp::matchAndRewrite( +LogicalResult ConvertTFSplitVOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_splitv_op = cast(op); @@ -299,7 +298,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, output_types, tf_splitv_op.value(), tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split); - return matchSuccess(); + return success(); } Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter, @@ -330,7 +329,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter, return rewriter.create(op->getLoc(), type, attr); } -PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite( +LogicalResult ConvertTFStridedSliceOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_strided_slice_op = cast(op); auto ranked_input_type = @@ -352,7 +351,7 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite( tf_strided_slice_op.new_axis_mask().getSExtValue()), rewriter.getI32IntegerAttr( tf_strided_slice_op.shrink_axis_mask().getSExtValue())); - return matchSuccess(); + return success(); } int num_input_dims = ranked_input_type.getRank(); @@ -382,10 +381,10 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite( tf_strided_slice_op.new_axis_mask().getSExtValue()), rewriter.getI32IntegerAttr( tf_strided_slice_op.shrink_axis_mask().getSExtValue())); - return matchSuccess(); + return success(); } -PatternMatchResult ConvertTFUnpackOp::matchAndRewrite( +LogicalResult ConvertTFUnpackOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_unpack_op = cast(op); @@ -397,7 +396,7 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite( auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue()); rewriter.replaceOpWithNewOp(op, output_types, input, num, axis); - return matchSuccess(); + return success(); } // MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute @@ -449,25 +448,25 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) { return true; } -PatternMatchResult ConvertTFMatrixDiagV2Op::matchAndRewrite( +LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { if (ConvertTFMatrixDiagV2orV3(op, &rewriter)) - return matchSuccess(); - return matchFailure(); + return success(); + return failure(); } -PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite( +LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { if (ConvertTFMatrixDiagV2orV3(op, &rewriter)) - return matchSuccess(); - return matchFailure(); + return success(); + return failure(); } // TF Lite doesn't support Assert, we just drop the assert from the graph. -PatternMatchResult ConvertTFAssertOp::matchAndRewrite( +LogicalResult ConvertTFAssertOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { rewriter.eraseOp(op); - return matchSuccess(); + return success(); } StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, @@ -545,7 +544,7 @@ StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, return rewriter->create(loc, scalar_type, attr); } -PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite( +LogicalResult ConvertTFReciprocalOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_reciprocal_op = cast(op); @@ -553,7 +552,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite( &rewriter, op->getLoc(), tf_reciprocal_op.x().getType().cast(), 1); if (!status_or_const_op.ok()) { - return matchFailure(); + return failure(); } StringAttr fused_activation_function = @@ -562,10 +561,10 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite( rewriter.replaceOpWithNewOp(op, status_or_const_op.ValueOrDie(), tf_reciprocal_op.x(), fused_activation_function); - return matchSuccess(); + return success(); } -PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite( +LogicalResult ConvertTFBroadcastToOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_broadcast_to_op = cast(op); auto element_type = tf_broadcast_to_op.input().getType().cast(); @@ -574,7 +573,7 @@ PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite( auto status_or_const_op = CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1); if (!status_or_const_op.ok()) { - return matchFailure(); + return failure(); } auto tfl_fill_op = rewriter.create( @@ -587,7 +586,7 @@ PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, output_type, tf_broadcast_to_op.input(), tfl_fill_op, fused_activation_function); - return matchSuccess(); + return success(); } // Legalize unidirectional sequence lstm. @@ -595,11 +594,11 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context) : RewritePattern(kUnidirectionalSequenceLstm, 1, context) {} - PatternMatchResult matchAndRewrite(Operation* op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { auto tflite_indices_attr = op->getAttrOfType(kTfLiteInputIndices); - if (!tflite_indices_attr) return matchFailure(); + if (!tflite_indices_attr) return failure(); SmallVector tflite_indices; for (auto index_attr : tflite_indices_attr.getValue()) { @@ -654,7 +653,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { // Rewire the output. op->getResult(2).replaceAllUsesWith(lstm_op.getResult()); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -663,24 +662,24 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context) : RewritePattern(kUnidirectionalSequenceRnn, 1, context) {} - PatternMatchResult matchAndRewrite(Operation* op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { auto tflite_indices_attr = op->getAttrOfType(kTfLiteInputIndices); - if (!tflite_indices_attr) return matchFailure(); + if (!tflite_indices_attr) return failure(); if (op->getNumOperands() != 5) { op->emitError() << "We're expecting 5 inputs for UnidirectionalSequenceRNN, only " << op->getNumOperands() << " provided"; - return matchFailure(); + return failure(); } if (op->getNumResults() != 2) { op->emitError() << "We're expecting 2 inputs for UnidirectionalSequenceRNN, only " << op->getNumResults() << " found"; - return matchFailure(); + return failure(); } // Populate inputs. @@ -714,7 +713,7 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { op->getResult(1).replaceAllUsesWith(rnn_op.getResult()); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index a13490ddb9f..9df205d908c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -175,33 +175,33 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list, struct ConvertConst : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::ConstOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Verify that the opaque elements attribute contains tensor of type variant // and scalar shape. The variant type should hold a TensorList. auto opaque_attr = op.value().dyn_cast(); - if (!opaque_attr) return matchFailure(); + if (!opaque_attr) return failure(); tensorflow::Tensor tensor; if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok()) - return matchFailure(); - if (tensor.dtype() != tensorflow::DT_VARIANT) return matchFailure(); + return failure(); + if (tensor.dtype() != tensorflow::DT_VARIANT) return failure(); if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape())) - return matchFailure(); + return failure(); const tensorflow::TensorList *list = tensor.scalar()().get(); - if (!list) return matchFailure(); + if (!list) return failure(); // Verify output type is variant and contains exactly one ranked subtypes. auto variant_ty = getElementTypeOrSelf(op.getType()).dyn_cast(); - if (!variant_ty) return matchFailure(); + if (!variant_ty) return failure(); ArrayRef subtypes = variant_ty.getSubtypes(); - if (subtypes.size() != 1) return matchFailure(); + if (subtypes.size() != 1) return failure(); RankedTensorType list_element_ty = subtypes.front().dyn_cast(); - if (!list_element_ty) return matchFailure(); + if (!list_element_ty) return failure(); // Extract tensor elements for the TensorList and construct result type // based on the number of elements and element shape. @@ -225,9 +225,9 @@ struct ConvertConst : public OpConversionPattern { tensorflow::Tensor tensor(list->element_dtype, tensorflow::TensorShape(tf_shape)); auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter); - if (!attr_or.ok()) return matchFailure(); + if (!attr_or.ok()) return failure(); rewriter.replaceOpWithNewOp(op, attr_or.ValueOrDie()); - return matchSuccess(); + return success(); } // Extract individual tensor list element and combine them using the tf.Pack @@ -237,14 +237,14 @@ struct ConvertConst : public OpConversionPattern { values.reserve(tensors.size()); for (const tensorflow::Tensor &tensor : tensors) { auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter); - if (!attr_or.ok()) return matchFailure(); + if (!attr_or.ok()) return failure(); auto value = rewriter.create(loc, attr_or.ValueOrDie()); values.push_back(value); } rewriter.replaceOpWithNewOp( op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0)); - return matchSuccess(); + return success(); } }; @@ -264,7 +264,7 @@ struct ConvertTensorListSetItem // (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim = // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::TensorListSetItemOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -311,7 +311,7 @@ struct ConvertTensorListSetItem rewriter.replaceOpWithNewOp( op, input.getType(), scalar_zero, ArrayRef({slice1, expanded_item, slice2})); - return matchSuccess(); + return success(); } }; @@ -330,7 +330,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern { // Rewrites the original op into `tf.fill`. The result tensor shape is // [num_element, element_shape]. All the values in the result tensor will be // initialized to 0. - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( OpT op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type dtype = op.element_dtype(); @@ -342,7 +342,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern { "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "integer or 16-bit/32-bit/64-bit float type during TF Lite " "transformation pass"); - return ConversionPattern::matchFailure(); + return failure(); } Value element_shape = operands[0]; @@ -354,7 +354,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern { op.emitError( "requires element_shape to be 1D tensor during TF Lite " "transformation pass"); - return ConversionPattern::matchFailure(); + return failure(); } } @@ -434,7 +434,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern { auto zero = rewriter.create(loc, zero_type, zero_attr); rewriter.replaceOpWithNewOp(op, result_type, list_shape, zero); - return Pattern::matchSuccess(); + return success(); } }; @@ -472,7 +472,7 @@ struct ConvertTensorListPushBack : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::TensorListPushBackOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Value input_handle = operands[0]; @@ -498,7 +498,7 @@ struct ConvertTensorListPushBack rewriter.replaceOpWithNewOp( op, result_type, scalar_zero, ArrayRef({input_handle, expanded_item})); - return matchSuccess(); + return success(); } }; @@ -516,7 +516,7 @@ struct ConvertTensorListResize : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::TensorListResizeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Value input_handle = operands[0]; @@ -582,7 +582,7 @@ struct ConvertTensorListResize /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op), /*output_shapes=*/rewriter.getStrArrayAttr({"{}"}), /*is_stateless=*/rewriter.getBoolAttr(true)); - return matchSuccess(); + return success(); } private: @@ -660,14 +660,14 @@ struct ConvertTensorListGetItem : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::TensorListGetItemOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Value input = operands[0]; Value index = operands[1]; rewriter.replaceOpWithNewOp(op, op.getType(), input, index, rewriter.getBoolAttr(true)); - return matchSuccess(); + return success(); } }; @@ -675,7 +675,7 @@ struct ConvertTensorListLength : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::TensorListLengthOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -687,7 +687,7 @@ struct ConvertTensorListLength rewriter.replaceOpWithNewOp( op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0), /*validate_indices=*/true_attr); - return matchSuccess(); + return success(); } }; @@ -695,7 +695,7 @@ struct ConvertTensorListStack : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::TensorListStackOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -713,7 +713,7 @@ struct ConvertTensorListStack !matchPattern(element_shape, m_Constant(&dense_elem_attr))) { // If no constant is spotted, just forward the operand. rewriter.replaceOp(op, {input}); - return matchSuccess(); + return success(); } RankedTensorType shape_type = @@ -726,20 +726,20 @@ struct ConvertTensorListStack RankedTensorType::get(output_shape, getElementTypeOrSelf(input)); rewriter.replaceOpWithNewOp(op, result_type, input, new_shape); - return matchSuccess(); + return success(); } }; struct ConvertIdentity : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::IdentityOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Value input = operands[0]; rewriter.replaceOpWithNewOp(op, input.getType(), operands, op.getAttrs()); - return matchSuccess(); + return success(); } }; @@ -804,7 +804,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { struct ConvertWhile : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( TF::WhileOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { llvm::SmallVector result_types; @@ -828,7 +828,7 @@ struct ConvertWhile : public OpConversionPattern { UpdateFunctionTypes(cloned); rewriter.replaceOp(op, cloned.getResults()); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index bc39c0cf74b..6137abfee4f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -211,18 +211,17 @@ DenseElementsAttr GetShape(Value output_val) { struct FuseFullyConnectedAndAdd : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TFL::AddOp add_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TFL::AddOp add_op, + PatternRewriter &rewriter) const override { // Match Add. DenseElementsAttr added_value; Value constant_val = add_op.rhs(); - if (!matchPattern(constant_val, m_Constant(&added_value))) - return matchFailure(); + if (!matchPattern(constant_val, m_Constant(&added_value))) return failure(); // Match Fully Connected. auto fc_op = dyn_cast_or_null(add_op.lhs().getDefiningOp()); - if (!fc_op) return matchFailure(); + if (!fc_op) return failure(); // Check if the constant RHS is either 0D (scalar), or a 1D with // `{num_channels}` shape. @@ -236,17 +235,17 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { if (constant_val_type.getRank() == 0) { is_scalar_rhs = true; } else if (constant_val_type.getRank() != 1) { - return matchFailure(); + return failure(); } Value filter = fc_op.filter(); Value bias = fc_op.bias(); ElementsAttr bias_value; const bool is_none_bias = bias.getType().isa(); - if (fc_op.fused_activation_function() != "NONE") return matchFailure(); + if (fc_op.fused_activation_function() != "NONE") return failure(); if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value))) - return matchFailure(); + return failure(); // Rewrite Location loc = fc_op.getLoc(); @@ -261,7 +260,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { // Filter must be a `2D` tensor with `{num_channels, num_features}` // shape. The following check is rejecting unknown rank (-1). if (filter_type.getRank() != 2) { - return matchFailure(); + return failure(); } int num_channels = filter_type.getShape()[0]; @@ -297,7 +296,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()), /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims())); - return matchSuccess(); + return success(); } }; @@ -305,13 +304,13 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { struct FuseFullyConnectedAndRelu : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TFL::ReluOp relu_op, + PatternRewriter &rewriter) const override { Operation *input = relu_op.getOperand().getDefiningOp(); - if (!isa_and_nonnull(input)) return matchFailure(); + if (!isa_and_nonnull(input)) return failure(); auto fully_connected_op = cast(input); if (fully_connected_op.fused_activation_function() != "NONE") - return matchFailure(); + return failure(); auto new_activation_func = rewriter.getStringAttr("RELU"); auto new_weights_format = @@ -323,7 +322,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern { fully_connected_op.filter(), fully_connected_op.bias(), new_activation_func, new_weights_format, new_keep_num_dims); - return matchSuccess(); + return success(); } }; @@ -332,25 +331,25 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern { struct FuseFullyConnectedAndMul : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TFL::MulOp mul_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TFL::MulOp mul_op, + PatternRewriter &rewriter) const override { // Mul. DenseElementsAttr cst; Value constant_val = mul_op.rhs(); - if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure(); + if (!matchPattern(constant_val, m_Constant(&cst))) return failure(); // Fully Connected. auto fc_op = dyn_cast_or_null(mul_op.lhs().getDefiningOp()); - if (!fc_op) return matchFailure(); + if (!fc_op) return failure(); Value filter = fc_op.filter(); Value bias = fc_op.bias(); ElementsAttr cst_tmp; - if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure(); + if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure(); if (!bias.getType().isa() && !matchPattern(bias, m_Constant(&cst_tmp))) - return matchFailure(); - if (fc_op.fused_activation_function() != "NONE") return matchFailure(); + return failure(); + if (fc_op.fused_activation_function() != "NONE") return failure(); // Broadcast the constant operand of Mul if it isn't compatible to the // filter input. We only support broadcasting the operand along the depth @@ -365,7 +364,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { normalized_shape, cst.getType().getElementType())); Type new_type = new_cst.getType(); if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) { - return matchFailure(); + return failure(); } auto new_op = rewriter.create(mul_op.getLoc(), new_type, new_cst); @@ -393,7 +392,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()), /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims())); - return matchSuccess(); + return success(); } }; @@ -425,36 +424,36 @@ template struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TFL::MulOp mul_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TFL::MulOp mul_op, + PatternRewriter &rewriter) const override { // Mul. Required 1-D rhs for batch normalization. DenseElementsAttr gamma_cst; Value gamma = mul_op.rhs(); - if (!matchPattern(gamma, m_Constant(&gamma_cst))) return matchFailure(); - if (gamma_cst.getType().getRank() != 1) return matchFailure(); + if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure(); + if (gamma_cst.getType().getRank() != 1) return failure(); // Affine op Operation *mul_op_lhs = mul_op.lhs().getDefiningOp(); auto fc_op = dyn_cast_or_null(mul_op_lhs); - if (!fc_op) return matchFailure(); + if (!fc_op) return failure(); Value filter = fc_op.filter(); Value bias = fc_op.bias(); // QDQs auto dq_op = dyn_cast_or_null(filter.getDefiningOp()); - if (!dq_op) return matchFailure(); + if (!dq_op) return failure(); auto q_op = dyn_cast_or_null(dq_op.input().getDefiningOp()); - if (!q_op) return matchFailure(); + if (!q_op) return failure(); filter = q_op.input(); // weight constant ElementsAttr cst_tmp; - if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure(); + if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure(); if (!bias.getType().isa() && !matchPattern(bias, m_Constant(&cst_tmp))) - return matchFailure(); - if (fc_op.fused_activation_function() != "NONE") return matchFailure(); + return failure(); + if (fc_op.fused_activation_function() != "NONE") return failure(); // Broadcast the constant operand of Mul if it isn't compatible to the // filter input. We only support broadcasting the operand along the depth @@ -469,7 +468,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst); broadcasted_gamma = rewriter.create(loc, mul_rhs); } else { - return matchFailure(); + return failure(); } // Rewrite filter constant. Since the folder of TFL::MulOp couldn't @@ -478,7 +477,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { rewriter.create(loc, filter, broadcasted_gamma).z(); // Update the scale in the quantize op. auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst); - if (!new_qtype) return matchFailure(); + if (!new_qtype) return failure(); rewriter.replaceOpWithNewOp(q_op, new_qtype.getValue(), new_filter, new_qtype); @@ -491,7 +490,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { // Remove the tailing mul op. mul_op.replaceAllUsesWith(fc_op.getResult()); - return matchSuccess(); + return success(); } }; @@ -504,20 +503,19 @@ template struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineOpType fc_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineOpType fc_op, + PatternRewriter &rewriter) const override { // Binary op. Operation *binary_op = fc_op.input().getDefiningOp(); - if (!binary_op || binary_op->getNumOperands() != 2) - return this->matchFailure(); + if (!binary_op || binary_op->getNumOperands() != 2) return failure(); // We only handle the cases the RHS is a scalar. // TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that // the constant operands are on the RHS, we need to consider LHS constant // operand if necessary. DenseFPElementsAttr cst; if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst))) - return this->matchFailure(); - if (cst.getNumElements() != 1) return this->matchFailure(); + return failure(); + if (cst.getNumElements() != 1) return failure(); APFloat cst_value = *cst.float_value_begin(); // Affine op. @@ -527,21 +525,21 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { if (!matchPattern(filter, m_Constant(&filter_cst))) { // The filter maybe quantized, then we should set it to the real constant. auto dq = llvm::dyn_cast_or_null(filter.getDefiningOp()); - if (!dq) return this->matchFailure(); + if (!dq) return failure(); auto q = llvm::dyn_cast_or_null(dq.input().getDefiningOp()); if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) { - return this->matchFailure(); + return failure(); } filter = q.input(); } if (!bias.getType().isa() && !matchPattern(bias, m_Constant(&bias_cst))) - return this->matchFailure(); + return failure(); ShapedType filter_type = filter_cst.getType(); if (llvm::isa(binary_op) || llvm::isa(binary_op)) { auto padding = fc_op.template getAttrOfType("padding"); - if (padding && padding.getValue() != "VALID") return this->matchFailure(); + if (padding && padding.getValue() != "VALID") return failure(); // The fusion of add/sub is actually applying the following // transformation: @@ -568,7 +566,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { bias_cst.float_value_begin(), bias_cst.float_value_end()); } else { - return this->matchFailure(); + return failure(); } int64_t flatten_index = 0; @@ -610,9 +608,9 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { fc_op.setOperand(1, new_filter_op); } } else { - return this->matchFailure(); + return failure(); } - return this->matchSuccess(); + return success(); } private: @@ -638,18 +636,17 @@ struct ConvertTrivialTransposeOpToReshapeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TFL::TransposeOp transpose_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op, + PatternRewriter &rewriter) const override { auto input_type = transpose_op.x().getType().cast(); auto output_type = transpose_op.y().getType().cast(); // It's possible to know if the transformation is safe only if the input // & output shapes are fully known and permutation is a constant. if (!input_type.hasStaticShape() || !output_type.hasStaticShape()) - return matchFailure(); + return failure(); Value perm = transpose_op.perm(); DenseElementsAttr perm_values_attr; - if (!matchPattern(perm, m_Constant(&perm_values_attr))) - return matchFailure(); + if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure(); auto input_shape = input_type.getShape(); SmallVector perm_values; @@ -674,7 +671,7 @@ struct ConvertTrivialTransposeOpToReshapeOp } } if (old_major_index_ordering != new_major_index_ordering) { - return matchFailure(); + return failure(); } // Rewrite. @@ -693,7 +690,7 @@ struct ConvertTrivialTransposeOpToReshapeOp rewriter.replaceOpWithNewOp( transpose_op, transpose_op.y().getType(), transpose_op.x(), new_shape); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 83ecf0be820..1c598fec08e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -75,14 +75,14 @@ class FoldIfOp : public OpRewritePattern { explicit FoldIfOp(MLIRContext* context, FuncSet* inlined_funcs) : OpRewritePattern(context), inlined_funcs_(inlined_funcs) {} - PatternMatchResult matchAndRewrite(TF::IfOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(TF::IfOp op, + PatternRewriter& rewriter) const override { // This pattern is restricted to if ops in functions with exactly one block // and therefore one terminator op. So, that function return type can be // updated if operands' shapes change after inlining. Without this // restriction, it would require tensor cast ops. FuncOp parent_op = op.getParentOfType(); - if (parent_op.getBlocks().size() != 1) return matchFailure(); + if (parent_op.getBlocks().size() != 1) return failure(); // Find the then and else branch functions. SymbolTable table(op.getParentOfType()); @@ -98,18 +98,18 @@ class FoldIfOp : public OpRewritePattern { inlined_funcs_->insert(then_branch); inlined_funcs_->insert(else_branch); rewriter.eraseOp(op.getOperation()); - return matchSuccess(); + return success(); } // Extract the constant cond value. DenseElementsAttr cond; - if (!matchPattern(op.cond(), m_Constant(&cond))) return matchFailure(); + if (!matchPattern(op.cond(), m_Constant(&cond))) return failure(); // TODO(hinsu): Handle constants that are not scalar booleans. auto cond_type = cond.getType().dyn_cast(); if (!cond_type || !cond_type.getShape().equals({}) || !cond_type.getElementType().isInteger(/*width=*/1)) - return matchFailure(); + return failure(); // Identify the branch to inline. bool cond_value = (*cond.int_value_begin()).getSExtValue(); @@ -118,7 +118,7 @@ class FoldIfOp : public OpRewritePattern { // Make sure that the function has exactly one block to simplify inlining. // TFLite doesn't use control flow with blocks so functions with more than // one blocks are not encountered in practice. - if (func.getBody().getBlocks().size() != 1) return matchFailure(); + if (func.getBody().getBlocks().size() != 1) return failure(); BlockAndValueMapping mapper; for (int i = 0, e = func.getNumArguments(); i != e; ++i) @@ -149,7 +149,7 @@ class FoldIfOp : public OpRewritePattern { // of the function. inlined_funcs_->insert(then_branch); inlined_funcs_->insert(else_branch); - return matchSuccess(); + return success(); } private: diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 7592f462f6b..1ff321780a4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -121,12 +121,12 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp MLIRContext *ctx) : OpRewritePattern(ctx) {} - PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, + PatternRewriter &rewriter) const override { // We don't want to insert quantize/dequantize if the quantize op exists. auto res = tf_op.outputs(); if (!res.hasOneUse() || isa(*res.user_begin())) - return this->matchFailure(); + return failure(); // Extract the min/max constant values from the operands. We also consider // a special case that there are tf.Identity ops between the min/max @@ -137,8 +137,8 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp min = id1.input(); if (auto id2 = dyn_cast_or_null(max.getDefiningOp())) max = id2.input(); - if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure(); - if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure(); + if (!matchPattern(min, m_Constant(&min_value))) return failure(); + if (!matchPattern(max, m_Constant(&max_value))) return failure(); int quant_dim = -1; if (PerAxis) { @@ -155,7 +155,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp TypeAttr qtype = quant::GetQuantizedTypeAttr( rewriter, res_type, min_value, max_value, quant_dim, num_bits, narrow_range, /*is_signed=*/false); - if (!qtype) this->matchFailure(); + if (!qtype) failure(); // Finally, use the quantization parameter to create the quantize and // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp @@ -168,7 +168,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp value.replaceAllUsesWith(dequantize); quantize.getOperation()->replaceUsesOfWith(dequantize, value); - return this->matchSuccess(); + return success(); } }; @@ -208,8 +208,8 @@ struct ConvertTFConvOp : public RewritePattern { : RewritePattern(TFConvOpType::getOperationName(), 1, context), intAttrOne(Builder(context).getI32IntegerAttr(1)) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { // Assumes TensorFlow convolution op is already verified to be // in valid form. @@ -223,10 +223,10 @@ struct ConvertTFConvOp : public RewritePattern { TFConvOpType tf_op = cast(op); if (!TFTypeIsFloatTensor(tf_op.input()) || !TFDataFormatIsNHWC(op)) - return matchFailure(); + return failure(); IntegerAttr height, width; - if (!TFIntListIs1XY1(op, "strides", &height, &width)) return matchFailure(); + if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure(); ConvertTFConvOpMatchState state; state.stride_height = height; @@ -242,14 +242,14 @@ struct ConvertTFConvOp : public RewritePattern { state.dilation_width_factor = intAttrOne; } - if (!TFPaddingIsSameOrValid(op, &state.padding)) return matchFailure(); + if (!TFPaddingIsSameOrValid(op, &state.padding)) return failure(); // Additionally, we require the filter operand to be of 4-D tensor type so // that we can extract info from the shape (e.g., for constructing bias // tensor, for setting depth_multiplier attribute, etc.). auto filter = tf_op.filter(); auto filter_type = filter.getType().template dyn_cast(); - if (!filter_type || filter_type.getRank() != 4) return matchFailure(); + if (!filter_type || filter_type.getRank() != 4) return failure(); // TensorFlow convolution op only has two inputs, while the TFLite one has // three, with the bias vector marked as optional. However, TOCO has a @@ -274,7 +274,7 @@ struct ConvertTFConvOp : public RewritePattern { bias); rewriter.replaceOp(op, conv_op.getResult()); - return matchSuccess(); + return success(); } const IntegerAttr intAttrOne; @@ -418,8 +418,8 @@ struct ConvertTFStridedSlice : public RewritePattern { explicit ConvertTFStridedSlice(MLIRContext *context) : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {} - PatternMatchResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask, - PatternRewriter &rewriter) const { + LogicalResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask, + PatternRewriter &rewriter) const { TF::StridedSliceOp strided_slice_op = llvm::cast(op); // Insert a new reshape op. @@ -474,11 +474,11 @@ struct ConvertTFStridedSlice : public RewritePattern { rewriter.getI64IntegerAttr(0), rewriter.getIntegerAttr(attribute_type, strided_slice_op.shrink_axis_mask())); - return matchSuccess(); + return success(); } - PatternMatchResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask, - PatternRewriter &rewriter) const { + LogicalResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask, + PatternRewriter &rewriter) const { TF::StridedSliceOp strided_slice_op = llvm::cast(op); DenseIntElementsAttr begin_dense_elem_attr; @@ -486,7 +486,7 @@ struct ConvertTFStridedSlice : public RewritePattern { auto begin_ranked_attr_type = begin.getType().dyn_cast(); if (!begin_ranked_attr_type || !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) { - return matchFailure(); + return failure(); } DenseIntElementsAttr end_dense_elem_attr; @@ -494,7 +494,7 @@ struct ConvertTFStridedSlice : public RewritePattern { auto end_ranked_attr_type = end.getType().dyn_cast(); if (!end_ranked_attr_type || !matchPattern(end, m_Constant(&end_dense_elem_attr))) { - return matchFailure(); + return failure(); } DenseIntElementsAttr stride_dense_elem_attr; @@ -503,7 +503,7 @@ struct ConvertTFStridedSlice : public RewritePattern { stride.getType().dyn_cast(); if (!stride_ranked_attr_type || !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) { - return matchFailure(); + return failure(); } Value input = strided_slice_op.input(); @@ -516,7 +516,7 @@ struct ConvertTFStridedSlice : public RewritePattern { const ArrayRef begin_shape = begin_type.getShape(); const int begin_dim = begin_shape.size(); - if (begin_dim != 1) return matchFailure(); + if (begin_dim != 1) return failure(); const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1; @@ -586,11 +586,11 @@ struct ConvertTFStridedSlice : public RewritePattern { strided_slice_op.new_axis_mask()), rewriter.getIntegerAttr(attribute_type, strided_slice_op.shrink_axis_mask())); - return matchSuccess(); + return success(); } - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { // TODO(renjieliu): Consider expand the transformation for shrink // mask as well. TF::StridedSliceOp strided_slice_op = llvm::cast(op); @@ -606,7 +606,7 @@ struct ConvertTFStridedSlice : public RewritePattern { if (ellipsis_mask != 0) { return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter); } - return matchFailure(); + return failure(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 38fb3154c48..163f4562d49 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -514,16 +514,16 @@ namespace { struct DropEmptyLaunch : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(LaunchOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(LaunchOp op, + PatternRewriter& rewriter) const override { Block& block = op.GetBody(); // Check if launch only has a return. - if (&block.front() != &block.back()) return matchFailure(); + if (&block.front() != &block.back()) return failure(); // Map launch results to return operands. rewriter.replaceOp(op, block.front().getOperands()); - return matchSuccess(); + return success(); } }; } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 85d87a56f01..36b747b7fb7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -1067,16 +1067,16 @@ bool HasSingleOpInBlock(Block *block) { struct DropEmptyGraph : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(GraphOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(GraphOp op, + PatternRewriter &rewriter) const override { Block &block = op.GetBody(); // Check if graph only has one fetch. - if (&block.front() != &block.back()) return matchFailure(); + if (&block.front() != &block.back()) return failure(); // Map graph results to fetch operands. rewriter.replaceOp(op, op.GetFetch().fetches()); - return matchSuccess(); + return success(); } }; @@ -1086,11 +1086,11 @@ struct DropEmptyGraph : public OpRewritePattern { struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(GraphOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(GraphOp op, + PatternRewriter &rewriter) const override { Block &block = op.GetBody(); // Check if graph only has one island. - if (!HasSingleOpInBlock(&block)) return matchFailure(); + if (!HasSingleOpInBlock(&block)) return failure(); FetchOp fetch_op = op.GetFetch(); auto island_op = llvm::cast(block.front()); @@ -1120,7 +1120,7 @@ struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern { std::prev(island_body.end())); rewriter.replaceOp(op, new_rets); - return matchSuccess(); + return success(); } }; } // anonymous namespace @@ -1142,18 +1142,18 @@ struct DropEmptyIslandNoOperandNoDataResult : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(IslandOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(IslandOp op, + PatternRewriter &rewriter) const override { if (op.getNumOperands() != 0 || op.getNumResults() != 1 || !HasSingleOpInBlock(&op.GetBody())) - return matchFailure(); + return failure(); for (auto &use : llvm::make_early_inc_range(op.control().getUses())) use.getOwner()->eraseOperand(use.getOperandNumber()); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -1165,16 +1165,16 @@ struct DropEmptyIslandNoOperandOneDataResult : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(IslandOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(IslandOp op, + PatternRewriter &rewriter) const override { if (op.getNumOperands() != 0 || op.getNumResults() != 2 || !op.control().use_empty() || !HasSingleOpInBlock(&op.GetBody())) - return matchFailure(); + return failure(); rewriter.replaceOp(op, {op.GetYield().getOperand(0), nullptr}); - return matchSuccess(); + return success(); } }; @@ -1199,16 +1199,16 @@ namespace { struct DropEmptyControlTrigger : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ControlTriggerOp op, - PatternRewriter &rewriter) const override { - if (op.getNumOperands() != 0) return matchFailure(); + LogicalResult matchAndRewrite(ControlTriggerOp op, + PatternRewriter &rewriter) const override { + if (op.getNumOperands() != 0) return failure(); for (auto &use : llvm::make_early_inc_range(op.control().getUses())) use.getOwner()->eraseOperand(use.getOperandNumber()); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 9cec3641d0a..008b18aafd5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -580,16 +580,16 @@ namespace { struct AssertWithTrue : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AssertOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AssertOp op, + PatternRewriter &rewriter) const override { ElementsAttr cst; if (matchPattern(op.condition(), m_Constant(&cst))) { if (cst.getValue({}).getValue()) { rewriter.eraseOp(op); - return matchSuccess(); + return success(); } } - return matchFailure(); + return failure(); } }; } // namespace @@ -3085,15 +3085,15 @@ namespace { // function and can be removed. class ToBoolOfZeroDBoolTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ToBoolOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(ToBoolOp op, + PatternRewriter &rewriter) const override { if (auto type = op.getOperand().getType().dyn_cast()) { if (type.getRank() == 0 && type.getElementType().isInteger(1)) { rewriter.replaceOp(op, op.getOperand()); - return matchSuccess(); + return success(); } } - return matchFailure(); + return failure(); } }; } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index 6cd82d1472d..0663ad8c52e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -60,7 +60,7 @@ void BatchMatMulToEinsumPass::runOnFunction() { } // namespace template -PatternMatchResult +LogicalResult ConvertTFBatchMatMulToEinsumOp::matchAndRewrite( BatchMatMulOpType op, PatternRewriter& rewriter) const { Value input_lhs = op.x(); @@ -68,18 +68,18 @@ ConvertTFBatchMatMulToEinsumOp::matchAndRewrite( if (!input_lhs.getType().isa()) { // LHS must be a ranked tensor type - return this->matchFailure(); + return failure(); } if (!input_rhs.getType().isa()) { // RHS must be a ranked tensor type - return this->matchFailure(); + return failure(); } auto lhs_type = input_lhs.getType().dyn_cast(); auto rhs_type = input_rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) { - return this->matchFailure(); + return failure(); } auto lhs_shape = lhs_type.getShape(); @@ -92,7 +92,7 @@ ConvertTFBatchMatMulToEinsumOp::matchAndRewrite( const int dims_b = rhs_shape.size(); if (dims_a < 2 || dims_b < 2) { // Both inputs must have rank >= 2 - return this->matchFailure(); + return failure(); } // einsum equation for batchmatmul @@ -110,7 +110,7 @@ ConvertTFBatchMatMulToEinsumOp::matchAndRewrite( /*inputs=*/ValueRange(inputs), /*equation=*/equation); - return this->matchSuccess(); + return success(); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h index cd836892ae9..b0a1b59fb94 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h @@ -32,7 +32,7 @@ class ConvertTFBatchMatMulToEinsumOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( BatchMatMulOpType op, PatternRewriter& rewriter) const override; // NOLINT }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 5410ce4faf7..833b52e3e89 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -167,7 +167,7 @@ TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, } // namespace -PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( +LogicalResult ConvertTFEinsumOp::matchAndRewrite( TF::EinsumOp op, PatternRewriter& rewriter) const { Type output_type = op.getResult().getType(); Value lhs = op.getOperand(0); @@ -176,11 +176,11 @@ PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( if (!lhs.getType().isa()) { // LHS must be a ranked tensor type - return matchFailure(); + return failure(); } if (!rhs.getType().isa()) { // RHS must be a ranked tensor type - return matchFailure(); + return failure(); } auto lhs_type = lhs.getType().cast(); @@ -190,14 +190,14 @@ PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( // Currently only support static shapes. if (!(lhs_type.hasStaticShape() && rhs_type.hasStaticShape())) { - return matchFailure(); + return failure(); } // Currently support use cases of LHS, RHS dims = 3 or 4 const int dims_lhs = lhs_shape.size(); const int dims_rhs = rhs_shape.size(); if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) { - return matchFailure(); + return failure(); } EinsumEquation einsum_eqn = tokenizeAndParse(op.equation()); @@ -207,7 +207,7 @@ PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); rewriter.replaceOp(op, bmm_op.getResult()); - return matchSuccess(); + return success(); } if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) { // Case "BFD,DNH->BFNH" @@ -235,7 +235,7 @@ PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim1, rhs_dim2}, bmm_element_type, loc, &rewriter); rewriter.replaceOp(op, {final_reshape.getResult()}); - return matchSuccess(); + return success(); } if (einsum_eqn == EinsumEquation::FourDMatrixDotProd) { // Case "BFND,NDH->BFH" @@ -259,7 +259,7 @@ PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( loc, ArrayRef{output_type}, reshaped_lhs, reshaped_rhs, rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); rewriter.replaceOp(op, {bmm_op.getResult()}); - return matchSuccess(); + return success(); } if (einsum_eqn == EinsumEquation::FourDBatchMatMul) { // Case "BFNH,BTNH->BNFT" @@ -271,9 +271,9 @@ PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); rewriter.replaceOp(op, {bmm_op.getResult()}); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } // Transform Einsum to other TF Ops for the supported variants. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h index 77b0c72aaef..734d22432a1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h @@ -45,8 +45,8 @@ struct ConvertTFEinsumOp : public OpRewritePattern { explicit ConvertTFEinsumOp(MLIRContext* context) : OpRewritePattern(context) {} - PatternMatchResult matchAndRewrite(TF::EinsumOp op, - PatternRewriter& rewriter) const override; + LogicalResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter& rewriter) const override; }; } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index 0a8d261ee39..de830d879dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -60,10 +60,10 @@ class GpuOpFusionPass : public FunctionPass { struct ReluToFusedBatchNorm : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ReluOp relu_op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(ReluOp relu_op, + PatternRewriter &rewriter) const override { Operation *relu_input = relu_op.features().getDefiningOp(); - if (!relu_input) return matchFailure(); + if (!relu_input) return failure(); auto batch_norm = dyn_cast_or_null(relu_input); AddV2Op add_op; Value side_input; @@ -71,7 +71,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { // We don't have a FusedBatchNorm as input to the ReLu, but we can get // through an AddV2 as well. add_op = dyn_cast_or_null(relu_input); - if (!add_op) return matchFailure(); + if (!add_op) return failure(); batch_norm = dyn_cast_or_null(add_op.x().getDefiningOp()); @@ -81,13 +81,13 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { // Didn't get a FusedBatchNorm on the LHS of the AddV2, try the RHS. batch_norm = dyn_cast_or_null(add_op.y().getDefiningOp()); - if (!batch_norm) return matchFailure(); + if (!batch_norm) return failure(); side_input = add_op.x(); } } assert(batch_norm); - if (batch_norm.is_training()) return matchFailure(); - if (!batch_norm.y().hasOneUse()) return matchFailure(); + if (batch_norm.is_training()) return failure(); + if (!batch_norm.y().hasOneUse()) return failure(); // Build the newly fused operation to replace the batch norm OperationState state(batch_norm.getLoc(), @@ -110,7 +110,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { rewriter.replaceOp(add_op, op->getResult(0)); } - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index e5676239e93..9268881cb71 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -125,12 +125,11 @@ class LowerAddNOp : public OpRewritePattern { explicit LowerAddNOp(MLIRContext *context) : OpRewritePattern(context) {} - PatternMatchResult matchAndRewrite(TF::AddNOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::AddNOp op, + PatternRewriter &rewriter) const override { // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't // support variant type so variant types require special handling. - if (getElementTypeOrSelf(op.getType()).isa()) - return matchFailure(); + if (getElementTypeOrSelf(op.getType()).isa()) return failure(); // TODO(hinsu): Improve parallelism by splitting operands in two halves and // accumulating them first. @@ -140,7 +139,7 @@ class LowerAddNOp : public OpRewritePattern { } rewriter.replaceOp(op, result); - return matchSuccess(); + return success(); } }; @@ -176,13 +175,13 @@ class LowerDynamicStitchOp : public OpRewritePattern { explicit LowerDynamicStitchOp(MLIRContext *context) : OpRewritePattern(context) {} - PatternMatchResult matchAndRewrite(DynamicStitchOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(DynamicStitchOp op, + PatternRewriter &rewriter) const override { // Static output type is used to compute intermediate values. Note that the // output type doesn't have to be static but if input types and indices are // constant, then the output type can be statically determined. RankedTensorType out_ty = op.getType().dyn_cast(); - if (!out_ty || !out_ty.hasStaticShape()) return matchFailure(); + if (!out_ty || !out_ty.hasStaticShape()) return failure(); // Extract out all the constant indices' attributes and verify that data // types are static. @@ -193,11 +192,11 @@ class LowerDynamicStitchOp : public OpRewritePattern { Value data = std::get<1>(it); DenseIntElementsAttr index_attr; - if (!matchPattern(index, m_Constant(&index_attr))) return matchFailure(); + if (!matchPattern(index, m_Constant(&index_attr))) return failure(); indices.push_back(index_attr); RankedTensorType data_ty = data.getType().dyn_cast(); - if (!data_ty || !data_ty.hasStaticShape()) return matchFailure(); + if (!data_ty || !data_ty.hasStaticShape()) return failure(); } // Compute type of each of the items and shape to use while reshaping inputs @@ -235,7 +234,7 @@ class LowerDynamicStitchOp : public OpRewritePattern { auto axis = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp(op, op.getType(), values, axis); - return matchSuccess(); + return success(); } }; @@ -266,15 +265,15 @@ class LowerInvertPermutationOp explicit LowerInvertPermutationOp(MLIRContext *context) : OpRewritePattern(context) {} - PatternMatchResult matchAndRewrite(TF::InvertPermutationOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::InvertPermutationOp op, + PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto x_type = op.x().getType().cast(); Type int_type = x_type.getElementType(); // Could be i32 or i64. // x input must have static shape. if (!x_type.hasStaticShape()) { - return matchFailure(); + return failure(); } auto result_type = x_type; @@ -298,7 +297,7 @@ class LowerInvertPermutationOp rewriter.replaceOpWithNewOp( op, result_type, op.x(), indices, updates); - return matchSuccess(); + return success(); } }; @@ -317,8 +316,8 @@ class LowerPackOp : public OpRewritePattern { explicit LowerPackOp(MLIRContext *context) : OpRewritePattern(context) {} - PatternMatchResult matchAndRewrite(TF::PackOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::PackOp op, + PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto axis_value = rewriter.create( loc, @@ -344,7 +343,7 @@ class LowerPackOp : public OpRewritePattern { rewriter.replaceOpWithNewOp(op, op.getType(), expanded_inputs, axis_value); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index 27939cba63c..c6223ed13f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -186,18 +186,18 @@ TF::PackOp ConvertTFBatchMatMulOp::createMatMulOps( } template -PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( +LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( BatchMatMulOpType op, PatternRewriter& rewriter) const { Value input_lhs = op.x(); Value input_rhs = op.y(); if (!input_lhs.getType().isa()) { // LHS must be a ranked tensor type - return this->matchFailure(); + return failure(); } if (!input_rhs.getType().isa()) { // RHS must be a ranked tensor type - return this->matchFailure(); + return failure(); } auto lhs_type = input_lhs.getType().cast(); @@ -207,7 +207,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( if (element_type != rhs_type.getElementType()) { // The element type of LHS must be the same with element type of RHS - return this->matchFailure(); + return failure(); } auto lhs_shape = lhs_type.getShape(); @@ -220,7 +220,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( const int dims_b = rhs_shape.size(); if (dims_a < 2 || dims_b < 2) { // Both inputs must have rank >= 2 - return this->matchFailure(); + return failure(); } // Transpose LHS input if necessary. @@ -241,7 +241,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) { // Input dimensions must be compatible for multiplication. - return this->matchFailure(); + return failure(); } if (dims_a == 2 && dims_b == 2) { @@ -254,19 +254,19 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( /*b=*/input_rhs, /*transpose_a=*/false_attr, /*transpose_b=*/false_attr); - return this->matchSuccess(); + return success(); } // Input dimensions must be defined. MatMulBCast does not support partial // shapes. for (auto dim : lhs_shape) { if (dim == -1) { - return this->matchFailure(); + return failure(); } } for (auto dim : rhs_shape) { if (dim == -1) { - return this->matchFailure(); + return failure(); } } // Ensure that batch shapes are broadcastable. @@ -277,7 +277,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( if (!bcast.IsValid()) { // Input batch dimensions must be broadcastable - return this->matchFailure(); + return failure(); } // Compute slices for each batch in the LHS and RHS. @@ -302,7 +302,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( auto reshape_op = createReshapeOp(pack_op.output(), result_shape, element_type, loc, rewriter); rewriter.replaceOp(op, reshape_op.output()); - return this->matchSuccess(); + return success(); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h index f3dc6d10503..c725930a484 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h @@ -49,8 +49,8 @@ class ConvertTFBatchMatMulOp : public OpRewritePattern { int rows, int cols, Type element_type, Location loc, PatternRewriter& rewriter); - PatternMatchResult matchAndRewrite(BatchMatMulOpType op, - PatternRewriter& rewriter) const override; + LogicalResult matchAndRewrite(BatchMatMulOpType op, + PatternRewriter& rewriter) const override; }; } // namespace TF diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 023ab46a66f..87f6eaecc52 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -500,21 +500,21 @@ struct ExtractElementFromScalarsToDimensionTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ExtractElementOp extract, - PatternRewriter& rewriter) const override { - if (extract.indices().size() != 1) return matchFailure(); + LogicalResult matchAndRewrite(ExtractElementOp extract, + PatternRewriter& rewriter) const override { + if (extract.indices().size() != 1) return failure(); if (auto scalars_to_tensor = dyn_cast_or_null( extract.aggregate().getDefiningOp())) { APInt index; if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) { - return matchFailure(); + return failure(); } rewriter.replaceOp(extract, scalars_to_tensor.getOperand(index.getZExtValue())); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index cc6ca472c23..eb6e7e1cd3d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -123,7 +123,7 @@ class HloToLhloOpConverter : public ConversionPattern { explicit HloToLhloOpConverter(MLIRContext* context) : ConversionPattern(HloOpTy::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { const auto& original_results = op->getResults(); @@ -132,7 +132,7 @@ class HloToLhloOpConverter : public ConversionPattern { RankedTensorType resultType = result.value().getType().dyn_cast(); if (!resultType) { - return matchFailure(); + return failure(); } if (resultType.hasStaticShape()) { buffer_args.push_back( @@ -140,10 +140,10 @@ class HloToLhloOpConverter : public ConversionPattern { } else { SmallVector results_shape; auto shape_type_op = dyn_cast(op); - if (!shape_type_op) return matchFailure(); + if (!shape_type_op) return failure(); if (failed( shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) - return matchFailure(); + return failure(); buffer_args.push_back(InsertDynamicAllocAndDealloc( op->getLoc(), result.value(), results_shape.front(), &rewriter)); } @@ -151,7 +151,7 @@ class HloToLhloOpConverter : public ConversionPattern { rewriter.create>(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); - return matchSuccess(); + return success(); } }; @@ -160,13 +160,13 @@ struct HloToLhloDynamicBroadcastInDimOpConverter public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); auto broadcast_dimensions = op.broadcast_dimensions(); if (!broadcast_dimensions.hasValue()) { - return matchFailure(); + return failure(); } Value resultBuffer = InsertDynamicAllocAndDealloc( loc, op.getResult(), op.output_dimensions(), &rewriter); @@ -175,7 +175,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter rewriter.replaceOp(op, {resultBuffer}); - return matchSuccess(); + return success(); } }; @@ -184,16 +184,16 @@ struct HloToLhloReduceOpConverter public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( xla_hlo::ReduceOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); // TODO(b/137624192) Implement variadic reduce. - if (op.getNumResults() != 1) return matchFailure(); + if (op.getNumResults() != 1) return failure(); if (op.getParentRegion()->getBlocks().size() != 1) { op.emitOpError() << "tensor to buffer conversion expects a single block " "in the region containing the operation"; - return matchFailure(); + return failure(); } const auto& original_results = op.getResults(); SmallVector buffer_args(operands.begin(), operands.end()); @@ -230,7 +230,7 @@ struct HloToLhloReduceOpConverter rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); - return matchSuccess(); + return success(); } }; @@ -238,11 +238,11 @@ class HloToLhloTensorLoadOpConverter : public ConversionPattern { public: explicit HloToLhloTensorLoadOpConverter(MLIRContext* context) : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOp(op, operands); - return matchSuccess(); + return success(); } }; @@ -252,12 +252,12 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { explicit HloToLhloTensorStoreOpConverter(MLIRContext* context) : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOpWithNewOp( op, llvm::None, operands.front(), operands.back()); - return matchSuccess(); + return success(); } }; @@ -373,13 +373,13 @@ class HloToLhloFuncOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { if (funcOp.getBody().getBlocks().size() > 1) { funcOp.emitOpError() << "tensor to buffer conversion expects a single " "block in the region containing the operation"; - return matchFailure(); + return failure(); } auto funcType = funcOp.getType(); @@ -396,7 +396,7 @@ class HloToLhloFuncOpConverter : public OpConversionPattern { rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); rewriter.applySignatureConversion(&funcOp.getBody(), conversion); }); - return matchSuccess(); + return success(); } }; @@ -406,7 +406,7 @@ class StdToLhloReturnOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( mlir::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto numReturnValues = returnOp.getNumOperands(); @@ -426,7 +426,7 @@ class StdToLhloReturnOpConverter : public OpConversionPattern { if (dealloc == nullptr) { returnOp.emitOpError() << "Missing dealloc for operand " << operand.index(); - return matchFailure(); + return failure(); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(dealloc); @@ -434,7 +434,7 @@ class StdToLhloReturnOpConverter : public OpConversionPattern { funcOp.getArgument(returnArgNumber)); } rewriter.replaceOpWithNewOp(returnOp); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 7d4b17ef291..2fc98ebd676 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -782,11 +782,11 @@ class ConvertConv : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(OpT op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { tensorflow::TensorFormat format; std::string data_format = op.data_format().str(); - if (!FormatFromString(data_format, &format)) return Pattern::matchFailure(); + if (!FormatFromString(data_format, &format)) return failure(); auto input_ty = op.input().getType().template dyn_cast(); auto filter_ty = @@ -796,13 +796,13 @@ class ConvertConv : public OpRewritePattern { // Input, filter and the result needs to have static shape for calculation // of HLO paddings and feature group count attributes. for (RankedTensorType ty : {input_ty, filter_ty, result_ty}) { - if (!ty || !ty.hasStaticShape()) return Pattern::matchFailure(); + if (!ty || !ty.hasStaticShape()) return failure(); } int num_dims = num_spatial_dims + 2; tensorflow::Padding padding; if (!GetPaddingFromString(op.padding().str(), &padding).ok()) - return Pattern::matchFailure(); + return failure(); auto get_int = [](Attribute attr) { return attr.template cast().getInt(); @@ -844,7 +844,7 @@ class ConvertConv : public OpRewritePattern { tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( input_ty.getDimSize(dim), filter_ty.getDimSize(i), dilation, stride, padding, &output_size, &pad_low_int64, &pad_high_int64); - if (!status.ok()) return Pattern::matchFailure(); + if (!status.ok()) return failure(); pad_low = pad_low_int64; pad_high = pad_high_int64; } @@ -886,7 +886,7 @@ class ConvertConv : public OpRewritePattern { batch_group_count_attr, paddings_attr}; rewriter.replaceOpWithNewOp(op, op.getType(), operands, llvm::makeArrayRef(attrs)); - return Pattern::matchSuccess(); + return success(); } }; @@ -908,12 +908,12 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::FloorDivOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::FloorDivOp op, + PatternRewriter &rewriter) const override { auto l = op.x(); auto r = op.y(); auto element_type = getElementTypeOrSelf(l.getType()); - if (!element_type.isBF16()) return matchFailure(); + if (!element_type.isBF16()) return failure(); auto out_type = op.z().getType().cast(); @@ -928,7 +928,7 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { auto floor_op = rewriter.create(op.getLoc(), out_type, intermediate); rewriter.replaceOp(op, floor_op.getResult()); - return Pattern::matchSuccess(); + return success(); } }; @@ -938,8 +938,8 @@ class ConvertEinsumOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::EinsumOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter &rewriter) const override { StringAttr equation = op.getAttrOfType("equation"); if (op.N() == 1) { rewriter.replaceOpWithNewOp( @@ -951,9 +951,9 @@ class ConvertEinsumOp : public OpRewritePattern { } else { // TensorFlow EinsumOp verifies that the number of operands are at most // two. - return Pattern::matchFailure(); + return failure(); } - return Pattern::matchSuccess(); + return success(); } }; @@ -966,8 +966,8 @@ class ConvertFusedBatchNormGradBase public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(FusedBatchNormGradOpT op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(FusedBatchNormGradOpT op, + PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value grad = op.y_backprop(); Value act = op.x(); @@ -980,7 +980,7 @@ class ConvertFusedBatchNormGradBase // TensorFlow to absolute indices required by HLO. RankedTensorType act_type = act.getType().template dyn_cast(); - if (!act_type) return Pattern::matchFailure(); + if (!act_type) return failure(); Type act_ele_type = act_type.getElementType(); // To support mixed precision, the statistics type, which maybe more // precise than the input types, are used for this op. @@ -1060,7 +1060,7 @@ class ConvertFusedBatchNormGradBase {/*x_backprop=*/x_backprop, /*scale_backprop=*/scale_backprop, /*offset_backprop=*/offset_backprop, op.x(), op.x()}); - return Pattern::matchSuccess(); + return success(); } }; @@ -1079,8 +1079,8 @@ class ConvertFusedBatchNormV3Op public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::FusedBatchNormV3Op op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::FusedBatchNormV3Op op, + PatternRewriter &rewriter) const override { auto feature_dim = getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x()); @@ -1092,7 +1092,7 @@ class ConvertFusedBatchNormV3Op // In the training case, dimensions of input tensors must be static. if (op.is_training() && ((!input_type_tensor.hasStaticShape()) || (!scale_type_tensor.hasStaticShape()))) { - return matchFailure(); + return failure(); } // TODO(b/69928690): Support mixed precision in the XLA batch @@ -1180,7 +1180,7 @@ class ConvertFusedBatchNormV3Op /*batch_variance=*/op.x(), /*reserve_space_1=*/op.x(), /*reserve_space_2=*/op.x(), /*reserve_space_3=*/op.x()}); } - return Pattern::matchSuccess(); + return success(); } }; @@ -1226,15 +1226,15 @@ class ConvertAvgPoolOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::AvgPoolOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::AvgPoolOp op, + PatternRewriter &rewriter) const override { auto input_type = op.value().getType().dyn_cast(); - if (!input_type) return matchFailure(); + if (!input_type) return failure(); // TODO(b/147217034): support other data formats. - if (!IsDefaultDataFormat(op.data_format())) return matchFailure(); + if (!IsDefaultDataFormat(op.data_format())) return failure(); // TODO(b/147217034): support "SAME" padding. - if (op.padding() != "VALID") return matchFailure(); + if (op.padding() != "VALID") return failure(); // We will do accumulation first; use a larger bitwidth if suitable. Type input_element_type = input_type.getElementType(); @@ -1289,7 +1289,7 @@ class ConvertAvgPoolOp : public OpRewritePattern { rewriter.create(op.getLoc(), result, input_element_type); rewriter.replaceOp(op, result); - return matchSuccess(); + return success(); } }; @@ -1306,16 +1306,16 @@ class ConvertMaxPoolOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::MaxPoolOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::MaxPoolOp op, + PatternRewriter &rewriter) const override { Type element_type = op.input().getType().cast().getElementType(); - if (!element_type.isSignlessIntOrFloat()) return matchFailure(); + if (!element_type.isSignlessIntOrFloat()) return failure(); Location loc = op.getLoc(); ConstOp init = GetMinValueForType(element_type, loc, &rewriter); auto input_ty = op.input().getType().dyn_cast(); - if (!input_ty) return matchFailure(); + if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto reduce = rewriter.create( @@ -1326,7 +1326,7 @@ class ConvertMaxPoolOp : public OpRewritePattern { BuildReduceBody(element_type, &reduce.body(), &rewriter); rewriter.replaceOp(op, reduce.getResult()); - return matchSuccess(); + return success(); } }; @@ -1352,8 +1352,8 @@ class ConvertSelectV2Op : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::SelectV2Op op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::SelectV2Op op, + PatternRewriter &rewriter) const override { llvm::SmallVector broadcast_then_else_shape; auto ranked_then_type = op.t().getType().dyn_cast(); auto ranked_else_type = op.e().getType().dyn_cast(); @@ -1362,18 +1362,18 @@ class ConvertSelectV2Op : public OpRewritePattern { if (!ranked_then_type || !ranked_then_type.hasStaticShape() || !ranked_else_type || !ranked_else_type.hasStaticShape() || !ranked_cond_type || !ranked_cond_type.hasStaticShape()) - return matchFailure(); + return failure(); if (!OpTrait::util::getBroadcastedShape(ranked_then_type.getShape(), ranked_else_type.getShape(), broadcast_then_else_shape)) - return matchFailure(); + return failure(); llvm::SmallVector broadcast_shape; if (!OpTrait::util::getBroadcastedShape(broadcast_then_else_shape, ranked_cond_type.getShape(), broadcast_shape)) - return matchFailure(); + return failure(); auto broadcast_or_self = [&](Value value) { RankedTensorType type = value.getType().cast(); @@ -1404,7 +1404,7 @@ class ConvertSelectV2Op : public OpRewritePattern { rewriter.replaceOpWithNewOp(op, on_true.getType(), pred, on_true, on_false); - return matchSuccess(); + return success(); }; }; @@ -1432,8 +1432,8 @@ class ConvertSigmoidOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::SigmoidOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::SigmoidOp op, + PatternRewriter &rewriter) const override { auto operand = op.getOperand(); auto scalar_one = rewriter.create( @@ -1460,7 +1460,7 @@ class ConvertSigmoidOp : public OpRewritePattern { /*DenseIntElementsAttr=*/DenseIntElementsAttr()); rewriter.replaceOp(op, add_op.getResult()); - return matchSuccess(); + return success(); } }; @@ -1494,14 +1494,14 @@ class ConvertSoftmaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { Value logits = op.logits(); // Softmax converter requires ranked type because the XLA reduce ops used // while lowering requires dimensions attribute to reduce along. RankedTensorType type = logits.getType().dyn_cast(); - if (!type) return Pattern::matchFailure(); + if (!type) return failure(); auto loc = op.getLoc(); int rank = type.getRank(); @@ -1540,7 +1540,7 @@ class ConvertSoftmaxOp : public OpRewritePattern { } else { rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims); } - return Pattern::matchSuccess(); + return success(); } }; @@ -1570,11 +1570,11 @@ class ConvertSizeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::SizeOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::SizeOp op, + PatternRewriter &rewriter) const override { Value input = op.input(); auto input_ty = input.getType().dyn_cast(); - if (!input_ty) return Pattern::matchFailure(); + if (!input_ty) return failure(); const int64_t rank = input_ty.getRank(); auto result_type = op.getResult().getType(); @@ -1591,7 +1591,7 @@ class ConvertSizeOp : public OpRewritePattern { } rewriter.replaceOp(op, size->getResult(0)); - return Pattern::matchSuccess(); + return success(); } }; @@ -1627,22 +1627,22 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::BatchMatMulV2Op op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, + PatternRewriter &rewriter) const override { // TODO(silvasean): Handle adj_x/adj_y // Should be able to just set the contracting_dimensions attribute // appropriately. // For complex types, need to do a complex conjugation. - if (op.adj_x() || op.adj_y()) return matchFailure(); + if (op.adj_x() || op.adj_y()) return failure(); Value lhs = op.x(); Value rhs = op.y(); auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); - if (!lhs_type || !rhs_type) return matchFailure(); + if (!lhs_type || !rhs_type) return failure(); // TODO(silvasean): Support dynamic shapes. if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) { - return matchFailure(); + return failure(); } // Broadcast both operands. @@ -1667,7 +1667,7 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs, dimension_numbers, /*precision_config=*/nullptr); - return matchSuccess(); + return success(); } }; @@ -1708,16 +1708,16 @@ class ConvertSplitOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::SplitOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::SplitOp op, + PatternRewriter &rewriter) const override { // We can only split along static dimensions. auto input_type = op.value().getType().dyn_cast(); - if (!input_type) return matchFailure(); + if (!input_type) return failure(); // We can only match when the split dimension is a constant scalar. DenseIntElementsAttr split_dim_attr; if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr))) - return matchFailure(); + return failure(); // Get the dimension we are splitting at. Offset properly if it's negative. int64_t input_rank = input_type.getRank(); @@ -1728,7 +1728,7 @@ class ConvertSplitOp : public OpRewritePattern { int64_t input_dim_size = input_type.getDimSize(dim_index); // If we are splitting along the dynamic dimension then we cannot compute // the static dimension length. - if (TensorType::isDynamic(input_dim_size)) return matchFailure(); + if (TensorType::isDynamic(input_dim_size)) return failure(); int64_t num_splits = op.getNumResults(); int64_t slice_size = input_dim_size / num_splits; @@ -1759,7 +1759,7 @@ class ConvertSplitOp : public OpRewritePattern { } rewriter.replaceOp(op, slices); - return matchSuccess(); + return success(); } }; @@ -1799,22 +1799,22 @@ class ConvertSplitVOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::SplitVOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::SplitVOp op, + PatternRewriter &rewriter) const override { // We can only split along static dimensions. // TODO(b/145731001): enhance to support dynamic-shaped inputs. auto input_type = op.value().getType().dyn_cast(); - if (!input_type) return matchFailure(); + if (!input_type) return failure(); // We can only match when the split dimension is a constant scalar. DenseIntElementsAttr split_dim_attr; if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr))) - return matchFailure(); + return failure(); // We can only match when the split sizes is a constant int vector. DenseIntElementsAttr split_sizes_attr; if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) - return matchFailure(); + return failure(); // Get each chunck's size along the dimension to split. It may contain // dynamic sizes and we need to update it if so. @@ -1841,7 +1841,7 @@ class ConvertSplitVOp : public OpRewritePattern { if (dim_index < 0) dim_index += input_rank; int64_t input_dim_size = input_type.getDimSize(dim_index); - if (TensorType::isDynamic(input_dim_size)) return matchFailure(); + if (TensorType::isDynamic(input_dim_size)) return failure(); assert(((dynamic_dim_index && total_dim_size <= input_dim_size) || (!dynamic_dim_index && total_dim_size == input_dim_size)) && @@ -1871,7 +1871,7 @@ class ConvertSplitVOp : public OpRewritePattern { } rewriter.replaceOp(op, slices); - return matchSuccess(); + return success(); } }; @@ -1893,15 +1893,15 @@ class ConvertStridedSliceOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::StridedSliceOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::StridedSliceOp op, + PatternRewriter &rewriter) const override { // Input shape needs to be static to convert negative indices in TensorFlow // to absolute indices required by HLO. // // TODO(hinsu): Relax this constraint for ops without negative indices and // strides. auto input_ty = op.input().getType().dyn_cast(); - if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); ArrayRef input_shape = input_ty.getShape(); // Output shape needs to be static to apply 'new_axis_mask' or @@ -1909,11 +1909,11 @@ class ConvertStridedSliceOp : public OpRewritePattern { // // TODO(hinsu): Relax this constraint for ops without the above masks. auto result_ty = op.getType().dyn_cast(); - if (!result_ty || !result_ty.hasStaticShape()) return matchFailure(); + if (!result_ty || !result_ty.hasStaticShape()) return failure(); SmallVector begin_indices, end_indices, strides; if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) - return matchFailure(); + return failure(); SmallVector hlo_begin_indices, hlo_end_indices, hlo_strides, dims_to_reverse; @@ -1923,7 +1923,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { hlo_strides.reserve(input_rank); int64_t indices_elements = begin_indices.size(); - if (input_rank < indices_elements) return matchFailure(); + if (input_rank < indices_elements) return failure(); // Convert from TensorFlow negative or out of range indices and strides // values to legal HLO Slice attributes. @@ -1967,7 +1967,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // Reshape slice result so that the shape is updated depending on // 'new_axis_mask' or 'shrink_axis_mask' attributes. rewriter.replaceOpWithNewOp(op, op.getType(), sliced); - return matchSuccess(); + return success(); } }; @@ -1982,12 +1982,12 @@ class ConvertStridedSliceGradOp public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::StridedSliceGradOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::StridedSliceGradOp op, + PatternRewriter &rewriter) const override { // We need constant input shape to perform padding calculations later. DenseIntElementsAttr input_shape_attr; if (!matchPattern(op.shape(), m_Constant(&input_shape_attr))) - return matchFailure(); + return failure(); // We also need constant begin/end indices and strides to perform padding // calculations. @@ -1997,7 +1997,7 @@ class ConvertStridedSliceGradOp SmallVector begin_indices, end_indices, strides; if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices, &strides)) - return matchFailure(); + return failure(); Value grad = op.dy(); Type element_type = grad.getType().cast().getElementType(); @@ -2050,7 +2050,7 @@ class ConvertStridedSliceGradOp GetI64ElementsAttr(padding_low, &rewriter), GetI64ElementsAttr(padding_high, &rewriter), GetI64ElementsAttr(padding_interm, &rewriter)); - return matchSuccess(); + return success(); } }; @@ -2075,12 +2075,12 @@ class ConvertStridedSliceGradOp class ConvertRangeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::RangeOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::RangeOp op, + PatternRewriter &rewriter) const override { auto result = op.getResult(); auto result_type = result.getType(); if (!result_type.cast().hasStaticShape()) { - return matchFailure(); + return failure(); } auto iota = rewriter.create(op.getLoc(), result_type, @@ -2091,7 +2091,7 @@ class ConvertRangeOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); - return matchSuccess(); + return success(); } }; @@ -2102,12 +2102,12 @@ class ConvertRangeOp : public OpRewritePattern { class ConvertLinSpaceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::LinSpaceOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::LinSpaceOp op, + PatternRewriter &rewriter) const override { auto result = op.getResult(); auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { - return matchFailure(); + return failure(); } // Calculate the scaling that needs to be applied to the iota. @@ -2137,7 +2137,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); - return matchSuccess(); + return success(); } }; @@ -2151,18 +2151,18 @@ template { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { // TODO(b/141785544): Update this to not require static shapes. // Input shape needs to be static to convert negative indices in TensorFlow // to absolute indices required by HLO. auto input_ty = op.input().getType().template dyn_cast(); - if (!input_ty) return this->matchFailure(); + if (!input_ty) return failure(); ArrayRef input_shape = input_ty.getShape(); DenseIntElementsAttr dimensions; if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions))) - return this->matchFailure(); + return failure(); // Build the final shape from input_shape and dimensions using a bitmap // to mark the reduced dimensions. @@ -2171,7 +2171,7 @@ class GenericConvertReductionOp : public OpRewritePattern { for (APInt index_raw : dimensions.getValues()) { int64_t index = index_raw.getSExtValue(); int64_t rank = input_shape.size(); - if ((index < -rank || index >= rank)) return this->matchFailure(); + if ((index < -rank || index >= rank)) return failure(); index = (index + rank) % rank; reduced_dimensions_bitmap[index] = true; xla_dimensions.push_back(index); @@ -2202,7 +2202,7 @@ class GenericConvertReductionOp : public OpRewritePattern { for (size_t i = 0; i < input_shape.size(); ++i) { if (reduced_dimensions_bitmap[i]) { if (TensorType::isDynamic(input_shape[i])) { - return this->matchFailure(); + return failure(); } divisor_count *= input_shape[i]; } @@ -2223,7 +2223,7 @@ class GenericConvertReductionOp : public OpRewritePattern { } rewriter.replaceOp(op, {result}); - return this->matchSuccess(); + return success(); } }; @@ -2349,18 +2349,18 @@ template class ConvertArgMinMaxOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { RankedTensorType input_type = op.input().getType().template dyn_cast(); if (!input_type) { - return this->matchFailure(); + return failure(); } Type input_element_type = input_type.getElementType(); // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If // tf.ArgMax doesn't support complex data types, this check can be removed. - if (!input_element_type.isSignlessIntOrFloat()) return this->matchFailure(); + if (!input_element_type.isSignlessIntOrFloat()) return failure(); Location loc = op.getLoc(); Value init_value = @@ -2369,7 +2369,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { RankedTensorType output_type = op.output().getType().template dyn_cast(); if (!output_type) { - return this->matchFailure(); + return failure(); } Type index_element_type = output_type.getElementType(); @@ -2382,7 +2382,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { llvm::Optional optional_axis = GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank()); if (!optional_axis.hasValue()) { - return this->matchFailure(); + return failure(); } int64_t axis = optional_axis.getValue(); @@ -2408,7 +2408,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { direction, &reduction.body(), &rewriter); rewriter.replaceOp(op, {reduction.getResult(1)}); - return this->matchSuccess(); + return success(); } }; @@ -2442,18 +2442,18 @@ class ConvertTensorScatterUpdateOp public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::TensorScatterUpdateOp op, + PatternRewriter &rewriter) const override { auto tensor_ty = op.tensor().getType().dyn_cast(); auto indices_ty = op.indices().getType().dyn_cast(); auto updates_ty = op.updates().getType().dyn_cast(); - if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure(); + if (!tensor_ty || !indices_ty || !updates_ty) return failure(); // Last dimension of the indices needs to known at compile time for // computation of the 'update_window_dims' attribute in the dimensions // struct. int64_t num_index_dims = indices_ty.getShape().back(); - if (ShapedType::isDynamic(num_index_dims)) return matchFailure(); + if (ShapedType::isDynamic(num_index_dims)) return failure(); int64_t tensor_rank = tensor_ty.getRank(); int64_t indices_rank = indices_ty.getRank(); @@ -2484,7 +2484,7 @@ class ConvertTensorScatterUpdateOp }(&scatter.update_computation()); rewriter.replaceOp(op, scatter.getResult()); - return matchSuccess(); + return success(); } }; @@ -2501,19 +2501,19 @@ class ConvertTileOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::TileOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::TileOp op, + PatternRewriter &rewriter) const override { auto input_ty = op.input().getType().dyn_cast(); - if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); ArrayRef input_shape = input_ty.getShape(); Type element_type = input_ty.getElementType(); DenseIntElementsAttr multiples; if (!matchPattern(op.multiples(), m_Constant(&multiples)) || multiples.getType().getRank() != 1) - return matchFailure(); + return failure(); - if (multiples.getNumElements() != input_shape.size()) return matchFailure(); + if (multiples.getNumElements() != input_shape.size()) return failure(); SmallVector broadcasted_shape; SmallVector broadcast_dimensions; @@ -2524,7 +2524,7 @@ class ConvertTileOp : public OpRewritePattern { int64_t multiple = std::get<0>(multiple_and_input).getSExtValue(); int64_t input_size = std::get<1>(multiple_and_input); - if (multiple < 0) return matchFailure(); + if (multiple < 0) return failure(); // Line input up with the next dimension in broadcasted_shape // when broadcasting. @@ -2554,7 +2554,7 @@ class ConvertTileOp : public OpRewritePattern { rewriter.replaceOp(op, {result}); - return matchSuccess(); + return success(); } }; @@ -2562,8 +2562,8 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::MaxPoolGradOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::MaxPoolGradOp op, + PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type element_type = @@ -2573,7 +2573,7 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { // Here, ReduceWindow op as used as the MaxPool op is lowered to the // ReduceWindow op. auto input_ty = op.orig_input().getType().dyn_cast(); - if (!input_ty) return matchFailure(); + if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); @@ -2601,7 +2601,7 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { rewriter.replaceOp(op, {result}); - return matchSuccess(); + return success(); } }; @@ -2613,24 +2613,23 @@ class ConvertConv2DBackpropInputOp public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::Conv2DBackpropInputOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::Conv2DBackpropInputOp op, + PatternRewriter &rewriter) const override { // Unpack all of the attributes. tensorflow::TensorFormat data_format; if (!FormatFromString(op.data_format().str(), &data_format)) { - return matchFailure(); + return failure(); } tensorflow::Padding padding; if (!GetPaddingFromString(op.padding().str(), &padding).ok()) - return Pattern::matchFailure(); + return failure(); auto out_backprop_ty = op.out_backprop().getType().dyn_cast(); - if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) - return matchFailure(); + if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) return failure(); ArrayRef out_backprop_shape = out_backprop_ty.getShape(); auto filter_ty = op.filter().getType().dyn_cast(); - if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure(); + if (!filter_ty || !filter_ty.hasStaticShape()) return failure(); ArrayRef filter_shape = filter_ty.getShape(); int num_spatial_dims = 2; Location loc = op.getLoc(); @@ -2643,11 +2642,11 @@ class ConvertConv2DBackpropInputOp DenseIntElementsAttr input_shape_attr; if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) || input_shape_attr.getType().getRank() != 1) { - return matchFailure(); + return failure(); } auto input_shape = llvm::to_vector<4>(input_shape_attr.getValues()); - if (input_shape.size() != num_dims) return matchFailure(); + if (input_shape.size() != num_dims) return failure(); auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim); auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim); @@ -2676,7 +2675,7 @@ class ConvertConv2DBackpropInputOp ToTensorShape(out_backprop_shape), dilations, strides, padding, explicit_paddings, data_format, &dims) .ok()) { - return matchFailure(); + return failure(); } // Compute ConvDimensionNumbers, dilation, and padding. @@ -2709,7 +2708,7 @@ class ConvertConv2DBackpropInputOp filter = TransposeFilterForGroupConvolutionBackpropInput( filter, filter_shape, feature_group_count, attrs.num_spatial_dims); */ - return matchFailure(); + return failure(); } // Mirror the filter in the spatial dimensions. @@ -2746,7 +2745,7 @@ class ConvertConv2DBackpropInputOp rewriter.replaceOp(op, {result}); - return matchSuccess(); + return success(); } }; @@ -2757,30 +2756,29 @@ class ConvertConv2DBackpropFilterOp public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::Conv2DBackpropFilterOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::Conv2DBackpropFilterOp op, + PatternRewriter &rewriter) const override { // Unpack all of the attributes. tensorflow::TensorFormat data_format; if (!FormatFromString(op.data_format().str(), &data_format)) { - return matchFailure(); + return failure(); } tensorflow::Padding padding; if (!GetPaddingFromString(op.padding().str(), &padding).ok()) - return Pattern::matchFailure(); + return failure(); auto out_backprop_ty = op.out_backprop().getType().dyn_cast(); - if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) - return matchFailure(); + if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) return failure(); ArrayRef out_backprop_shape = out_backprop_ty.getShape(); auto input_ty = op.input().getType().dyn_cast(); - if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); ArrayRef input_shape = input_ty.getShape(); DenseIntElementsAttr filter_shape_attr; if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) || filter_shape_attr.getType().getRank() != 1) { - return matchFailure(); + return failure(); } auto strides_attr = GetI64ElementsAttr(op.strides()); @@ -2803,7 +2801,7 @@ class ConvertConv2DBackpropFilterOp auto filter_shape = llvm::to_vector<4>(filter_shape_attr.getValues()); - if (filter_shape.size() != num_dims) return matchFailure(); + if (filter_shape.size() != num_dims) return failure(); // Reuse dimension computation logic from conv_grad_shape_utils.cc. tensorflow::ConvBackpropDimensions dims; @@ -2813,7 +2811,7 @@ class ConvertConv2DBackpropFilterOp ToTensorShape(out_backprop_shape), dilations, strides, padding, explicit_paddings, data_format, &dims) .ok()) { - return matchFailure(); + return failure(); } // The activations (inputs) form the LHS of the convolution. @@ -2832,7 +2830,7 @@ class ConvertConv2DBackpropFilterOp activations, input_shape, feature_group_count, batch_dim, feature_dim); */ - return matchFailure(); + return failure(); } // Compute ConvDimensionNumbers, dilation, and padding. @@ -2948,7 +2946,7 @@ class ConvertConv2DBackpropFilterOp rewriter.replaceOp(op, {result}); - return matchSuccess(); + return success(); } }; @@ -2956,16 +2954,16 @@ class ConvertOneHotOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::OneHotOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::OneHotOp op, + PatternRewriter &rewriter) const override { auto indices_ty = op.indices().getType().dyn_cast(); - if (!indices_ty || !indices_ty.hasStaticShape()) return matchFailure(); + if (!indices_ty || !indices_ty.hasStaticShape()) return failure(); ArrayRef indices_shape = indices_ty.getShape(); Type element_type = indices_ty.getElementType(); DenseIntElementsAttr depth_attr; if (!matchPattern(op.depth(), m_Constant(&depth_attr))) { - return matchFailure(); + return failure(); } int64_t depth = depth_attr.getValue({}).getSExtValue(); @@ -3000,7 +2998,7 @@ class ConvertOneHotOp : public OpRewritePattern { rewriter.replaceOp(op, {result}); - return matchSuccess(); + return success(); } }; @@ -3032,8 +3030,8 @@ class ConvertInfeedDequeueTupleOp public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::InfeedDequeueTupleOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op, + PatternRewriter &rewriter) const override { std::vector result_types(op.outputs().size()); for (auto idx_and_output : llvm::enumerate(op.outputs())) { result_types[idx_and_output.index()] = (idx_and_output.value().getType()); @@ -3069,7 +3067,7 @@ class ConvertInfeedDequeueTupleOp results.push_back(tuple_element); } rewriter.replaceOp(op, ValueRange(results)); - return matchSuccess(); + return success(); } }; @@ -3096,15 +3094,15 @@ class ConvertOutfeedEnqueueTupleOp public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, + PatternRewriter &rewriter) const override { auto token_type = xla_hlo::TokenType::get(rewriter.getContext()); auto tuple = rewriter.create(op.getLoc(), op.inputs()); auto token = rewriter.create(op.getLoc(), token_type); rewriter.create(op.getLoc(), token_type, tuple, token, /*outfeed_config=*/rewriter.getStringAttr("")); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -3142,20 +3140,20 @@ class ConvertTopKV2Op : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::TopKV2Op op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::TopKV2Op op, + PatternRewriter &rewriter) const override { // We can only match when the `k` operand is a constant scalar. DenseIntElementsAttr k_attr; - if (!matchPattern(op.k(), m_Constant(&k_attr))) return matchFailure(); + if (!matchPattern(op.k(), m_Constant(&k_attr))) return failure(); // The last dimension of the input tensor's shape should be known so we can // have clamped end_indices for slices. TensorType input_type = op.input().getType().cast(); - if (!input_type.hasRank()) return matchFailure(); + if (!input_type.hasRank()) return failure(); int64_t input_rank = input_type.getRank(); int64_t last_dim_index = input_rank - 1; int64_t last_dim_size = input_type.getDimSize(last_dim_index); - if (last_dim_size == ShapedType::kDynamicSize) return matchFailure(); + if (last_dim_size == ShapedType::kDynamicSize) return failure(); // Create an Itoa op for indices. auto i32_type = rewriter.getIntegerType(32); @@ -3199,7 +3197,7 @@ class ConvertTopKV2Op : public OpRewritePattern { GetI64ElementsAttr(strides, &rewriter)); rewriter.replaceOp(op, {values, indices}); - return matchSuccess(); + return success(); } }; @@ -3213,10 +3211,10 @@ class ConvertUnpackOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::UnpackOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::UnpackOp op, + PatternRewriter &rewriter) const override { auto value_type = op.value().getType().cast(); - if (!value_type) return matchFailure(); + if (!value_type) return failure(); int64_t value_rank = value_type.getRank(); int64_t axis = op.axis().getSExtValue(); @@ -3246,7 +3244,7 @@ class ConvertUnpackOp : public OpRewritePattern { } rewriter.replaceOp(op, results); - return matchSuccess(); + return success(); } }; @@ -3271,20 +3269,20 @@ template class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { auto data_type = op.data().getType().template dyn_cast(); - if (!data_type) return this->matchFailure(); + if (!data_type) return failure(); int64_t data_rank = data_type.getRank(); auto segment_ids_type = op.segment_ids().getType().template dyn_cast(); - if (!segment_ids_type) return this->matchFailure(); + if (!segment_ids_type) return failure(); int64_t segment_ids_rank = segment_ids_type.getRank(); DenseIntElementsAttr num_segments_attr; if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) - return this->matchFailure(); + return failure(); // The final shape for TF unsorted segment reduction op is [num_segments] + // data_shape[segment_ids_rank:]. @@ -3322,7 +3320,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { &scatter.update_computation(), &rewriter); rewriter.replaceOp(op, scatter.getResult()); - return this->matchSuccess(); + return success(); } }; @@ -3390,20 +3388,20 @@ class ConvertRandomShuffleOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::RandomShuffleOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::RandomShuffleOp op, + PatternRewriter &rewriter) const override { auto input_type = op.value().getType().dyn_cast(); - if (!input_type) return matchFailure(); + if (!input_type) return failure(); int64_t input_rank = input_type.getRank(); int64_t first_dim_size = input_type.getDimSize(0); - if (ShapedType::isDynamic(first_dim_size)) return matchFailure(); + if (ShapedType::isDynamic(first_dim_size)) return failure(); // We are shuffling along the first dimension. If its size is <= 1, then // shuffling is a no-op. if (first_dim_size <= 1) { rewriter.replaceOp(op, op.value()); - return matchSuccess(); + return success(); } // For vectors, shuffle values by sorting instead of the obvious @@ -3464,7 +3462,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { sorted.getResult(), 1); } rewriter.replaceOp(op, current); - return matchSuccess(); + return success(); } // The Fisher-Yates algorithm. @@ -3540,7 +3538,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { op, op.getType(), op.value(), swaped_indices, dims_attr, GetI64ElementsAttr(slice_sizes, &rewriter)); - return matchSuccess(); + return success(); } }; @@ -3550,18 +3548,18 @@ class ConvertVariableShapeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::VariableShapeOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::VariableShapeOp op, + PatternRewriter &rewriter) const override { // The input type should be a tensor>. We need // to get the inner resource type. auto input_type = op.input().getType().cast(); auto subtypes = input_type.getElementType().cast().getSubtypes(); // It can be missing; then we cannot convert. - if (subtypes.empty()) return matchFailure(); + if (subtypes.empty()) return failure(); auto resource_type = subtypes[0].cast(); - if (!resource_type.hasStaticShape()) return matchFailure(); + if (!resource_type.hasStaticShape()) return failure(); auto resource_shape = resource_type.getShape(); Attribute const_attr; @@ -3579,7 +3577,7 @@ class ConvertVariableShapeOp : public OpRewritePattern { } rewriter.replaceOpWithNewOp(op, const_attr); - return matchSuccess(); + return success(); } }; @@ -3588,13 +3586,13 @@ class ConvertXlaShardingOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::XlaShardingOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::XlaShardingOp op, + PatternRewriter &rewriter) const override { // TODO(b/148313088): define sharding attribute struct in MLIR intead of // using a string. auto sharding = op.getAttrOfType("_XlaSharding"); if (!sharding) { - return matchFailure(); + return failure(); } // _XlaSharding attribute in TF is a serialized string of the OpSharding @@ -3602,11 +3600,11 @@ class ConvertXlaShardingOp : public OpRewritePattern { ::xla::OpSharding sharding_proto; std::string sharding_str; if (!sharding_proto.ParseFromString(sharding.getValue().str())) { - return matchFailure(); + return failure(); } if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, &sharding_str)) { - return matchFailure(); + return failure(); } auto custom_call = rewriter.create( @@ -3618,7 +3616,7 @@ class ConvertXlaShardingOp : public OpRewritePattern { rewriter.getStringAttr(sharding_str)); rewriter.replaceOp(op, custom_call.getResult()); - return matchSuccess(); + return success(); } }; @@ -3628,12 +3626,12 @@ class ConvertXlaDynamicUpdateSliceOp public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, + PatternRewriter &rewriter) const override { auto indices_type = op.indices().getType().dyn_cast(); if (!indices_type || !indices_type.hasStaticShape() || indices_type.getShape().size() != 1) - return matchFailure(); + return failure(); SmallVector unpacked_indices_type( indices_type.getDimSize(0), @@ -3643,7 +3641,7 @@ class ConvertXlaDynamicUpdateSliceOp IntegerAttr::get(rewriter.getIntegerType(64), 0)); rewriter.replaceOpWithNewOp( op, op.getType(), op.input(), op.update(), unpacked_indices.output()); - return matchSuccess(); + return success(); } }; @@ -3654,24 +3652,24 @@ class ConvertXlaDynamicUpdateSliceOp class ConvertCumsumOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TF::CumsumOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TF::CumsumOp op, + PatternRewriter &rewriter) const override { auto input = op.x(); auto input_type = input.getType().dyn_cast(); if (!input_type || !input_type.hasStaticShape()) { - return matchFailure(); + return failure(); } // TODO(jennik): Add support for the optional 'exclusive' and 'reverse' // arguments. if (op.exclusive() || op.reverse()) { - return matchFailure(); + return failure(); } // We can only match when the axis is a constant scalar. DenseIntElementsAttr axis_attr; if (!matchPattern(op.axis(), m_Constant(&axis_attr))) { - return matchFailure(); + return failure(); } // Convert if we need to enlarge the element type's bitwidth to avoid @@ -3717,7 +3715,7 @@ class ConvertCumsumOp : public OpRewritePattern { rewriter.create(op.getLoc(), result, input_element_type); rewriter.replaceOp(op, result); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index aeaceeb27d5..0769e92b8ce 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -35,19 +35,19 @@ class CompareIConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(xla_hlo::CompareOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(xla_hlo::CompareOp op, + PatternRewriter &rewriter) const override { auto lhs = op.lhs(); auto rhs = op.rhs(); auto lhs_type = lhs.getType().cast(); auto rhs_type = rhs.getType().cast(); // Broadcasting not supported by this rewrite. - if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure(); + if (lhs_type.getShape() != rhs_type.getShape()) return failure(); if (!lhs_type.getElementType().isSignlessInteger() || !rhs_type.getElementType().isSignlessInteger()) - return matchFailure(); + return failure(); auto comparison_direction = op.comparison_direction(); auto compare_predicate = @@ -60,11 +60,11 @@ class CompareIConvert : public OpRewritePattern { .Case("GE", CmpIPredicate::sge) .Default(llvm::None); - if (!compare_predicate.hasValue()) return matchFailure(); + if (!compare_predicate.hasValue()) return failure(); rewriter.replaceOpWithNewOp(op, compare_predicate.getValue(), lhs, rhs); - return matchSuccess(); + return success(); } }; @@ -72,19 +72,19 @@ class CompareFConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(xla_hlo::CompareOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(xla_hlo::CompareOp op, + PatternRewriter &rewriter) const override { auto lhs = op.lhs(); auto rhs = op.rhs(); auto lhs_type = lhs.getType().cast(); auto rhs_type = rhs.getType().cast(); // Broadcasting not supported by this rewrite. - if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure(); + if (lhs_type.getShape() != rhs_type.getShape()) return failure(); if (!lhs_type.getElementType().isa() || !rhs_type.getElementType().isa()) - return matchFailure(); + return failure(); auto comparison_direction = op.comparison_direction(); auto compare_predicate = @@ -97,11 +97,11 @@ class CompareFConvert : public OpRewritePattern { .Case("GE", CmpFPredicate::OGE) .Default(llvm::None); - if (!compare_predicate.hasValue()) return matchFailure(); + if (!compare_predicate.hasValue()) return failure(); rewriter.replaceOpWithNewOp(op, compare_predicate.getValue(), lhs, rhs); - return matchSuccess(); + return success(); } }; @@ -113,8 +113,8 @@ class ConvertIotaOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(xla_hlo::IotaOp op, + PatternRewriter &rewriter) const override { auto output_type = op.getType().cast(); auto output_size = output_type.getNumElements(); auto dimension = op.iota_dimension().getSExtValue(); @@ -159,7 +159,7 @@ class ConvertIotaOp : public OpRewritePattern { // For int/float types we are done, replace op and return. if (!complex_ty) { rewriter.replaceOp(op, iota_const.getResult()); - return matchSuccess(); + return success(); } // For complex types, generate a constant tensor of zeroes for the imaginary @@ -170,7 +170,7 @@ class ConvertIotaOp : public OpRewritePattern { rewriter.create(loc, int_or_float_shape_ty, zeroes); rewriter.replaceOpWithNewOp(op, iota_const, imag_zeroes); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index 32053950fed..43f0116ef0d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -35,8 +35,8 @@ template struct BinaryOpConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(LhloOpTy op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(LhloOpTy op, + PatternRewriter& rewriter) const override { const auto& lhs = op.lhs(); const auto& rhs = op.rhs(); const auto& lhs_type = lhs.getType().template cast(); @@ -44,7 +44,7 @@ struct BinaryOpConverter : public OpRewritePattern { const auto& element_type = lhs_type.getElementType(); if (lhs_type.getShape() != rhs_type.getShape()) { - return this->matchFailure(); + return failure(); } const auto& shape = lhs_type.getShape(); SmallVector induction_vars; @@ -59,11 +59,11 @@ struct BinaryOpConverter : public OpRewritePattern { Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( op, element_type, {l, r}, &rewriter); if (opResult == nullptr) { - return this->matchFailure(); + return failure(); } rewriter.create(loc, opResult, op.out(), induction_vars); rewriter.eraseOp(op); - return this->matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index c9245d93e56..537703302c0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -48,7 +48,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( ReduceOp reduce_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = reduce_op.getLoc(); @@ -57,11 +57,11 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { for (auto result : reduce_op.out()) { auto shaped_type = result.getType().dyn_cast(); if (!shaped_type || shaped_type.getRank() != 1) { - return matchFailure(); + return failure(); } auto dim_size = shaped_type.getDimSize(0); if (size && size != dim_size) { - return matchFailure(); + return failure(); } size = dim_size; } @@ -73,7 +73,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { for (auto input : reduce_op.operands()) { auto shaped_type = input.getType().dyn_cast(); if (!shaped_type || !shaped_type.hasStaticShape()) { - return matchFailure(); + return failure(); } reduce_dim_size = shaped_type.getDimSize(reducing_dimension.getSExtValue()); @@ -164,7 +164,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { } rewriter.eraseOp(reduce_op); - return matchSuccess(); + return success(); }; }; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index 8ef08e4f9f3..894e4d039b8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -66,18 +66,18 @@ class ReduceOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( xla_lhlo::ReduceOp xla_reduce_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // TODO(b/137624192) Implement variadic reduce. - if (xla_reduce_op.out().size() != 1) return matchFailure(); + if (xla_reduce_op.out().size() != 1) return failure(); loop::ReduceOp reduce_op = CreateParallelLoopsWithReduceOp(xla_reduce_op, args, &rewriter); ConvertReductionOperator(xla_reduce_op, &reduce_op.reductionOperator().front(), &rewriter); rewriter.replaceOp(xla_reduce_op, llvm::None); - return matchSuccess(); + return success(); } private: diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc index f18607dfffb..2e901094348 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc @@ -32,14 +32,16 @@ limitations under the License. using mlir::DenseIntElementsAttr; using mlir::ElementsAttr; +using mlir::failure; using mlir::FunctionPass; +using mlir::LogicalResult; using mlir::MLIRContext; using mlir::OpRewritePattern; using mlir::OwningRewritePatternList; using mlir::PassRegistration; -using mlir::PatternMatchResult; using mlir::PatternRewriter; using mlir::RankedTensorType; +using mlir::success; using mlir::Value; namespace { @@ -135,14 +137,14 @@ struct GeneralDotConvert explicit GeneralDotConvert(MLIRContext *context) : OpRewritePattern(context) {} - PatternMatchResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op, + PatternRewriter &rewriter) const override { auto dot_element_type = mlir::getElementTypeOrSelf(op); auto dot_numbers = op.dot_dimension_numbers(); if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 || dot_numbers.rhs_batching_dimensions().getNumElements() != 0) { - return matchFailure(); + return failure(); } auto lhs = ProcessDotArg(op.lhs(), op.getLoc(), @@ -164,7 +166,7 @@ struct GeneralDotConvert rewriter.replaceOpWithNewOp(op, op.getType(), new_dot_op); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index 4c20a589ce0..157029a04dc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -260,22 +260,22 @@ struct BinaryOpWithBroadcastConvert : public OpRewritePattern { explicit BinaryOpWithBroadcastConvert(MLIRContext *context) : OpRewritePattern(context) {} - PatternMatchResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { Value new_lhs; Value new_rhs; auto op_ranked_type = op.getType().template dyn_cast(); - if (!op_ranked_type) return this->matchFailure(); + if (!op_ranked_type) return failure(); if (op_ranked_type.hasStaticShape()) { if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) { - return this->matchFailure(); + return failure(); } } else { if (!CreateDynamicBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) { - return this->matchFailure(); + return failure(); } } @@ -283,7 +283,7 @@ struct BinaryOpWithBroadcastConvert : public OpRewritePattern { // New args are broadcasts, so no dims are needed on the replacement op. rewriter.replaceOpWithNewOp(op, op.getType(), new_lhs, new_rhs, /*broadcast_dims=*/nullptr); - return this->matchSuccess(); + return success(); } }; @@ -292,18 +292,18 @@ struct CompareWithBroadcastConvert : public OpRewritePattern { explicit CompareWithBroadcastConvert(MLIRContext *context) : OpRewritePattern(context) {} - PatternMatchResult matchAndRewrite(CompareOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(CompareOp op, + PatternRewriter &rewriter) const override { Value new_lhs; Value new_rhs; if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) { - return this->matchFailure(); + return failure(); } rewriter.replaceOpWithNewOp(op, op.getType(), new_lhs, new_rhs, /*broadcast_dims=*/nullptr, op.comparison_direction()); - return this->matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc index 071cc575656..2b785c4ba06 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc @@ -112,7 +112,7 @@ class UnfuseBatchNormInferencePattern public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( xla_hlo::BatchNormInferenceOp bn_op, ArrayRef raw_operands, ConversionPatternRewriter& rewriter) const override { xla_hlo::BatchNormInferenceOpOperandAdaptor operands(raw_operands); @@ -124,11 +124,11 @@ class UnfuseBatchNormInferencePattern auto variance_type = operands.variance().getType().dyn_cast(); if (!input_type || !variance_type) { - return matchFailure(); + return failure(); } auto fp_type = variance_type.getElementType().dyn_cast(); if (!fp_type) { - return matchFailure(); + return failure(); } int64_t feature_dim = bn_op.feature_index().getSExtValue(); @@ -138,7 +138,7 @@ class UnfuseBatchNormInferencePattern MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type, operands.variance(), variance_type, rewriter); if (!epsilon) { - return matchFailure(); + return failure(); } Value stddev = rewriter.create(bn_op.getLoc(), operands.variance(), @@ -174,7 +174,7 @@ class UnfuseBatchNormInferencePattern rewriter.replaceOpWithNewOp(bn_op, result, broadcast_offset, nullptr); - return matchSuccess(); + return success(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 0daec32fbab..983e5795253 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -73,7 +73,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); @@ -81,10 +81,10 @@ class PointwiseToLinalgConverter : public OpConversionPattern { op.getOperation()->getOperand(0).getType().template cast(); if (!argType.hasRank()) { emitError(loc, "lhlo to linalg conversion expects ranked args"); - return ConversionPattern::matchFailure(); + return failure(); } if (!argType.getElementType().isSignlessIntOrFloat()) { - return ConversionPattern::matchFailure(); + return failure(); } // Construct the indexing maps needed for linalg.generic ops. @@ -95,7 +95,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { // here is that are broadcasts have been made explicit. unsigned nloops = argType.getRank(); if (!nloops) { - return ConversionPattern::matchFailure(); + return failure(); } int operandCount = (isLHLO ? args.size() - 1 : args.size()); auto verifyArgOrResultType = [&](Value val) -> ShapedType { @@ -111,7 +111,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { }; for (const auto& arg : llvm::enumerate(args)) { auto shapedType = verifyArgOrResultType(arg.value()); - if (!shapedType) return ConversionPattern::matchFailure(); + if (!shapedType) return failure(); auto& result_or_body_arg = arg.index() < operandCount ? bodyArgTypes : bodyResultTypes; result_or_body_arg.emplace_back(shapedType.getElementType()); @@ -122,7 +122,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { "When lowering HLO ops result can't be part of arguments"); Value result = op.getOperation()->getResult(0); auto shapedType = verifyArgOrResultType(result); - if (!shapedType) return ConversionPattern::matchFailure(); + if (!shapedType) return failure(); bodyResultTypes.push_back(shapedType.getElementType()); opResultTypes.push_back(shapedType); } @@ -152,11 +152,11 @@ class PointwiseToLinalgConverter : public OpConversionPattern { Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( op, bodyResultTypes, bodyArgs, &rewriter); if (!opResult) { - return ConversionPattern::matchFailure(); + return failure(); } rewriter.create(loc, opResult); rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); - return ConversionPattern::matchSuccess(); + return success(); } }; @@ -165,7 +165,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( LhloOp lhlo_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = lhlo_op.getLoc(); @@ -173,7 +173,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { lhlo_op.getOperand(0).getType().template dyn_cast(); if (!argType || !argType.getElementType().isSignlessIntOrFloat() || (argType.getRank() != 0)) { - return ConversionPattern::matchFailure(); + return failure(); } // Create two loads from the input. @@ -185,7 +185,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { &rewriter); rewriter.create(loc, opResult, lhlo_op.out()); rewriter.eraseOp(lhlo_op); - return ConversionPattern::matchSuccess(); + return success(); } }; @@ -199,18 +199,16 @@ class DataMovementOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - if (!verifyXLAOpBufferOrTensorSemantics(op)) - return ConversionPattern::matchFailure(); + if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); auto operandType = op.operand().getType().template cast(); auto resultType = getXLAOpResultType(op); - if (!verifyXLAOpBufferOrTensorSemantics(op)) - return ConversionPattern::matchFailure(); + if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); ArrayAttr indexingMapsAttr = static_cast(*this).getIndexingMapsAttr(op, &rewriter); - if (!indexingMapsAttr) return ConversionPattern::matchFailure(); + if (!indexingMapsAttr) return failure(); OpBuilder::InsertionGuard linalgOpGuard(rewriter); auto nloops = resultType.getRank(); @@ -230,7 +228,7 @@ class DataMovementOpConverter : public OpConversionPattern { rewriter.create(loc, block->getArgument(0)); rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); - return ConversionPattern::matchSuccess(); + return success(); } }; @@ -377,15 +375,15 @@ class IotaConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( xla_lhlo::IotaOp iotaOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto resultMemrefType = iotaOp.getOperand().getType().dyn_cast(); - if (!resultMemrefType) return matchFailure(); + if (!resultMemrefType) return failure(); auto resultElementType = resultMemrefType.getElementType(); - if (!resultElementType.isSignlessIntOrFloat()) return matchFailure(); + if (!resultElementType.isSignlessIntOrFloat()) return failure(); // Construct the indexing maps needed for linalg.generic ops. unsigned nloops = resultMemrefType.getRank(); @@ -420,7 +418,7 @@ class IotaConverter : public OpConversionPattern { } rewriter.create(loc, castOp->getResult(0)); rewriter.eraseOp(iotaOp); - return matchSuccess(); + return success(); } }; @@ -428,17 +426,17 @@ class ConstConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( xla_lhlo::ConstOp constOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = constOp.getLoc(); auto valueAttr = constOp.value().cast(); - if (valueAttr.getType().getRank() != 0) return matchFailure(); + if (valueAttr.getType().getRank() != 0) return failure(); auto stdConstOp = rewriter.create(loc, valueAttr.getValue({})); rewriter.create(loc, stdConstOp, constOp.getOperand()); rewriter.eraseOp(constOp); - return matchSuccess(); + return success(); } }; @@ -446,7 +444,7 @@ class SliceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - PatternMatchResult matchAndRewrite( + LogicalResult matchAndRewrite( xla_lhlo::SliceOp sliceOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = sliceOp.getLoc(); @@ -454,7 +452,7 @@ class SliceConverter : public OpConversionPattern { sliceOp.getOperand(0).getType().template dyn_cast(); if (!argType || !argType.hasRank()) { emitError(loc, "lhlo to linalg conversion expects known-rank args"); - return ConversionPattern::matchFailure(); + return failure(); } SmallVector ranges; @@ -472,7 +470,7 @@ class SliceConverter : public OpConversionPattern { rewriter.create(loc, sliceOp.getOperand(0), ranges); rewriter.create(loc, linalg_slice, sliceOp.getOperand(1)); rewriter.eraseOp(sliceOp); - return matchSuccess(); + return success(); } }; From a581bff59ed66a3496fab4557f62b3d49f844cac Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Tue, 17 Mar 2020 11:37:49 -0700 Subject: [PATCH 084/492] Apply name change(experimental_run_v2 -> run) for all callers in Tensorflow. PiperOrigin-RevId: 301417899 Change-Id: I8083161d769163e0ac95534b4d78b34c8af69350 --- .../mixed_precision/experimental/loss_scale_benchmark.py | 4 ++-- tensorflow/python/keras/optimizer_v2/optimizer_v2.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_benchmark.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_benchmark.py index c3835efa702..8f8f50b4052 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_benchmark.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_benchmark.py @@ -127,7 +127,7 @@ class LossScaleBenchmark(test.Benchmark): return opt.minimize(get_loss, var_list) if mode == 'graph': - run_op = strategy.experimental_run_v2(minimize_fn) + run_op = strategy.run(minimize_fn) init_op = variables.global_variables_initializer() with session_module.Session() as sess: sess.run(init_op) @@ -136,7 +136,7 @@ class LossScaleBenchmark(test.Benchmark): return def run_fn(): - strategy.experimental_run_v2(minimize_fn) + strategy.run(minimize_fn) if mode == 'tf_function': run_fn = def_function.function(run_fn) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 72cdda616b5..2a4d4cf86e8 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -486,7 +486,7 @@ class OptimizerV2(trackable.Trackable): if distribute_ctx.in_cross_replica_context(): raise RuntimeError( "`apply_gradients() cannot be called in cross-replica context. " - "Use `tf.distribute.Strategy.experimental_run_v2` to enter replica " + "Use `tf.distribute.Strategy.run` to enter replica " "context.") apply_state = self._prepare(var_list) From ac303a810d440d56033766651ddeffd3e2ee5f57 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Tue, 17 Mar 2020 11:45:26 -0700 Subject: [PATCH 085/492] [tf.data] Unify reshuffling state management. - Consolidate FixedSeedDataset and ReshufflingDataset into `ShuffleDatasetOp::Dataset`. FixedSeedDataset is kept around for forwards compatibility. It will be removed after the forward compatibility window has expired on 4/10. - Rename ReshufflingDatasetV2 to DatasetV2 and have it support both reshuffle=true and reshuffle=false. PiperOrigin-RevId: 301419373 Change-Id: Ia489d8a85a9da0866b5144f97da2b511f3bc5d10 --- .../api_def_AnonymousSeedGenerator.pbtxt | 4 + .../api_def_DeleteSeedGenerator.pbtxt | 4 + .../core/kernels/data/random_seed_ops.cc | 91 +++++++------ .../core/kernels/data/random_seed_ops.h | 63 ++++++--- .../core/kernels/data/shuffle_dataset_op.cc | 126 ++++++++++-------- .../core/kernels/data/shuffle_dataset_op.h | 4 +- tensorflow/core/ops/dataset_ops.cc | 19 +++ .../python/data/kernel_tests/shuffle_test.py | 8 +- tensorflow/python/data/ops/dataset_ops.py | 58 +++++++- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 8 ++ .../api/golden/v2/tensorflow.raw_ops.pbtxt | 8 ++ 11 files changed, 270 insertions(+), 123 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_AnonymousSeedGenerator.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_DeleteSeedGenerator.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousSeedGenerator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousSeedGenerator.pbtxt new file mode 100644 index 00000000000..a8bbfb4cf22 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AnonymousSeedGenerator.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "AnonymousSeedGenerator" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_DeleteSeedGenerator.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeleteSeedGenerator.pbtxt new file mode 100644 index 00000000000..3c8d9d96c63 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DeleteSeedGenerator.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "DeleteSeedGenerator" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/data/random_seed_ops.cc b/tensorflow/core/kernels/data/random_seed_ops.cc index 09a7687a919..ea403086818 100644 --- a/tensorflow/core/kernels/data/random_seed_ops.cc +++ b/tensorflow/core/kernels/data/random_seed_ops.cc @@ -26,16 +26,39 @@ namespace tensorflow { namespace data { namespace { +const char kAnonymousRandomSeedGenerator[] = "AnonymousRandomSeedGenerator"; const char kNumRandomSamples[] = "num_random_samples"; +const char kFixedSeedGenerator[] = "FixedSeedGenerator"; const char kRandomSeedGenerator[] = "RandomSeedGenerator"; +const char kSeedGenerator[] = "SeedGenerator"; const char kSeed[] = "seed"; const char kSeed2[] = "seed2"; +const char kReshuffle[] = "reshuffle"; } // namespace +int64 SeedGenerator::num_random_samples() { + tf_shared_lock l(mu_); + return num_random_samples_; +} + +void SeedGenerator::set_num_random_samples(int64 num_random_samples) { + mutex_lock l(mu_); + num_random_samples_ = num_random_samples; +} + +string FixedSeedGenerator::DebugString() const { return kFixedSeedGenerator; } + +void FixedSeedGenerator::GenerateSeeds(int64* seed1, int64* seed2) { + mutex_lock l(mu_); + num_random_samples_++; + *seed1 = seed_; + *seed2 = seed2_; +} + string RandomSeedGenerator::DebugString() const { return kRandomSeedGenerator; } -void RandomSeedGenerator::GenerateRandomSeeds(int64* seed1, int64* seed2) { +void RandomSeedGenerator::GenerateSeeds(int64* seed1, int64* seed2) { mutex_lock l(mu_); num_random_samples_++; *seed1 = generator_(); @@ -43,16 +66,6 @@ void RandomSeedGenerator::GenerateRandomSeeds(int64* seed1, int64* seed2) { *seed2 = generator_(); } -int64 RandomSeedGenerator::num_random_samples() { - tf_shared_lock l(mu_); - return num_random_samples_; -} - -void RandomSeedGenerator::set_num_random_samples(int64 num_random_samples) { - mutex_lock l(mu_); - num_random_samples_ = num_random_samples; -} - void RandomSeedGenerator::Reset() { mutex_lock l(mu_); // Reset the generators based on the current seeds. @@ -62,25 +75,11 @@ void RandomSeedGenerator::Reset() { generator_.Skip(num_random_samples_); } -void RandomSeedGenerator::Serialize(OpKernelContext* ctx) { - mutex_lock l(mu_); - Tensor* num_random_samples; - OP_REQUIRES_OK(ctx, ctx->allocate_output(kNumRandomSamples, TensorShape({}), - &num_random_samples)); - num_random_samples->scalar()() = num_random_samples_; - Tensor* seed; - OP_REQUIRES_OK(ctx, ctx->allocate_output(kSeed, TensorShape({}), &seed)); - seed->scalar()() = seed_; - Tensor* seed2; - OP_REQUIRES_OK(ctx, ctx->allocate_output(kSeed2, TensorShape({}), &seed2)); - seed2->scalar()() = seed2_; -} - -AnonymousRandomSeedGeneratorHandleOp::AnonymousRandomSeedGeneratorHandleOp( +AnonymousSeedGeneratorHandleOp::AnonymousSeedGeneratorHandleOp( OpKernelConstruction* ctx) - : AnonymousResourceOp(ctx) {} + : AnonymousResourceOp(ctx) {} -void AnonymousRandomSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) { +void AnonymousSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) { int64 seed; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSeed, &seed)); int64 seed2; @@ -91,22 +90,33 @@ void AnonymousRandomSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) { } seed_ = seed; seed2_ = seed2; - AnonymousResourceOp::Compute(ctx); + + // TODO(b/151115950): Remove this case when the forward compatibility window + // expires. + if (ctx->op_kernel().def().op() == kAnonymousRandomSeedGenerator) { + reshuffle_ = true; + } else { + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, kReshuffle, &reshuffle_)); + } + AnonymousResourceOp::Compute(ctx); } -string AnonymousRandomSeedGeneratorHandleOp::name() { - return kRandomSeedGenerator; -} +std::string AnonymousSeedGeneratorHandleOp::name() { return kSeedGenerator; } -Status AnonymousRandomSeedGeneratorHandleOp::CreateResource( +Status AnonymousSeedGeneratorHandleOp::CreateResource( OpKernelContext* ctx, std::unique_ptr flib_def, std::unique_ptr pflr, - FunctionLibraryRuntime* lib, RandomSeedGenerator** resource) { - *resource = new RandomSeedGenerator(seed_, seed2_); + FunctionLibraryRuntime* lib, SeedGenerator** resource) { + if (reshuffle_) { + *resource = new RandomSeedGenerator(seed_, seed2_); + } else { + *resource = new FixedSeedGenerator(seed_, seed2_); + } return Status::OK(); } -void DeleteRandomSeedGeneratorOp::Compute(OpKernelContext* ctx) { +void DeleteSeedGeneratorOp::Compute(OpKernelContext* ctx) { ResourceHandle handle = ctx->input(0).flat()(0); // The resource is guaranteed to exist because the variant tensor wrapping the // deleter is provided as an unused input to this op, which guarantees that it @@ -115,12 +125,17 @@ void DeleteRandomSeedGeneratorOp::Compute(OpKernelContext* ctx) { } namespace { +REGISTER_KERNEL_BUILDER(Name("AnonymousSeedGenerator").Device(DEVICE_CPU), + AnonymousSeedGeneratorHandleOp); + +REGISTER_KERNEL_BUILDER(Name("DeleteSeedGenerator").Device(DEVICE_CPU), + DeleteSeedGeneratorOp); REGISTER_KERNEL_BUILDER(Name("AnonymousRandomSeedGenerator").Device(DEVICE_CPU), - AnonymousRandomSeedGeneratorHandleOp); + AnonymousSeedGeneratorHandleOp); REGISTER_KERNEL_BUILDER(Name("DeleteRandomSeedGenerator").Device(DEVICE_CPU), - DeleteRandomSeedGeneratorOp); + DeleteSeedGeneratorOp); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/random_seed_ops.h b/tensorflow/core/kernels/data/random_seed_ops.h index 1a336466ffa..54332c7a820 100644 --- a/tensorflow/core/kernels/data/random_seed_ops.h +++ b/tensorflow/core/kernels/data/random_seed_ops.h @@ -25,8 +25,37 @@ limitations under the License. namespace tensorflow { namespace data { -// A random seed generator resource. -class RandomSeedGenerator : public ResourceBase { +// Base class for seed generator resources. Subclasses customize how seeds are +// generated. +class SeedGenerator : public ResourceBase { + public: + virtual void GenerateSeeds(int64* seed1, int64* seed2) = 0; + virtual void Reset() = 0; + + virtual int64 num_random_samples(); + virtual void set_num_random_samples(int64 num_random_samples); + + protected: + mutex mu_; + int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0; +}; + +// Always generates the specified seed values. +class FixedSeedGenerator : public SeedGenerator { + public: + FixedSeedGenerator(int64 seed, int64 seed2) : seed_(seed), seed2_(seed2) {} + + std::string DebugString() const override; + void GenerateSeeds(int64* seed1, int64* seed2) override; + void Reset() override {} + + private: + const int64 seed_; + const int64 seed2_; +}; + +// Generates different (but deterministically chosen) seed values. +class RandomSeedGenerator : public SeedGenerator { public: RandomSeedGenerator(int64 seed, int64 seed2) : seed_(seed), @@ -34,30 +63,24 @@ class RandomSeedGenerator : public ResourceBase { parent_generator_(seed, seed2), generator_(&parent_generator_) {} - int64 num_random_samples(); - void set_num_random_samples(int64 num_random_samples); - - string DebugString() const override; - void GenerateRandomSeeds(int64* seed1, int64* seed2); - void Reset(); - void Serialize(OpKernelContext* ctx); + std::string DebugString() const override; + void GenerateSeeds(int64* seed1, int64* seed2) override; + void Reset() override; private: const int64 seed_; const int64 seed2_; - mutex mu_; random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_); random::SingleSampleAdapter generator_ TF_GUARDED_BY(mu_); - int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0; }; -// Creates an instance of random seed generator resource and transfers ownership +// Creates an instance of seed generator resource and transfers ownership // to the caller. -class AnonymousRandomSeedGeneratorHandleOp - : public AnonymousResourceOp { +class AnonymousSeedGeneratorHandleOp + : public AnonymousResourceOp { public: - explicit AnonymousRandomSeedGeneratorHandleOp(OpKernelConstruction* ctx); + explicit AnonymousSeedGeneratorHandleOp(OpKernelConstruction* ctx); void Compute(OpKernelContext* ctx) override; private: @@ -66,17 +89,17 @@ class AnonymousRandomSeedGeneratorHandleOp std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib, - RandomSeedGenerator** resource) override; + SeedGenerator** resource) override; int64 seed_; int64 seed2_; + bool reshuffle_; }; -// Deletes an instance of random seed generator resource. -class DeleteRandomSeedGeneratorOp : public OpKernel { +// Deletes an instance of seed generator resource. +class DeleteSeedGeneratorOp : public OpKernel { public: - explicit DeleteRandomSeedGeneratorOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} + explicit DeleteSeedGeneratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override; }; diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index ce68f533664..337c82c1c61 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -65,11 +65,12 @@ constexpr char kSlicesStart[] = "slices_start"; constexpr char kSlicesEnd[] = "slices_end"; constexpr char kBuffer[] = "buffer"; constexpr char kSize[] = "size"; -constexpr char kRandomSeedGenerator[] = "RandomSeedGenerator"; +constexpr char kSeedGenerator[] = "SeedGenerator"; constexpr char kTFData[] = "tf_data"; constexpr char kDSNumRandomSamples[] = "ds_num_random_samples"; constexpr char kFixedSeedDatasetPrefix[] = "FixedSeed"; -constexpr char kReshufflingDatasetPrefix[] = "Reshuffling"; +constexpr char kDatasetPrefix[] = "Dataset"; +constexpr char kDatasetV2Prefix[] = "DatasetV2"; constexpr char kShuffleDataset[] = "ShuffleDataset"; namespace { @@ -448,16 +449,17 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { const TraceMeMetadata traceme_metadata_; }; -// A dataset that uses a pseudorandom sequence of seeds for the iterators -// created from it. Used when `reshuffle_each_iteration` is true. -class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { +class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase { public: - ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input, - int64 buffer_size, Seeds seeds, int64 count) - : ShuffleDatasetBase(ctx, input, buffer_size, count), seeds_(seeds) {} + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, + Seeds seeds, int64 count, bool reshuffle_each_iteration) + : ShuffleDatasetBase(ctx, input, buffer_size, count), + seeds_(seeds), + reshuffle_each_iteration_(reshuffle_each_iteration) {} + string DebugString() const override { name_utils::DatasetDebugStringParams params; - params.dataset_prefix = kReshufflingDatasetPrefix; + params.dataset_prefix = kDatasetPrefix; params.set_args(buffer_size_, seeds_.seed_, seeds_.seed2_); return name_utils::DatasetDebugString(kDatasetType, params); } @@ -471,11 +473,10 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { } protected: - class Iterator : public ShuffleDatasetBase::Iterator { + class Iterator : public ShuffleDatasetBase::Iterator { public: Iterator(const Params& params, int64 seed, int64 seed2) - : ShuffleDatasetBase::Iterator(params, seed, - seed2) {} + : ShuffleDatasetBase::Iterator(params, seed, seed2) {} ~Iterator() override { seed_generator_->Unref(); } @@ -483,10 +484,10 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { // Firstly, lookup or create a seed generator from the IteratorResource // resource_mgr. ResourceMgr* mgr = ctx->resource_mgr(); - RandomSeedGenerator* seed_generator; + SeedGenerator* seed_generator; const string name = strings::StrCat( prefix(), name_utils::kDelimiter, dataset()->type_string(), - name_utils::kDelimiter, kRandomSeedGenerator); + name_utils::kDelimiter, kSeedGenerator); int64 dataset_seed, dataset_seed2; { @@ -496,18 +497,23 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { dataset_seed = seed_; dataset_seed2 = seed2_; } - TF_RETURN_IF_ERROR(mgr->LookupOrCreate( + TF_RETURN_IF_ERROR(mgr->LookupOrCreate( kTFData, name, &seed_generator, - [dataset_seed, dataset_seed2](RandomSeedGenerator** seed_generator) { + [this, dataset_seed, dataset_seed2](SeedGenerator** seed_generator) { // On the first iterator creation, use the original seeds from the - // dataset to seed a `RandomSeedGenerator` that will provide seeds + // dataset to seed a `SeedGenerator` that will provide seeds // for subsequent repetitions of the same dataset. - *seed_generator = - new RandomSeedGenerator(dataset_seed, dataset_seed2); + if (dataset()->reshuffle_each_iteration_) { + *seed_generator = + new RandomSeedGenerator(dataset_seed, dataset_seed2); + } else { + *seed_generator = + new FixedSeedGenerator(dataset_seed, dataset_seed2); + } return Status::OK(); })); seed_generator_ = seed_generator; - seed_generator_->GenerateRandomSeeds(&seed_, &seed2_); + seed_generator_->GenerateSeeds(&seed_, &seed2_); mutex_lock l(mu_); ResetRngs(); return Status::OK(); @@ -527,8 +533,7 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { seed_generator_->num_random_samples())); // Save the Iterator. - return ShuffleDatasetBase::Iterator::SaveInternal( - writer); + return ShuffleDatasetBase::Iterator::SaveInternal(writer); } Status RestoreInternal(IteratorContext* ctx, @@ -541,12 +546,12 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { seed_generator_->Reset(); // Restore the Iterator. - return ShuffleDatasetBase::Iterator::RestoreInternal( - ctx, reader); + return ShuffleDatasetBase::Iterator::RestoreInternal(ctx, + reader); } private: - RandomSeedGenerator* seed_generator_; + SeedGenerator* seed_generator_; }; Status AsGraphDefInternal(SerializationContext* ctx, @@ -561,8 +566,8 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); TF_RETURN_IF_ERROR( - AddSeeds(seeds_, ctx->preserve_random_seeds(), b, &seed, &seed2)); - b->BuildAttrValue(true, &reshuffle_each_iteration); + AddSeeds(seeds_, /*preserve_random_seeds=*/true, b, &seed, &seed2)); + b->BuildAttrValue(reshuffle_each_iteration_, &reshuffle_each_iteration); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node, buffer_size, seed, seed2}, // Inputs {std::make_pair(kReshuffleEachIteration, @@ -573,25 +578,25 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { private: const Seeds seeds_; + const bool reshuffle_each_iteration_; }; -// A dataset that uses a pseudorandom sequence of seeds for the iterators -// created from it. Used in TF 2.0 when `reshuffle_each_iteration` is true. -class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase { +// A shuffle dataset that uses an external seed generator resource to choose the +// shuffle seeds for each iteration. +class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase { public: - ReshufflingDatasetV2(OpKernelContext* ctx, const DatasetBase* input, - int64 buffer_size, int64 count, - RandomSeedGenerator* seed_generator, - std::unique_ptr handle) + DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, + int64 count, SeedGenerator* seed_generator, + std::unique_ptr handle) : ShuffleDatasetBase(ctx, input, buffer_size, count), seed_generator_(seed_generator), handle_(std::move(handle)) {} - ~ReshufflingDatasetV2() override { seed_generator_->Unref(); } + ~DatasetV2() override { seed_generator_->Unref(); } string DebugString() const override { name_utils::DatasetDebugStringParams params; - params.dataset_prefix = kReshufflingDatasetPrefix; + params.dataset_prefix = kDatasetV2Prefix; params.set_args(buffer_size_); return name_utils::DatasetDebugString(kDatasetType, params); } @@ -610,15 +615,15 @@ class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase { } protected: - class Iterator : public ShuffleDatasetBase::Iterator { + class Iterator : public ShuffleDatasetBase::Iterator { public: - Iterator(const Params& params, RandomSeedGenerator* seed_generator) - : ShuffleDatasetBase::Iterator(params, 0, 0), + Iterator(const Params& params, SeedGenerator* seed_generator) + : ShuffleDatasetBase::Iterator(params, 0, 0), seed_generator_(seed_generator) {} Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); - seed_generator_->GenerateRandomSeeds(&seed_, &seed2_); + seed_generator_->GenerateSeeds(&seed_, &seed2_); ResetRngs(); return Status::OK(); } @@ -636,8 +641,7 @@ class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase { seed_generator_->num_random_samples())); // Save the tterator state. - return ShuffleDatasetBase::Iterator::SaveInternal( - writer); + return ShuffleDatasetBase::Iterator::SaveInternal(writer); } Status RestoreInternal(IteratorContext* ctx, @@ -650,12 +654,12 @@ class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase { seed_generator_->Reset(); // Restore the iterator state. - return ShuffleDatasetBase::Iterator< - ReshufflingDatasetV2>::RestoreInternal(ctx, reader); + return ShuffleDatasetBase::Iterator::RestoreInternal(ctx, + reader); } private: - RandomSeedGenerator* seed_generator_; + SeedGenerator* seed_generator_; }; Status AsGraphDefInternal(SerializationContext* ctx, @@ -678,12 +682,13 @@ class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase { } private: - RandomSeedGenerator* seed_generator_ = nullptr; + SeedGenerator* seed_generator_ = nullptr; std::unique_ptr handle_; }; // A dataset that uses the same fixed seed for all iterators created from it. // Used when `reshuffle_each_iteration` is false. +// TODO(b/151115950): delete this class. class ShuffleDatasetOp::FixedSeedDataset : public ShuffleDatasetBase { public: FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input, @@ -752,20 +757,20 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, int64 count = 1; if (op_version_ == 2) { - RandomSeedGenerator* seed_generator = nullptr; + SeedGenerator* seed_generator = nullptr; OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &seed_generator)); // Create a fresh handle for the resource because the input handle can // become invalid after this op executes. std::unique_ptr handle; - OP_REQUIRES_OK(ctx, - OwnedResourceHandle::Create(ctx, seed_generator, - kRandomSeedGenerator, &handle)); + OP_REQUIRES_OK( + ctx, OwnedResourceHandle::Create( + ctx, seed_generator, seed_generator->DebugString(), &handle)); - // Ownership of seed generator is transferred onto `ReshufflingDatasetV2`. - *output = new ReshufflingDatasetV2(ctx, input, buffer_size, count, - seed_generator, std::move(handle)); + // Ownership of seed generator is transferred onto `DatasetV2`. + *output = new ShuffleDatasetOp::DatasetV2( + ctx, input, buffer_size, count, seed_generator, std::move(handle)); return; } @@ -775,12 +780,17 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, int64 seed2; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSeed2, &seed2)); - if (reshuffle_each_iteration_) { - *output = new ReshufflingDataset(ctx, input, buffer_size, - Seeds(seed, seed2), count); - } else { + if (!reshuffle_each_iteration_) { + // This dataset is only needed to support old clients running v2 eager with + // reshuffle_each_iteration_=false. We can't tell here whether we are in v2 + // eager, so we conservatively always use FixedSeedDataset when + // reshuffle_each_iteration=false. *output = new FixedSeedDataset(ctx, input, buffer_size, Seeds(seed, seed2), count); + } else { + *output = new ShuffleDatasetOp::Dataset(ctx, input, buffer_size, + Seeds(seed, seed2), count, + reshuffle_each_iteration_); } } @@ -817,7 +827,7 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); TF_RETURN_IF_ERROR( - AddSeeds(seeds_, ctx->preserve_random_seeds(), b, &seed, &seed2)); + AddSeeds(seeds_, /*preserve_random_seeds=*/true, b, &seed, &seed2)); TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.h b/tensorflow/core/kernels/data/shuffle_dataset_op.h index 33b33f8d7e0..165a1db4c45 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.h +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.h @@ -48,8 +48,8 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { DatasetBase** output) override; private: - class ReshufflingDataset; - class ReshufflingDatasetV2; + class Dataset; + class DatasetV2; class FixedSeedDataset; int op_version_; bool reshuffle_each_iteration_; diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 3329867dc89..74e0d5bcf84 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -403,6 +403,24 @@ REGISTER_OP("RangeDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("AnonymousSeedGenerator") + .Input("seed: int64") + .Input("seed2: int64") + .Input("reshuffle: bool") + .Output("handle: resource") + .Output("deleter: variant") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("DeleteSeedGenerator") + .Input("handle: resource") + .Input("deleter: variant") + .SetShapeFn(shape_inference::NoOutputs); + +// Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. REGISTER_OP("AnonymousRandomSeedGenerator") .Input("seed: int64") .Input("seed2: int64") @@ -414,6 +432,7 @@ REGISTER_OP("AnonymousRandomSeedGenerator") return Status::OK(); }); +// Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. REGISTER_OP("DeleteRandomSeedGenerator") .Input("handle: resource") .Input("deleter: variant") diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py index 7a1273c9d47..81a97867b71 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_test.py @@ -329,9 +329,13 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + # We skip v2 eager since the v2 eager shuffle dataset is not serializable due + # to its use of an external seed generator resource. @combinations.generate( - combinations.times(test_base.default_test_combinations(), - combinations.combine(reshuffle=[True, False]))) + combinations.times( + test_base.graph_only_combinations() + + combinations.combine(mode=["eager"], tf_api_version=1), + combinations.combine(reshuffle=[True, False]))) def testRerandomizeOnReplicate(self, reshuffle): random_seed.set_random_seed(None) # When no seeds are fixed, each instantiation of the shuffle dataset should diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 4c957b45f68..b81de54f16e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -3523,6 +3523,51 @@ class CacheDataset(UnaryUnchangedStructureDataset): super(CacheDataset, self).__init__(input_dataset, variant_tensor) +class _SeedGeneratorDeleter(object): + """An object which cleans up an anonymous seed generator resource. + + An alternative to defining a __del__ method on an object. Even if the parent + object is part of a reference cycle, the cycle will be collectable. + """ + + def __init__(self, handle, device, deleter): + self._deleter = deleter + self._handle = handle + self._device = device + self._eager_mode = context.executing_eagerly() + + def __del__(self): + with ops.device(self._device): + # Make sure the resource is deleted in the same mode as it was created in. + if self._eager_mode: + with context.eager_mode(): + gen_dataset_ops.delete_seed_generator( + handle=self._handle, deleter=self._deleter) + else: + with context.graph_mode(): + gen_dataset_ops.delete_seed_generator( + handle=self._handle, deleter=self._deleter) + + +class _SeedGenerator(object): + """Represents a fixed seed generator resource.""" + + def __init__(self, seed, seed2, reshuffle): + super(_SeedGenerator, self).__init__() + self._device = context.context().device_name + self._handle, self._deleter = ( + gen_dataset_ops.anonymous_seed_generator( + seed=seed, seed2=seed2, reshuffle=reshuffle)) + self._resource_deleter = _SeedGeneratorDeleter( + handle=self._handle, device=self._device, deleter=self._deleter) + + @property + def handle(self): + return self._handle + + +# TODO(b/151115950): Remove this class after forward compatibility window +# expires class _RandomSeedGeneratorDeleter(object): """An object which cleans up an anonymous random seed generator resource. @@ -3549,6 +3594,8 @@ class _RandomSeedGeneratorDeleter(object): handle=self._handle, deleter=self._deleter) +# TODO(b/151115950): Remove this class after forward compatibility window +# expires class _RandomSeedGenerator(object): """Represents a random seed generator resource.""" @@ -3602,9 +3649,14 @@ class ShuffleDataset(UnaryUnchangedStructureDataset): else: self._reshuffle_each_iteration = reshuffle_each_iteration - if tf2.enabled() and self._reshuffle_each_iteration and ( - context.executing_eagerly() or ops.inside_function()): - self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2) + if (tf2.enabled() and (self._reshuffle_each_iteration or + compat.forward_compatible(2020, 4, 10)) and + (context.executing_eagerly() or ops.inside_function())): + if compat.forward_compatible(2020, 4, 10): + self._seed_generator = _SeedGenerator(self._seed, self._seed2, + self._reshuffle_each_iteration) + else: + self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2) variant_tensor = gen_dataset_ops.shuffle_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access buffer_size=self._buffer_size, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 45c92b94119..8df5fe219f6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -108,6 +108,10 @@ tf_module { name: "AnonymousRandomSeedGenerator" argspec: "args=[\'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "AnonymousSeedGenerator" + argspec: "args=[\'seed\', \'seed2\', \'reshuffle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Any" argspec: "args=[\'input\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -1064,6 +1068,10 @@ tf_module { name: "DeleteRandomSeedGenerator" argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DeleteSeedGenerator" + argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DeleteSessionTensor" argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 45c92b94119..8df5fe219f6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -108,6 +108,10 @@ tf_module { name: "AnonymousRandomSeedGenerator" argspec: "args=[\'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "AnonymousSeedGenerator" + argspec: "args=[\'seed\', \'seed2\', \'reshuffle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Any" argspec: "args=[\'input\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -1064,6 +1068,10 @@ tf_module { name: "DeleteRandomSeedGenerator" argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DeleteSeedGenerator" + argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DeleteSessionTensor" argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " From 9450882b98f8ebd2e566f4d86c610149f01d186e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 11:50:51 -0700 Subject: [PATCH 086/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301420666 Change-Id: I74828c10fc62a71ca39269977b9d2632d6631a96 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 52a9bf9551b..6456f104ad3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From 6e41335c8369145e4871aee08c505ad304aac631 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Tue, 17 Mar 2020 12:01:23 -0700 Subject: [PATCH 087/492] internal code change PiperOrigin-RevId: 301422909 Change-Id: I00f70d1d429a0980bcbfc3c45958ecefaefe4280 --- .../swift/Sources/CoreMlDelegate.swift | 50 ------------------- 1 file changed, 50 deletions(-) delete mode 100644 tensorflow/lite/experimental/swift/Sources/CoreMlDelegate.swift diff --git a/tensorflow/lite/experimental/swift/Sources/CoreMlDelegate.swift b/tensorflow/lite/experimental/swift/Sources/CoreMlDelegate.swift deleted file mode 100644 index 21e0276578c..00000000000 --- a/tensorflow/lite/experimental/swift/Sources/CoreMlDelegate.swift +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2020 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import TensorFlowLiteC - -/// A delegate that uses the `Core ML` framework for performing TensorFlow Lite graph operations. -/// -/// - Important: This is an experimental interface that is subject to change. -public final class CoreMLDelegate: Delegate { - /// The configuration options for the `CoreMLDelegate`. - public let options: Options - - // Conformance to the `Delegate` protocol. - public private(set) var cDelegate: CDelegate - - /// Creates a new instance configured with the given `options`. - /// - /// - Parameters: - /// - options: Configurations for the delegate. The default is a new instance of - /// `CoreMLDelegate.Options` with the default configuration values. - public init(options: Options = Options()) { - self.options = options - var delegateOptions = TfLiteCoreMlDelegateOptions() - cDelegate = TfLiteCoreMlDelegateCreate(&delegateOptions) - } - - deinit { - TfLiteCoreMlDelegateDelete(cDelegate) - } -} - -extension CoreMLDelegate { - /// Options for configuring the `CoreMLDelegate`. - // TODO(b/143931022): Add preferred device support. - public struct Options: Equatable, Hashable { - /// Creates a new instance with the default values. - public init() {} - } -} From a12180b8025ca4b9bb7d9c795b4bcf8c42640df6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 17 Mar 2020 12:09:04 -0700 Subject: [PATCH 088/492] Remove the failing test on macos. PiperOrigin-RevId: 301424820 Change-Id: Ib0623bab8459c758473e16b6c4be2751ab0f7577 --- tensorflow/python/ops/image_ops_impl.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index b73bfa44b03..6a9fed66800 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -3253,19 +3253,6 @@ def rgb_to_yuv(images): value of the pixels. The output is only well defined if the value in images are in [0,1]. - Usage Example: - - >>> x = [[[0.1, 0.2, 0.3], - ... [0.4, 0.5, 0.6]], - ... [[0.7, 0.8, 0.9], - ... [0.10, 0.11, 0.12]]] - >>> tf.image.rgb_to_yuv(x) - - Args: images: 2-D or higher rank. Image data to convert. Last dimension must be size 3. From 5e37185b8ea3759ac24935f2dd077ffa01ef7a0c Mon Sep 17 00:00:00 2001 From: "Ahmed S. Taei" Date: Tue, 17 Mar 2020 12:15:08 -0700 Subject: [PATCH 089/492] Enable xla_hlo -> linalg.generic conversion for rank-0 tensors PiperOrigin-RevId: 301426047 Change-Id: Ief3131d37a03f8d135b82e78825b2c4d12860c40 --- .../mlir/xla/tests/hlo-legalize-to-linalg.mlir | 13 +++++++++++++ .../mlir/xla/transforms/xla_legalize_to_linalg.cc | 11 ++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index 1f4c9c6ea6c..c2fb840ad10 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -336,3 +336,16 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK-NEXT: %[[CMP:.*]] = cmpi "sgt", %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()> +// CHECK-LABEL: func @add_scalar +func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { + %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]] +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 983e5795253..17ba38e8c40 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -94,9 +94,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern { // This doesnt account for implicit broadcast, but the working assumption // here is that are broadcasts have been made explicit. unsigned nloops = argType.getRank(); - if (!nloops) { - return failure(); - } + + if (isLHLO && !nloops) ConversionPattern::matchFailure(); + int operandCount = (isLHLO ? args.size() - 1 : args.size()); auto verifyArgOrResultType = [&](Value val) -> ShapedType { auto shapedType = val.getType().dyn_cast(); @@ -105,8 +105,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern { !shapedType.isa()) || shapedType.getRank() != nloops) return nullptr; - indexingMaps.emplace_back( - AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops))); + indexingMaps.emplace_back(AffineMapAttr::get( + nloops ? rewriter.getMultiDimIdentityMap(nloops) + : AffineMap::get(nloops, 0, rewriter.getContext()))); return shapedType; }; for (const auto& arg : llvm::enumerate(args)) { From b1e79c7a4f3231987469e52c8b97d44837599557 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 12:19:33 -0700 Subject: [PATCH 090/492] Create Variables to track mini-batches seen in Model.fit / evaluate / predict. Use these counters in the TensorBoard Callback. PiperOrigin-RevId: 301427015 Change-Id: I73fade763d8dcb58ad471d333e4e5b53992356db --- tensorflow/python/keras/callbacks.py | 450 +++++++++++------- tensorflow/python/keras/callbacks_test.py | 6 +- tensorflow/python/keras/callbacks_v1.py | 29 +- tensorflow/python/keras/engine/training.py | 63 +-- tensorflow/python/keras/engine/training_v1.py | 3 - .../keras/tests/model_subclassing_test.py | 15 - .../python/keras/utils/version_utils.py | 22 - ...orflow.keras.callbacks.-tensor-board.pbtxt | 2 - ...orflow.keras.callbacks.-tensor-board.pbtxt | 1 - 9 files changed, 285 insertions(+), 306 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 9177d89c67b..bb9e61d01a2 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -35,19 +35,21 @@ import six from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import distributed_file_utils from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.distribute import multi_worker_training_state as training_state from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils -from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.training import checkpoint_management @@ -1612,7 +1614,7 @@ class LearningRateScheduler(Callback): @keras_export('keras.callbacks.TensorBoard', v1=[]) -class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): +class TensorBoard(Callback): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -1674,10 +1676,11 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): batches. Note that writing too frequently to TensorBoard can slow down your training. profile_batch: Profile the batch(es) to sample compute characteristics. - profile_batch must be a non-negative integer or a tuple of integers. - A pair of positive integers signify a range of batches to profile. - By default, it will profile the second batch. Set profile_batch=0 - to disable profiling. Must run in TensorFlow eager mode. + profile_batch must be a non-negative integer or a comma separated string + of pair of positive integers. A pair of positive integers signify a + range of batches to profile. By default, it will profile the second + batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow + eager mode. embeddings_freq: frequency (in epochs) at which embedding layers will be visualized. If set to 0, embeddings won't be visualized. embeddings_metadata: a dictionary which maps layer name to a file name in @@ -1710,18 +1713,30 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): self.histogram_freq = histogram_freq self.write_graph = write_graph self.write_images = write_images - self.update_freq = 1 if update_freq == 'batch' else update_freq + if update_freq == 'batch': + self.update_freq = 1 + else: + self.update_freq = update_freq self.embeddings_freq = embeddings_freq self.embeddings_metadata = embeddings_metadata - self._init_profile_batch(profile_batch) - self._epoch = 0 - # Lazily initialized in order to avoid creating event files when - # not needed. + self._samples_seen = 0 + self._samples_seen_at_last_write = 0 + self._current_batch = 0 + + # A collection of file writers currently in use, to be closed when + # training ends for this callback. Writers are keyed by the + # directory name under the root logdir: e.g., "train" or + # "validation". + self._train_run_name = 'train' + self._validation_run_name = 'validation' self._writers = {} - - # Used to restore any existing `SummaryWriter` after training ends. - self._prev_summary_state = [] + self._start_batch, self._stop_batch = self._init_profile_batch( + profile_batch) + if self._start_batch > 0: + profiler.warmup() # Improve the profiling accuracy. + # True when a trace is running. + self._is_tracing = False def _validate_kwargs(self, kwargs): """Handle arguments were supported in V1.""" @@ -1753,56 +1768,37 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): def set_model(self, model): """Sets Keras model and writes graph if specified.""" self.model = model - self._log_write_dir = self._get_log_write_dir() - self._train_dir = os.path.join(self._log_write_dir, 'train') - self._train_step = self.model._train_counter # pylint: disable=protected-access + # In case this callback is used via native Keras, _get_distribution_strategy does not exist. + if hasattr(self.model, '_get_distribution_strategy'): + # TensorBoard callback involves writing a summary file in a + # possibly distributed settings. + self._log_write_dir = distributed_file_utils.write_dirpath( + self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access + else: + self._log_write_dir = self.log_dir - self._val_dir = os.path.join(self._log_write_dir, 'validation') - self._val_step = self.model._test_counter # pylint: disable=protected-access + with context.eager_mode(): + self._close_writers() + if self.write_graph: + with self._get_writer(self._train_run_name).as_default(): + with summary_ops_v2.always_record_summaries(): + if not model.run_eagerly: + summary_ops_v2.graph(K.get_graph(), step=0) - self._writers = {} # Resets writers. + summary_writable = ( + self.model._is_graph_network or # pylint: disable=protected-access + self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access + if summary_writable: + summary_ops_v2.keras_model('keras', self.model, step=0) - if self.write_graph: - self._write_keras_model_graph() if self.embeddings_freq: self._configure_embeddings() - @property - def _train_writer(self): - if 'train' not in self._writers: - self._writers['train'] = summary_ops_v2.create_file_writer_v2( - self._train_dir) - return self._writers['train'] - - @property - def _val_writer(self): - if 'val' not in self._writers: - self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir) - return self._writers['val'] - - def _get_log_write_dir(self): - """For multi-worker, only chief should write, others write to '/tmp'.""" - return distributed_file_utils.write_dirpath(self.log_dir, - self.model.distribute_strategy) - - def _delete_tmp_write_dir(self): - """Deletes tmp write directories for multi-worker.""" - distributed_file_utils.remove_temp_dirpath(self.log_dir, - self.model.distribute_strategy) - - def _write_keras_model_graph(self): - """Writes Keras graph networks to TensorBoard.""" - with self._train_writer.as_default(): - with summary_ops_v2.always_record_summaries(): - if not self.model.run_eagerly: - summary_ops_v2.graph(K.get_graph(), step=0) - - summary_writable = ( - self.model._is_graph_network or # pylint: disable=protected-access - self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access - if summary_writable: - summary_ops_v2.keras_model('keras', self.model, step=0) + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + self._prev_summary_recording = summary_state.is_recording + self._prev_summary_writer = summary_state.writer + self._prev_summary_step = summary_state.step def _configure_embeddings(self): """Configure the Projector for embeddings.""" @@ -1843,44 +1839,74 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): writer = DummyWriter(self._log_write_dir) projector.visualize_embeddings(writer, config) - def _push_writer(self, writer, step): + def _close_writers(self): + """Close all remaining open file writers owned by this callback. + + If there are no such file writers, this is a no-op. + """ + with context.eager_mode(): + for writer in six.itervalues(self._writers): + writer.close() + self._writers.clear() + + def _get_writer(self, writer_name): + """Get a summary writer for the given subdirectory under the logdir. + + A writer will be created if it does not yet exist. + + Arguments: + writer_name: The name of the directory for which to create or + retrieve a writer. Should be either `self._train_run_name` or + `self._validation_run_name`. + + Returns: + A `SummaryWriter` object. + """ + if writer_name not in self._writers: + path = os.path.join(self._log_write_dir, writer_name) + writer = summary_ops_v2.create_file_writer_v2(path) + self._writers[writer_name] = writer + return self._writers[writer_name] + + def _set_default_writer(self, writer_name): """Sets the default writer for custom batch-level summaries.""" if self.update_freq == 'epoch': + # Writer is only used for custom summaries, which are written + # batch-by-batch. return - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - self._prev_summary_state.append({ - 'is_recording': summary_state.is_recording, - 'writer': summary_state.writer, - 'step': summary_state.step - }) + step = self._total_batches_seen[writer_name] - if self.update_freq == 'epoch': - should_record = False - writer = None + def _should_record(): + return math_ops.equal(step % self.update_freq, 0) + + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + summary_state.is_recording = _should_record + summary_state.writer = self._get_writer(writer_name) + summary_ops_v2.set_step(step) + + def _init_batch_steps(self): + """Create the total batch counters.""" + if ops.executing_eagerly_outside_functions(): + # Variables are needed for the `step` value of custom tf.summaries + # to be updated inside a tf.function. + self._total_batches_seen = { + self._train_run_name: variables.Variable(0, dtype='int64'), + self._validation_run_name: variables.Variable(0, dtype='int64') + } else: - should_record = lambda: math_ops.equal(step % self.update_freq, 0) + # Custom tf.summaries are not supported in legacy graph mode. + self._total_batches_seen = { + self._train_run_name: 0, + self._validation_run_name: 0 + } - summary_state.is_recording = should_record - summary_state.writer = writer - # TODO(b/151339474): Fix deadlock when not using .value() here. - summary_ops_v2.set_step(step.value()) - - def _pop_writer(self): - """Pops the current writer.""" - if self.update_freq == 'epoch': - return - - prev_state = self._prev_summary_state.pop() - - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.is_recording = prev_state['is_recording'] - summary_state.writer = prev_state['writer'] - summary_ops_v2.set_step(prev_state['step']) - - def _close_writers(self): - for writer in self._writers.values(): - writer.close() + def _increment_step(self, writer_name): + step = self._total_batches_seen[writer_name] + if isinstance(step, variables.Variable): + step.assign_add(1) + else: + self._total_batches_seen[writer_name] += 1 def _init_profile_batch(self, profile_batch): """Validate profile_batch value and set the range of batches to profile. @@ -1900,79 +1926,75 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): """ profile_batch_error_message = ( - 'profile_batch must be a non-negative integer or 2-tuple of positive ' - 'integers. A pair of positive integers signifies a range of batches ' - 'to profile. Found: {}'.format(profile_batch)) - - # Support legacy way of specifying "start,stop" or "start" as str. - if isinstance(profile_batch, six.string_types): - profile_batch = str(profile_batch).split(',') - profile_batch = nest.map_structure(int, profile_batch) - - if isinstance(profile_batch, int): - self._start_batch = profile_batch - self._stop_batch = profile_batch - elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2: - self._start_batch, self._stop_batch = profile_batch + 'profile_batch must be a non-negative integer or a comma separated ' + 'string of pair of positive integers. A pair of positive integers ' + 'signify a range of batches to profile.') + try: + profile_range = [int(i) for i in str(profile_batch).split(',')] + except ValueError: + raise ValueError(profile_batch_error_message) + if len(profile_range) == 1: # single batch + start_batch, stop_batch = profile_range[0], profile_range[0] + if start_batch < 0: + raise ValueError(profile_batch_error_message) + elif len(profile_range) == 2: # (start_batch, stop_batch) + start_batch, stop_batch = profile_range + # [0, 0], [-1, 100], [6, 5] are illegal. + if start_batch <= 0 or start_batch > stop_batch: + raise ValueError(profile_batch_error_message) else: raise ValueError(profile_batch_error_message) - - if self._start_batch < 0 or self._stop_batch < self._start_batch: - raise ValueError(profile_batch_error_message) - - if self._start_batch > 0: - profiler.warmup() # Improve the profiling accuracy. - # True when a trace is running. - self._is_tracing = False - - # Setting `profile_batch=0` disables profiling. - self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0) + return start_batch, stop_batch def on_train_begin(self, logs=None): - self._push_writer(self._train_writer, self._train_step) - - def on_train_end(self, logs=None): - self._pop_writer() - - if self._is_tracing: - self._stop_trace() - - self._close_writers() - self._delete_tmp_write_dir() + self._init_batch_steps() + if self._start_batch == 1: + self._enable_trace() def on_test_begin(self, logs=None): - self._push_writer(self._val_writer, self._val_step) - - def on_test_end(self, logs=None): - self._pop_writer() - - def on_train_batch_begin(self, batch, logs=None): - if not self._should_trace: - return - - if self._epoch == 0 and batch == self._start_batch: - self._start_trace() + self._set_default_writer(self._validation_run_name) def on_train_batch_end(self, batch, logs=None): - """Performs profiling if current batch is in profiler_batches. + """Writes scalar summaries for metrics on every training batch. + + Performs profiling if current batch is in profiler_batches. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Metric results for this batch. """ - if not self._should_trace: + # TODO(b/150629188): Make TensorBoard callback not use batch hooks + # by default. + if self.update_freq == 'epoch' and self._start_batch is None: return - if self._is_tracing and batch >= self._stop_batch: - self._stop_trace() + # Don't output batch_size and batch number as TensorBoard summaries + logs = logs or {} + train_batches = self._total_batches_seen[self._train_run_name] + if self.update_freq != 'epoch' and batch % self.update_freq == 0: + self._log_metrics(logs, prefix='batch_', step=train_batches) + + self._increment_step(self._train_run_name) + if self._is_tracing: + control_flow_ops.cond( + math_ops.greater_equal(train_batches, self._stop_batch), + lambda: self._log_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda + else: + control_flow_ops.cond( + math_ops.equal(train_batches, self._start_batch - 1), + lambda: self._enable_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda + + def on_test_batch_end(self, batch, logs=None): + if self.update_freq == 'epoch': + return + self._increment_step(self._validation_run_name) def on_epoch_begin(self, epoch, logs=None): - # Keeps track of epoch for profiling. - self._epoch = epoch + self._set_default_writer(self._train_run_name) def on_epoch_end(self, epoch, logs=None): """Runs metrics and histogram summaries at epoch end.""" - self._log_epoch_metrics(epoch, logs) + self._log_metrics(logs, prefix='epoch_', step=epoch) if self.histogram_freq and epoch % self.histogram_freq == 0: self._log_weights(epoch) @@ -1980,57 +2002,124 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): if self.embeddings_freq and epoch % self.embeddings_freq == 0: self._log_embeddings(epoch) - def _start_trace(self): - summary_ops_v2.trace_on(graph=True, profiler=False) - profiler.start(logdir=self._train_dir) + def on_train_end(self, logs=None): + if self._is_tracing: + self._log_trace() + self._close_writers() + + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + summary_state.is_recording = self._prev_summary_recording + summary_state.writer = self._prev_summary_writer + summary_state.step = self._prev_summary_step + + # In case this callback is used via native Keras, _get_distribution_strategy does not exist. + if hasattr(self.model, '_get_distribution_strategy'): + # Safely remove the unneeded temp files. + distributed_file_utils.remove_temp_dirpath( + self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access + + def _enable_trace(self): + """Starts to collect trace graph to TensorBoard. + + Collects both trace and graph in eager mode, and trace only in graph mode. + """ + if context.executing_eagerly(): + # Graph must be traced in eager mode. + summary_ops_v2.trace_on(graph=True, profiler=False) + profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) self._is_tracing = True - def _stop_trace(self, batch=None): - """Logs the trace graph to TensorBoard.""" - if batch is None: - batch = self._stop_batch - with self._train_writer.as_default(): - with summary_ops_v2.always_record_summaries(): - # TODO(b/126388999): Remove step info in the summary name. - summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch) + def _enable_trace_return_true(self): + """Starts to collect trace graph to TensorBoard and returns True. + + Returns: + True. + """ + self._enable_trace() + return True + + def _log_trace(self): + """Logs the trace graph to TensorBoard. + + Logs both trace and graph in eager mode, and trace only in graph mode. + """ profiler.stop() + if context.executing_eagerly(): + # Graph must be traced in eager mode. + with self._get_writer(self._train_run_name).as_default(), \ + summary_ops_v2.always_record_summaries(): + # TODO(b/126388999): Remove step info in the summary name. + step = K.get_value(self._total_batches_seen[self._train_run_name]) + summary_ops_v2.trace_export(name='batch_%d' % step, step=step) self._is_tracing = False - def _log_epoch_metrics(self, epoch, logs): - """Writes epoch metrics out as scalar summaries. + def _log_trace_return_true(self): + """Logs the trace graph to TensorBoard and returns True. + + Returns: + True. + """ + self._log_trace() + return True + + def _log_metrics(self, logs, prefix, step): + """Writes metrics out as custom scalar summaries. Arguments: - epoch: Int. The global step to use for TensorBoard. - logs: Dict. Keys are scalar summary names, values are scalars. + logs: Dict. Keys are scalar summary names, values are NumPy scalars. + prefix: String. The prefix to apply to the scalar summary names. + step: Int. The global step to use for TensorBoard. """ - if not logs: - return + if logs is None: + logs = {} - train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')} - val_logs = {k: v for k, v in logs.items() if k.startswith('val_')} + # Group metrics by the name of their associated file writer. Values + # are lists of metrics, as (name, scalar_value) pairs. + logs_by_writer = { + self._train_run_name: [], + self._validation_run_name: [], + } + validation_prefix = 'val_' + for (name, value) in logs.items(): + if name in ('batch', 'size', 'num_steps'): + # Scrub non-metric items. + continue + if name.startswith(validation_prefix): + name = name[len(validation_prefix):] + writer_name = self._validation_run_name + else: + writer_name = self._train_run_name + name = prefix + name # assign batch or epoch prefix + logs_by_writer[writer_name].append((name, value)) - with summary_ops_v2.always_record_summaries(): - if train_logs: - with self._train_writer.as_default(): - for name, value in train_logs.items(): - summary_ops_v2.scalar('epoch_' + name, value, step=epoch) - if val_logs: - with self._val_writer.as_default(): - for name, value in val_logs.items(): - name = name[4:] # Remove 'val_' prefix. - summary_ops_v2.scalar('epoch_' + name, value, step=epoch) + with context.eager_mode(): + with summary_ops_v2.always_record_summaries(): + for writer_name in logs_by_writer: + these_logs = logs_by_writer[writer_name] + if not these_logs: + # Don't create a "validation" events file if we don't + # actually have any validation data. + continue + writer = self._get_writer(writer_name) + with writer.as_default(): + for (name, value) in these_logs: + summary_ops_v2.scalar(name, value, step=step) def _log_weights(self, epoch): """Logs the weights of the Model to TensorBoard.""" - with self._train_writer.as_default(): - with summary_ops_v2.always_record_summaries(): - for layer in self.model.layers: - for weight in layer.weights: - weight_name = weight.name.replace(':', '_') - summary_ops_v2.histogram(weight_name, weight, step=epoch) - if self.write_images: - self._log_weight_as_image(weight, weight_name, epoch) - self._train_writer.flush() + writer = self._get_writer(self._train_run_name) + with context.eager_mode(), \ + writer.as_default(), \ + summary_ops_v2.always_record_summaries(): + for layer in self.model.layers: + for weight in layer.weights: + weight_name = weight.name.replace(':', '_') + with ops.init_scope(): + weight = K.get_value(weight) + summary_ops_v2.histogram(weight_name, weight, step=epoch) + if self.write_images: + self._log_weight_as_image(weight, weight_name, epoch) + writer.flush() def _log_weight_as_image(self, weight, weight_name, epoch): """Logs a weight as a TensorBoard image.""" @@ -2061,9 +2150,6 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): 'keras_embedding.ckpt-{}'.format(epoch)) self.model.save_weights(embeddings_ckpt) - def _implements_train_batch_hooks(self): - return not (self._start_batch == 0 and self._stop_batch == 0) - @keras_export('keras.callbacks.ReduceLROnPlateau') class ReduceLROnPlateau(Callback): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 54f71402177..eb62d0b29ee 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -2079,19 +2079,17 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): model.fit( np.zeros((64, 1)), np.zeros((64, 1)), - batch_size=32, callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)], ) # Verifies trace exists in the first train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) + self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) model.fit( np.zeros((64, 1)), np.zeros((64, 1)), - batch_size=32, callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)], ) # Verifies trace exists in the second train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) + self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) def test_TensorBoard_autoTrace_profileBatchRange(self): model = self._get_seq_model() diff --git a/tensorflow/python/keras/callbacks_v1.py b/tensorflow/python/keras/callbacks_v1.py index 09af890b76c..524e039f597 100644 --- a/tensorflow/python/keras/callbacks_v1.py +++ b/tensorflow/python/keras/callbacks_v1.py @@ -39,7 +39,7 @@ from tensorflow.python.util.tf_export import keras_export @keras_export(v1=['keras.callbacks.TensorBoard']) -class TensorBoard(callbacks.TensorBoard): +class TensorBoard(callbacks.Callback): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -127,8 +127,7 @@ class TensorBoard(callbacks.TensorBoard): embeddings_data=None, update_freq='epoch', profile_batch=2): - # Don't call super's init since it is an eager-only version. - callbacks.Callback.__init__(self) + super(TensorBoard, self).__init__() self.log_dir = log_dir self.histogram_freq = histogram_freq if self.histogram_freq and context.executing_eagerly(): @@ -343,21 +342,6 @@ class TensorBoard(callbacks.TensorBoard): self.writer.add_summary(summary, step) self.writer.flush() - def on_train_batch_begin(self, batch, logs=None): - if (not self._is_profiling and - self._total_batches_seen == self._profile_batch - 1): - profiler.start(self.log_dir) - self._is_profiling = True - - def on_train_batch_end(self, batch, logs=None): - return self.on_batch_end(batch, logs) - - def on_test_begin(self, logs=None): - pass - - def on_test_end(self, logs=None): - pass - def on_batch_end(self, batch, logs=None): """Writes scalar summaries for metrics on every training batch. @@ -374,13 +358,18 @@ class TensorBoard(callbacks.TensorBoard): self._write_custom_summaries(self._total_batches_seen, batch_logs) self._samples_seen_at_last_write = self._samples_seen self._total_batches_seen += 1 - if self._is_profiling: profiler.stop() self._is_profiling = False + elif (not self._is_profiling and + self._total_batches_seen == self._profile_batch - 1): + profiler.start(self.log_dir) + self._is_profiling = True def on_train_begin(self, logs=None): - pass + if self._profile_batch == 1: + profiler.start(self.log_dir) + self._is_profiling = True def on_epoch_begin(self, epoch, logs=None): """Add histogram op to Model eval_function callbacks, reset batch count.""" diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 21361f680da..7dcf10a506c 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import copy -import itertools from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribute_coordinator_context as dc_context @@ -29,7 +28,6 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import monitoring -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import callbacks as callbacks_module from tensorflow.python.keras import optimizers @@ -45,8 +43,6 @@ from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import summary_ops_v2 -from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.profiler import trace @@ -165,9 +161,6 @@ class Model(network.Network, version_utils.ModelVersionSelector): Checkout [guide](https://www.tensorflow.org/guide/keras/overview) for additional details. """ - _TF_MODULE_IGNORED_PROPERTIES = frozenset( - itertools.chain(('_train_counter', '_test_counter', '_predict_counter'), - network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access def __init__(self, *args, **kwargs): super(Model, self).__init__(*args, **kwargs) @@ -193,18 +186,6 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.compiled_loss = None self.compiled_metrics = None - self._init_batch_counters() - - @trackable.no_automatic_dependency_tracking - def _init_batch_counters(self): - # Untracked Variables, used to keep track of mini-batches seen in `fit`, - # `evaluate`, and `predict`. - agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA - self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg) - self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg) - self._predict_counter = variables.Variable( - 0, dtype='int64', aggregation=agg) - def get_weights(self): """Retrieves the weights of the model. @@ -518,18 +499,11 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.train_function def train_function(iterator): - """Runs one call to `self.train_function`.""" - - def run_step(data): - outputs = self.train_step(data) - self._train_counter.assign_add(1) - return outputs - data = next(iterator) - outputs = self.distribute_strategy.run(run_step, args=(data,)) + outputs = self.distribute_strategy.run( + self.train_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') - write_scalar_summaries(outputs, step=self._train_counter) return outputs if not self.run_eagerly: @@ -788,7 +762,6 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.stop_training = False train_function = self.make_train_function() - self._train_counter.assign(0) callbacks.on_train_begin() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to @@ -899,15 +872,9 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.test_function def test_function(iterator): - """Runs one call to `self.test_function`.""" - - def run_step(data): - outputs = self.test_step(data) - self._test_counter.assign_add(1) - return outputs - data = next(iterator) - outputs = self.distribute_strategy.run(run_step, args=(data,)) + outputs = self.distribute_strategy.run( + self.test_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') return outputs @@ -1036,7 +1003,6 @@ class Model(network.Network, version_utils.ModelVersionSelector): steps=data_handler.inferred_steps) test_function = self.make_test_function() - self._test_counter.assign(0) callbacks.on_test_begin() for _, iterator in data_handler.enumerate_epochs(): # Single epoch. self.reset_metrics() @@ -1109,15 +1075,9 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.predict_function def predict_function(iterator): - """Runs one call to `self.predict_function`.""" - - def run_step(data): - outputs = self.predict_step(data) - self._predict_counter.assign_add(1) - return outputs - data = next(iterator) - outputs = self.distribute_strategy.run(run_step, args=(data,)) + outputs = self.distribute_strategy.run( + self.predict_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='concat') return outputs @@ -1232,7 +1192,6 @@ class Model(network.Network, version_utils.ModelVersionSelector): steps=data_handler.inferred_steps) predict_function = self.make_predict_function() - self._predict_counter.assign(0) callbacks.on_predict_begin() for _, iterator in data_handler.enumerate_epochs(): # Single epoch. with data_handler.catch_stop_iteration(): @@ -1775,13 +1734,3 @@ def _minimize(tape, optimizer, loss, trainable_variables): all_reduce_sum_gradients=False) else: optimizer.apply_gradients(zip(gradients, trainable_variables)) - - -def _is_scalar(x): - return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 - - -def write_scalar_summaries(logs, step): - for name, value in logs.items(): - if _is_scalar(value): - summary_ops_v2.scalar('batch_' + name, value, step=step) diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 710f9bf3497..1c0fea91337 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -162,9 +162,6 @@ class Model(training_lib.Model): self._v1_compile_was_called = False - def _init_batch_counters(self): - pass # Batch counters should not be created in legacy graph mode. - @trackable.no_automatic_dependency_tracking def _set_strategy(self, strategy): self._compile_time_distribution_strategy = strategy diff --git a/tensorflow/python/keras/tests/model_subclassing_test.py b/tensorflow/python/keras/tests/model_subclassing_test.py index 5af1148f4f0..761f720cea5 100644 --- a/tensorflow/python/keras/tests/model_subclassing_test.py +++ b/tensorflow/python/keras/tests/model_subclassing_test.py @@ -737,21 +737,6 @@ class CustomCallSignatureTests(test.TestCase, parameterized.TestCase): self.assertLen(new_model.variables, 1) self.assertLen(new_model.layers, 1) - def test_batch_counters_not_in_variables(self): - - class MyModel(keras.Model): - - def __init__(self): - super(MyModel, self).__init__() - self.layer = keras.layers.Dense(4) - - def call(self, obs): - return self.layer(obs) - - model = MyModel() - model(np.ones((10, 10))) - self.assertLen(model.variables, 2) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/utils/version_utils.py b/tensorflow/python/keras/utils/version_utils.py index 377f370430c..cf485e1080d 100644 --- a/tensorflow/python/keras/utils/version_utils.py +++ b/tensorflow/python/keras/utils/version_utils.py @@ -36,13 +36,6 @@ base_layer = lazy_loader.LazyLoader( base_layer_v1 = lazy_loader.LazyLoader( "base_layer_v1", globals(), "tensorflow.python.keras.engine.base_layer_v1") -callbacks = lazy_loader.LazyLoader( - "callbacks", globals(), - "tensorflow.python.keras.callbacks") -callbacks_v1 = lazy_loader.LazyLoader( - "callbacks_v1", globals(), - "tensorflow.python.keras.callbacks_v1") - # pylint: enable=g-inconsistent-quotes @@ -65,21 +58,6 @@ class LayerVersionSelector(object): return super(LayerVersionSelector, cls).__new__(cls) -class TensorBoardVersionSelector(object): - """Chooses between Keras v1 and v2 TensorBoard callback class.""" - - def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument - eager_enabled = ops.executing_eagerly_outside_functions() - start_cls = cls - cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, - eager_enabled) - if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard: - # Since the v2 class is not a subclass of the v1 class, __init__ has to - # be called manually. - return cls(*args, **kwargs) - return super(TensorBoardVersionSelector, cls).__new__(cls) - - def swap_class(cls, v2_cls, v1_cls, eager_enabled): """Swaps in v2_cls or v1_cls depending on graph mode.""" if cls == object: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt index 2e0c6c97826..4504633d4a1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -1,9 +1,7 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { is_instance: "" - is_instance: "" is_instance: "" - is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt index 51d6901e936..24385e2722a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" member_method { name: "__init__" From 6e62523b641568b336d46687272cb0fb9455a461 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 17 Mar 2020 12:37:19 -0700 Subject: [PATCH 091/492] [TF] Expose C API in libtensorflow_framework. While at it, expose the associated header files from tensorflow/c/ in pip package. Note, we expose the subset of C API that doesn't require tensorflow/cc linkage; specifically the core operations that exclude building while loops and gradient ops, and also excluding the experimental API. The experimental API can also be added in the future, by factoring it into "core" and "non-core" targets. Similarly for the C eager API. PiperOrigin-RevId: 301430667 Change-Id: I5ae7f3cedfe9dc72184d39ef1147193450c3d92e --- tensorflow/BUILD | 1 + tensorflow/c/BUILD | 76 +- tensorflow/c/c_api.cc | 2083 +---------------------- tensorflow/c/c_api.h | 1423 +--------------- tensorflow/c/c_api_internal.h | 8 +- tensorflow/c/c_core_api.cc | 2193 +++++++++++++++++++++++++ tensorflow/c/c_core_api.h | 1456 ++++++++++++++++ tensorflow/c/eager/BUILD | 2 +- tensorflow/c/eager/c_api.cc | 2 +- tensorflow/c/eager/c_api.h | 2 +- tensorflow/tools/pip_package/BUILD | 1 + tensorflow/tools/pip_package/setup.py | 1 + 12 files changed, 3739 insertions(+), 3509 deletions(-) create mode 100644 tensorflow/c/c_core_api.cc create mode 100644 tensorflow/c/c_core_api.h diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 55406a5686a..005acff27f7 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -644,6 +644,7 @@ tf_cc_shared_object( "//tensorflow/core:lib_internal_impl", "//tensorflow/core/profiler:profiler_impl", "//tensorflow/stream_executor:stream_executor_impl", + "//tensorflow/c:c_core_api_no_xla", "//tensorflow:tf_framework_version_script.lds", ] + tf_additional_binary_deps(), ) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index c5574793b74..248bb826c28 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -23,6 +23,7 @@ filegroup( srcs = [ "c_api.h", "c_api_experimental.h", + "c_core_api.h", "tf_attrtype.h", "tf_datatype.h", "tf_file_statistics.h", @@ -73,6 +74,7 @@ tf_cuda_library( hdrs = [ "c_api.h", "c_api_internal.h", + "c_core_api.h", "tf_datatype.h", "tf_tensor.h", ], @@ -116,10 +118,41 @@ cc_library( visibility = ["//visibility:public"], ) +tf_cuda_library( + name = "c_core_api", + hdrs = [ + "c_core_api.h", + "tf_attrtype.h", + "tf_datatype.h", + "tf_file_statistics.h", + "tf_status.h", + "tf_tensor.h", + ], + copts = tf_copts(), + visibility = [ + "//visibility:public", + ], + deps = [ + ":c_core_api_no_xla", + ":c_api_internal", + ":tf_attrtype", + ":tf_status_internal", + ":tf_file_statistics", + ":tf_tensor_internal", + ] + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], + }), +) + tf_cuda_library( name = "c_api", hdrs = [ "c_api.h", + "c_core_api.h", "tf_attrtype.h", "tf_datatype.h", "tf_file_statistics.h", @@ -129,6 +162,7 @@ tf_cuda_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":c_core_api", ":c_api_no_xla", ":c_api_internal", ":tf_attrtype", @@ -144,11 +178,48 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "c_core_api_no_xla", + srcs = [ + "c_api_function.cc", + "c_core_api.cc", + ], + hdrs = [ + "c_core_api.h", + ], + copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":c_api_internal", + ":tf_attrtype", + ":tf_datatype", + ":tf_status_internal", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":tf_status", + ":tf_tensor", + "@com_google_absl//absl/strings", + "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:op_gen_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:server_lib", + ], + }), + alwayslink = 1, +) + tf_cuda_library( name = "c_api_no_xla", srcs = [ "c_api.cc", - "c_api_function.cc", ], hdrs = [ "c_api.h", @@ -159,6 +230,7 @@ tf_cuda_library( "//third_party/llvm/llvm-project:__subpackages__", ], deps = [ + ":c_core_api_no_xla", ":c_api_internal", ":tf_attrtype", ":tf_datatype", @@ -184,8 +256,6 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/kernels:logging_ops", ], }), alwayslink = 1, diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index bc1fbd3fcf5..3a110e4c9f2 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -29,9 +29,6 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/while_loop.h" -#include "tensorflow/cc/saved_model/loader.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/framework/logging.h" #include "tensorflow/core/framework/op_gen_lib.h" #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/c/c_api_internal.h" @@ -99,566 +96,14 @@ using tensorflow::TensorBuffer; using tensorflow::TensorId; using tensorflow::TensorShape; using tensorflow::TensorShapeProto; +using tensorflow::ToTensorId; using tensorflow::VersionDef; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; using tensorflow::strings::StrCat; -extern "C" { - -// -------------------------------------------------------------------------- -const char* TF_Version() { return TF_VERSION_STRING; } - -// -------------------------------------------------------------------------- - -// -------------------------------------------------------------------------- -TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } -void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } - -void TF_SetTarget(TF_SessionOptions* options, const char* target) { - options->options.target = target; -} - -void TF_SetConfig(TF_SessionOptions* options, const void* proto, - size_t proto_len, TF_Status* status) { - if (!options->options.config.ParseFromArray(proto, proto_len)) { - status->status = InvalidArgument("Unparseable ConfigProto"); - } -} -// -------------------------------------------------------------------------- -TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; } - -TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { - void* copy = tensorflow::port::Malloc(proto_len); - memcpy(copy, proto, proto_len); - - TF_Buffer* buf = new TF_Buffer; - buf->data = copy; - buf->length = proto_len; - buf->data_deallocator = [](void* data, size_t length) { - tensorflow::port::Free(data); - }; - return buf; -} - -void TF_DeleteBuffer(TF_Buffer* buffer) { - if (buffer == nullptr) return; - if (buffer->data_deallocator != nullptr) { - (*buffer->data_deallocator)(const_cast(buffer->data), - buffer->length); - } - delete buffer; -} - -TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } - -// -------------------------------------------------------------------------- - -TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, - TF_Status* status) { - Session* session; - status->status = NewSession(opt->options, &session); - if (status->status.ok()) { - return new TF_DeprecatedSession({session}); - } else { - DCHECK_EQ(nullptr, session); - return nullptr; - } -} - -void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { - status->status = s->session->Close(); -} - -void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { - status->status = Status::OK(); - if (s == nullptr) return; - delete s->session; - delete s; -} - -void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto, - size_t proto_len, TF_Status* status) { - GraphDef g; - if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { - status->status = InvalidArgument("Invalid GraphDef"); - return; - } - status->status = s->session->Extend(g); -} - -} // end extern "C" - -// Reset helper for converting character arrays to string vectors. -static void TF_Reset_Helper(const TF_SessionOptions* opt, - const char** containers, int ncontainers, - TF_Status* status) { - std::vector container_names(ncontainers); - for (int i = 0; i < ncontainers; ++i) { - container_names[i] = containers[i]; - } - - status->status = Reset(opt->options, container_names); -} - -extern "C" { - -void TF_Reset(const TF_SessionOptions* opt, const char** containers, - int ncontainers, TF_Status* status) { - TF_Reset_Helper(opt, containers, ncontainers, status); -} - -} // end extern "C" - -namespace tensorflow { - - -Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, - TF_Buffer* out) { - if (out->data != nullptr) { - return InvalidArgument("Passing non-empty TF_Buffer is invalid."); - } - const size_t proto_size = in.ByteSizeLong(); - void* buf = port::Malloc(proto_size); - if (buf == nullptr) { - return tensorflow::errors::ResourceExhausted( - "Failed to allocate memory to serialize message of type '", - in.GetTypeName(), "' and size ", proto_size); - } - if (!in.SerializeWithCachedSizesToArray(static_cast(buf))) { - port::Free(buf); - return InvalidArgument("Unable to serialize ", in.GetTypeName(), - " protocol buffer, perhaps the serialized size (", - proto_size, " bytes) is too large?"); - } - out->data = buf; - out->length = proto_size; - out->data_deallocator = [](void* data, size_t length) { port::Free(data); }; - return Status::OK(); -} - -void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type) { - // If any session has already run this node_id, mark this session as - // unrunnable. - for (auto it : graph->sessions) { - mutex_lock session_lock(it.first->mu); - if (it.first->last_num_graph_nodes > op.node.id()) { - it.second = strings::StrCat( - "Operation '", op.node.DebugString(), "' was changed by ", - mutation_type, - " after it was run by a session. This mutation will have no effect, " - "and will trigger an error in the future. Either don't modify " - "nodes after running them or create a new session."); - } - } -} - namespace { - -// Helper method that creates a shape handle for a shape described by dims. -tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( - tensorflow::shape_inference::InferenceContext* ic, int num_dims, - const int64_t* dims) { - if (num_dims != -1) { - std::vector dim_vec; - dim_vec.reserve(num_dims); - for (int i = 0; i < num_dims; ++i) { - dim_vec.push_back(ic->MakeDim(dims[i])); - } - return ic->MakeShape(dim_vec); - } else { - return ic->UnknownShape(); - } -} - -} // namespace - -void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, - int num_shapes_and_types, - const int64_t** shapes, - const int* ranks, - const TF_DataType* types, - TF_Status* status) { - Node* node = &output.oper->node; - - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - if (ic == nullptr) { - status->status = - InvalidArgument("Node ", node->name(), " was not found in the graph"); - return; - } - - auto shape_and_type_vec = - std::vector( - num_shapes_and_types); - for (int i = 0; i < num_shapes_and_types; ++i) { - tensorflow::shape_inference::ShapeHandle shape_handle = - ShapeHandleFromDims(ic, ranks[i], shapes[i]); - shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( - shape_handle, static_cast(types[i])); - } - - ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); -} - -// Helpers for loading a TensorFlow plugin (a .so file). -Status LoadLibrary(const char* library_filename, void** result, - const void** buf, size_t* len); - -// TODO(josh11b,mrry): Change Session to be able to use a Graph* -// directly, instead of requiring us to serialize to a GraphDef and -// call Session::Extend(). -bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { - if (session->graph != nullptr) { - // Take the graph lock before the session lock to avoid deadlock. This is - // safe since session->graph does not change. - session->graph->mu.lock(); - mutex_lock session_lock(session->mu); - const Graph& graph = session->graph->graph; - - const string& mutation_warning = session->graph->sessions[session]; - if (!mutation_warning.empty()) { - // TODO(b/74949947): turn this back into an error status - LOG(WARNING) << mutation_warning; - session->graph->sessions[session].clear(); - } - - const auto num_nodes = graph.num_node_ids(); - if (session->last_num_graph_nodes < num_nodes) { - // TODO(nolivia): check this on a subset of the graph instead of all of - // it. - status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; - } - - GraphDef graph_def; - *graph_def.mutable_versions() = graph.versions(); - // Fill graph_def with nodes with ids in the range - // [session->last_num_graph_nodes, num_nodes), that is the nodes - // added since the last TF_SessionRun() call. - for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { - Node* const node = graph.FindNodeId(id); - if (node != nullptr && node->IsOp()) { - NodeDef* const node_def = graph_def.add_node(); - *node_def = node->def(); - } - } - *graph_def.mutable_library() = graph.flib_def().ToProto(); - session->graph->mu.unlock(); - status->status = session->session->Extend(std::move(graph_def)); - if (!status->status.ok()) { - // Contract is we always delete input_values[i]. - return false; - } - // Note: session->session is not modified if Extend() fails, so - // we only set last_num_graph_nodes if it succeeds. - session->last_num_graph_nodes = num_nodes; - } else { - session->graph->mu.unlock(); - } - } - return true; -} - -} // namespace tensorflow - -static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, - TF_Status* status) { - status->status = Status::OK(); - for (int i = 0; i < noutputs; ++i) { - c_outputs[i] = nullptr; - } -} - -static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, - std::vector>* input_pairs, - TF_Status* status) { - const int ninputs = input_pairs->size(); - for (int i = 0; i < ninputs; ++i) { - status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); - if (!status->status.ok()) return false; - } - return true; -} - -// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to -// result in a zero-sized tensor. -static TF_Tensor* EmptyTensor(TF_DataType dtype, - const tensorflow::TensorShape& shape) { - static char empty; - tensorflow::int64 nelems = 1; - std::vector dims; - for (int i = 0; i < shape.dims(); ++i) { - dims.push_back(shape.dim_size(i)); - nelems *= shape.dim_size(i); - } - CHECK_EQ(nelems, 0); - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - return TF_NewTensor( - dtype, reinterpret_cast(dims.data()), shape.dims(), - reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); -} - -static void TF_Run_Helper( - Session* session, const char* handle, const TF_Buffer* run_options, - // Input tensors - const std::vector>& input_pairs, - // Output tensors - const std::vector& output_tensor_names, TF_Tensor** c_outputs, - // Target nodes - const std::vector& target_oper_names, TF_Buffer* run_metadata, - TF_Status* status) { - const int noutputs = output_tensor_names.size(); - std::vector outputs(noutputs); - Status result; - - if (handle == nullptr) { - RunOptions run_options_proto; - if (run_options != nullptr && !run_options_proto.ParseFromArray( - run_options->data, run_options->length)) { - status->status = InvalidArgument("Unparseable RunOptions proto"); - return; - } - if (run_metadata != nullptr && run_metadata->data != nullptr) { - status->status = - InvalidArgument("Passing non-empty run_metadata is invalid."); - return; - } - - RunMetadata run_metadata_proto; - result = session->Run(run_options_proto, input_pairs, output_tensor_names, - target_oper_names, &outputs, &run_metadata_proto); - - // Serialize back to upstream client, who now owns the new buffer - if (run_metadata != nullptr) { - status->status = MessageToBuffer(run_metadata_proto, run_metadata); - if (!status->status.ok()) return; - } - } else { - // NOTE(zongheng): PRun does not support RunOptions yet. - result = session->PRun(handle, input_pairs, output_tensor_names, &outputs); - } - if (!result.ok()) { - status->status = result; - return; - } - - // Store results in c_outputs[] - for (int i = 0; i < noutputs; ++i) { - const Tensor& src = outputs[i]; - if (!src.IsInitialized() || src.NumElements() == 0) { - c_outputs[i] = - EmptyTensor(static_cast(src.dtype()), src.shape()); - continue; - } - c_outputs[i] = TF_TensorFromTensor(src, &status->status); - if (!status->status.ok()) return; - } -} - -extern "C" { - -void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, - // Input tensors - const char** c_input_names, TF_Tensor** c_inputs, int ninputs, - // Output tensors - const char** c_output_names, TF_Tensor** c_outputs, int noutputs, - // Target nodes - const char** c_target_oper_names, int ntargets, - TF_Buffer* run_metadata, TF_Status* status) { - TF_Run_Setup(noutputs, c_outputs, status); - std::vector> input_pairs(ninputs); - if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; - for (int i = 0; i < ninputs; ++i) { - input_pairs[i].first = c_input_names[i]; - } - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = c_output_names[i]; - } - std::vector target_oper_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_oper_names[i] = c_target_oper_names[i]; - } - TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names, - c_outputs, target_oper_names, run_metadata, status); -} - -void TF_PRunSetup(TF_DeprecatedSession* s, - // Input names - const char** c_input_names, int ninputs, - // Output names - const char** c_output_names, int noutputs, - // Target nodes - const char** c_target_oper_names, int ntargets, - const char** handle, TF_Status* status) { - *handle = nullptr; - - std::vector input_names(ninputs); - std::vector output_names(noutputs); - std::vector target_oper_names(ntargets); - for (int i = 0; i < ninputs; ++i) { - input_names[i] = c_input_names[i]; - } - for (int i = 0; i < noutputs; ++i) { - output_names[i] = c_output_names[i]; - } - for (int i = 0; i < ntargets; ++i) { - target_oper_names[i] = c_target_oper_names[i]; - } - string new_handle; - status->status = s->session->PRunSetup(input_names, output_names, - target_oper_names, &new_handle); - if (status->status.ok()) { - char* buf = new char[new_handle.size() + 1]; - memcpy(buf, new_handle.c_str(), new_handle.size() + 1); - *handle = buf; - } -} - -void TF_PRun(TF_DeprecatedSession* s, const char* handle, - // Input tensors - const char** c_input_names, TF_Tensor** c_inputs, int ninputs, - // Output tensors - const char** c_output_names, TF_Tensor** c_outputs, int noutputs, - // Target nodes - const char** c_target_oper_names, int ntargets, - TF_Status* status) { - TF_Run_Setup(noutputs, c_outputs, status); - std::vector> input_pairs(ninputs); - if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; - for (int i = 0; i < ninputs; ++i) { - input_pairs[i].first = c_input_names[i]; - } - - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = c_output_names[i]; - } - std::vector target_oper_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_oper_names[i] = c_target_oper_names[i]; - } - TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names, - c_outputs, target_oper_names, nullptr, status); -} - -TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { - TF_Library* lib_handle = new TF_Library; - status->status = tensorflow::LoadLibrary( - library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, - &lib_handle->op_list.length); - if (!status->status.ok()) { - delete lib_handle; - return nullptr; - } - return lib_handle; -} - -TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } - -void TF_DeleteLibraryHandle(TF_Library* lib_handle) { - if (lib_handle == nullptr) return; - tensorflow::port::Free(const_cast(lib_handle->op_list.data)); - delete lib_handle; -} - -TF_Buffer* TF_GetAllOpList() { - std::vector op_defs; - tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs); - tensorflow::OpList op_list; - for (const auto& op : op_defs) { - *(op_list.add_op()) = op; - } - TF_Buffer* ret = TF_NewBuffer(); - TF_CHECK_OK(MessageToBuffer(op_list, ret)); - return ret; -} - -// -------------------------------------------------------------------------- -// ListDevices & SessionListDevices API - -void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; } - -TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { - TF_DeviceList* response = new TF_DeviceList; - status->status = session->session->ListDevices(&response->response); - return response; -} - -TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session, - TF_Status* status) { - TF_DeviceList* response = new TF_DeviceList; - status->status = session->session->ListDevices(&response->response); - return response; -} - -int TF_DeviceListCount(const TF_DeviceList* list) { - return list->response.size(); -} - -#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \ - return_type method_name(const TF_DeviceList* list, const int index, \ - TF_Status* status) { \ - if (list == nullptr) { \ - status->status = InvalidArgument("list is null!"); \ - return err_val; \ - } \ - if (index < 0 || index >= list->response.size()) { \ - status->status = InvalidArgument("index out of bounds"); \ - return err_val; \ - } \ - status->status = Status::OK(); \ - return list->response[index].accessor; \ - } - -TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); -TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), - nullptr); -TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); -TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0); - -#undef TF_DEVICELIST_METHOD - -} // end extern "C" - -// -------------------------------------------------------------------------- -// New Graph and Session API - -// Helper functions ----------------------------------------------------------- - -namespace { - -TF_Operation* ToOperation(Node* node) { - return static_cast(static_cast(node)); -} - -string OutputName(const TF_Output& output) { - return StrCat(output.oper->node.name(), ":", output.index); -} - -const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, - const char* attr_name, - TF_Status* status) { - const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); - if (attr == nullptr) { - status->status = InvalidArgument("Operation '", oper->node.name(), - "' has no attr named '", attr_name, "'."); - } - return attr; -} - -TensorId ToTensorId(const TF_Output& output) { - return TensorId(output.oper->node.name(), output.index); -} - #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) std::vector OutputsFromTFOutputs(TF_Output* tf_outputs, int n) { @@ -681,1134 +126,8 @@ void TFOutputsFromOutputs(const std::vector& outputs, } // namespace -// Shape functions ----------------------------------------------------------- - -void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, - const int64_t* dims, const int num_dims, - TF_Status* status) { - Node* node = &output.oper->node; - - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - if (ic == nullptr) { - status->status = - InvalidArgument("Node ", node->name(), " was not found in the graph"); - return; - } - tensorflow::shape_inference::ShapeHandle new_shape = - tensorflow::ShapeHandleFromDims(ic, num_dims, dims); - status->status = graph->refiner.SetShape(node, output.index, new_shape); -} - -int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, - TF_Status* status) { - Node* node = &output.oper->node; - - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - if (ic == nullptr) { - status->status = - InvalidArgument("Node ", node->name(), " was not found in the graph"); - return -1; - } - - tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); - - // Unknown rank means the number of dimensions is -1. - if (!ic->RankKnown(shape)) { - return -1; - } - - return ic->Rank(shape); -} - -void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, - int num_dims, TF_Status* status) { - Node* node = &output.oper->node; - - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - if (ic == nullptr) { - status->status = - InvalidArgument("Node ", node->name(), " was not found in the graph"); - return; - } - - tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); - - int rank = -1; - if (ic->RankKnown(shape)) { - rank = ic->Rank(shape); - } - - if (num_dims != rank) { - status->status = InvalidArgument("Expected rank is ", num_dims, - " but actual rank is ", rank); - return; - } - - if (num_dims == 0) { - // Output shape is a scalar. - return; - } - - // Rank is greater than 0, so fill in the values, if known, and - // -1 for unknown values. - for (int i = 0; i < num_dims; ++i) { - auto dim = ic->Dim(shape, i); - tensorflow::int64 value = -1; - if (ic->ValueKnown(dim)) { - value = ic->Value(dim); - } - dims[i] = value; - } -} - -// TF_OperationDescription functions ------------------------------------------ - extern "C" { -static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, - const char* op_type, - const char* oper_name) - TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - return new TF_OperationDescription(graph, op_type, oper_name); -} - -TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type, - const char* oper_name) { - mutex_lock l(graph->mu); - return TF_NewOperationLocked(graph, op_type, oper_name); -} - -void TF_SetDevice(TF_OperationDescription* desc, const char* device) { - desc->node_builder.Device(device); -} - -void TF_AddInput(TF_OperationDescription* desc, TF_Output input) { - desc->node_builder.Input(&input.oper->node, input.index); -} - -void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs, - int num_inputs) { - std::vector input_list; - input_list.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - input_list.emplace_back(&inputs[i].oper->node, inputs[i].index); - } - desc->node_builder.Input(input_list); -} - -void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { - desc->node_builder.ControlInput(&input->node); -} - -void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { - desc->colocation_constraints.emplace( - StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); -} - -void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, - const void* value, size_t length) { - tensorflow::StringPiece s(static_cast(value), length); - desc->node_builder.Attr(attr_name, s); -} - -void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, - const void* const* values, const size_t* lengths, - int num_values) { - if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { - desc->colocation_constraints.clear(); - for (int i = 0; i < num_values; ++i) { - desc->colocation_constraints.emplace(static_cast(values[i]), - lengths[i]); - } - } else { - std::vector v; - v.reserve(num_values); - for (int i = 0; i < num_values; ++i) { - v.emplace_back(static_cast(values[i]), lengths[i]); - } - desc->node_builder.Attr(attr_name, v); - } -} - -void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, - int64_t value) { - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - desc->node_builder.Attr(attr_name, static_cast(value)); -} - -void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name, - const int64_t* values, int num_values) { - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - desc->node_builder.Attr( - attr_name, - ArraySlice( - reinterpret_cast(values), num_values)); -} - -void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name, - float value) { - desc->node_builder.Attr(attr_name, value); -} - -void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name, - const float* values, int num_values) { - desc->node_builder.Attr(attr_name, - ArraySlice(values, num_values)); -} - -void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, - unsigned char value) { - desc->node_builder.Attr(attr_name, static_cast(value)); -} - -void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name, - const unsigned char* values, int num_values) { - std::unique_ptr b(new bool[num_values]); - for (int i = 0; i < num_values; ++i) { - b[i] = values[i]; - } - desc->node_builder.Attr(attr_name, - ArraySlice(b.get(), num_values)); -} - -void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, - TF_DataType value) { - desc->node_builder.Attr(attr_name, static_cast(value)); -} - -void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, - const TF_DataType* values, int num_values) { - desc->node_builder.Attr( - attr_name, ArraySlice( - reinterpret_cast(values), num_values)); -} - -void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name, - const char* placeholder) { - tensorflow::AttrValue attr_value; - attr_value.set_placeholder(placeholder); - desc->node_builder.Attr(attr_name, attr_value); -} - -void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, - const char* value, size_t length) { - tensorflow::NameAttrList func_name; - func_name.set_name(string(value, value + length)); - desc->node_builder.Attr(attr_name, func_name); -} - -void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, - const int64_t* dims, int num_dims) { - PartialTensorShape shape; - if (num_dims >= 0) { - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - shape = PartialTensorShape(ArraySlice( - reinterpret_cast(dims), num_dims)); - } - desc->node_builder.Attr(attr_name, shape); -} - -void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name, - const int64_t* const* dims, const int* num_dims, - int num_shapes) { - std::vector shapes; - shapes.reserve(num_shapes); - for (int i = 0; i < num_shapes; ++i) { - if (num_dims[i] < 0) { - shapes.emplace_back(); - } else { - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - shapes.emplace_back(ArraySlice( - reinterpret_cast(dims[i]), num_dims[i])); - } - } - desc->node_builder.Attr(attr_name, shapes); -} - -void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, - const char* attr_name, const void* proto, - size_t proto_len, TF_Status* status) { - // shape.ParseFromArray takes an int as length, this function takes size_t, - // make sure there is no information loss. - if (proto_len > std::numeric_limits::max()) { - status->status = InvalidArgument( - "proto_len (", proto_len, - " bytes) is too large to be parsed by the protocol buffer library"); - return; - } - TensorShapeProto shape; - if (shape.ParseFromArray(proto, static_cast(proto_len))) { - desc->node_builder.Attr(attr_name, shape); - status->status = Status::OK(); - } else { - status->status = InvalidArgument("Unparseable TensorShapeProto"); - } -} - -void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, - const char* attr_name, - const void* const* protos, - const size_t* proto_lens, int num_shapes, - TF_Status* status) { - std::vector shapes; - shapes.resize(num_shapes); - for (int i = 0; i < num_shapes; ++i) { - if (proto_lens[i] > std::numeric_limits::max()) { - status->status = InvalidArgument( - "length of element ", i, " in the list (", proto_lens[i], - " bytes) is too large to be parsed by the protocol buffer library"); - return; - } - if (!shapes[i].ParseFromArray(protos[i], static_cast(proto_lens[i]))) { - status->status = - InvalidArgument("Unparseable TensorShapeProto at index ", i); - return; - } - } - desc->node_builder.Attr(attr_name, shapes); - status->status = Status::OK(); -} - -void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, - TF_Tensor* value, TF_Status* status) { - Tensor t; - status->status = TF_TensorToTensor(value, &t); - if (status->status.ok()) desc->node_builder.Attr(attr_name, t); -} - -void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, - TF_Tensor* const* values, int num_values, - TF_Status* status) { - status->status = Status::OK(); - std::vector t; - t.reserve(num_values); - - for (int i = 0; i < num_values && status->status.ok(); ++i) { - Tensor v; - status->status = TF_TensorToTensor(values[i], &v); - t.emplace_back(v); - } - - if (status->status.ok()) desc->node_builder.Attr(attr_name, t); -} - -void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, - const void* proto, size_t proto_len, - TF_Status* status) { - tensorflow::AttrValue attr_value; - if (!attr_value.ParseFromArray(proto, proto_len)) { - status->status = InvalidArgument("Unparseable AttrValue proto"); - return; - } - - if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { - if (attr_value.value_case() != tensorflow::AttrValue::kList && - attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) { - status->status = - InvalidArgument("Expected \"list\" field for \"", - tensorflow::kColocationAttrName, "\" attribute"); - return; - } - desc->colocation_constraints.clear(); - for (const string& location : attr_value.list().s()) { - desc->colocation_constraints.insert(location); - } - } else { - desc->node_builder.Attr(attr_name, std::move(attr_value)); - } - - status->status = Status::OK(); -} - -static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, - TF_Status* status) - TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { - Node* ret = nullptr; - - if (desc->graph->name_map.count(desc->node_builder.node_name())) { - status->status = InvalidArgument("Duplicate node name in graph: '", - desc->node_builder.node_name(), "'"); - } else { - if (!desc->colocation_constraints.empty()) { - desc->node_builder.Attr( - tensorflow::kColocationAttrName, - std::vector(desc->colocation_constraints.begin(), - desc->colocation_constraints.end())); - } - status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret, - /*consume=*/true); - - if (status->status.ok()) { - // Run shape inference function for newly added node. - status->status = desc->graph->refiner.AddNode(ret); - } - if (status->status.ok()) { - // Add the node to the name-to-node mapping. - desc->graph->name_map[ret->name()] = ret; - } else if (ret != nullptr) { - desc->graph->graph.RemoveNode(ret); - ret = nullptr; - } - } - - delete desc; - - return ToOperation(ret); -} - -TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, - TF_Status* status) { - mutex_lock l(desc->graph->mu); - return TF_FinishOperationLocked(desc, status); -} - -// TF_Operation functions -// ---------------------------------------------------------- - -const char* TF_OperationName(TF_Operation* oper) { - return oper->node.name().c_str(); -} - -const char* TF_OperationOpType(TF_Operation* oper) { - return oper->node.type_string().c_str(); -} - -const char* TF_OperationDevice(TF_Operation* oper) { - return oper->node.requested_device().c_str(); -} - -int TF_OperationNumOutputs(TF_Operation* oper) { - return oper->node.num_outputs(); -} - -TF_DataType TF_OperationOutputType(TF_Output oper_out) { - return static_cast( - oper_out.oper->node.output_type(oper_out.index)); -} - -int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, - TF_Status* status) { - NameRangeMap name_ranges; - status->status = - NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); - if (!status->status.ok()) return -1; - auto iter = name_ranges.find(arg_name); - if (iter == name_ranges.end()) { - status->status = InvalidArgument("Output arg '", arg_name, "' not found"); - return -1; - } - return iter->second.second - iter->second.first; -} - -int TF_OperationNumInputs(TF_Operation* oper) { - return oper->node.num_inputs(); -} - -TF_DataType TF_OperationInputType(TF_Input oper_in) { - return static_cast(oper_in.oper->node.input_type(oper_in.index)); -} - -int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, - TF_Status* status) { - NameRangeMap name_ranges; - status->status = - NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); - if (!status->status.ok()) return -1; - auto iter = name_ranges.find(arg_name); - if (iter == name_ranges.end()) { - status->status = InvalidArgument("Input arg '", arg_name, "' not found"); - return -1; - } - return iter->second.second - iter->second.first; -} - -TF_Output TF_OperationInput(TF_Input oper_in) { - const tensorflow::Edge* edge; - Status s = oper_in.oper->node.input_edge(oper_in.index, &edge); - if (!s.ok()) { - return {nullptr, -1}; - } - - return {ToOperation(edge->src()), edge->src_output()}; -} - -void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs, - int max_inputs) { - for (auto* edge : oper->node.in_edges()) { - if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) { - inputs[edge->dst_input()] = {ToOperation(edge->src()), - edge->src_output()}; - } - } -} - -int TF_OperationOutputNumConsumers(TF_Output oper_out) { - int count = 0; - for (const auto* edge : oper_out.oper->node.out_edges()) { - if (edge->src_output() == oper_out.index) { - ++count; - } - } - return count; -} - -int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, - int max_consumers) { - int count = 0; - for (const auto* edge : oper_out.oper->node.out_edges()) { - if (edge->src_output() == oper_out.index) { - if (count < max_consumers) { - consumers[count] = {ToOperation(edge->dst()), edge->dst_input()}; - } - ++count; - } - } - return count; -} - -int TF_OperationNumControlInputs(TF_Operation* oper) { - int count = 0; - for (const auto* edge : oper->node.in_edges()) { - if (edge->IsControlEdge() && !edge->src()->IsSource()) { - ++count; - } - } - return count; -} - -int TF_OperationGetControlInputs(TF_Operation* oper, - TF_Operation** control_inputs, - int max_control_inputs) { - int count = 0; - for (const auto* edge : oper->node.in_edges()) { - if (edge->IsControlEdge() && !edge->src()->IsSource()) { - if (count < max_control_inputs) { - control_inputs[count] = ToOperation(edge->src()); - } - ++count; - } - } - return count; -} - -int TF_OperationNumControlOutputs(TF_Operation* oper) { - int count = 0; - for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge() && !edge->dst()->IsSink()) { - ++count; - } - } - return count; -} - -int TF_OperationGetControlOutputs(TF_Operation* oper, - TF_Operation** control_outputs, - int max_control_outputs) { - int count = 0; - for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge() && !edge->dst()->IsSink()) { - if (count < max_control_outputs) { - control_outputs[count] = ToOperation(edge->dst()); - } - ++count; - } - } - return count; -} - -TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, - const char* attr_name, - TF_Status* status) { - TF_AttrMetadata metadata; - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return metadata; - switch (attr->value_case()) { -#define SINGLE_CASE(kK, attr_type, size_expr) \ - case tensorflow::AttrValue::kK: \ - metadata.is_list = 0; \ - metadata.list_size = -1; \ - metadata.type = attr_type; \ - metadata.total_size = size_expr; \ - break; - - SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); - SINGLE_CASE(kI, TF_ATTR_INT, -1); - SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); - SINGLE_CASE(kB, TF_ATTR_BOOL, -1); - SINGLE_CASE(kType, TF_ATTR_TYPE, -1); - SINGLE_CASE(kShape, TF_ATTR_SHAPE, - attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); - SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); -#undef SINGLE_CASE - - case tensorflow::AttrValue::kList: - metadata.is_list = 1; - metadata.list_size = 0; - metadata.total_size = -1; -#define LIST_CASE(field, attr_type, ...) \ - if (attr->list().field##_size() > 0) { \ - metadata.type = attr_type; \ - metadata.list_size = attr->list().field##_size(); \ - __VA_ARGS__; \ - break; \ - } - - LIST_CASE( - s, TF_ATTR_STRING, metadata.total_size = 0; - for (int i = 0; i < attr->list().s_size(); - ++i) { metadata.total_size += attr->list().s(i).size(); }); - LIST_CASE(i, TF_ATTR_INT); - LIST_CASE(f, TF_ATTR_FLOAT); - LIST_CASE(b, TF_ATTR_BOOL); - LIST_CASE(type, TF_ATTR_TYPE); - LIST_CASE( - shape, TF_ATTR_SHAPE, metadata.total_size = 0; - for (int i = 0; i < attr->list().shape_size(); ++i) { - const auto& s = attr->list().shape(i); - metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); - }); - LIST_CASE(tensor, TF_ATTR_TENSOR); - LIST_CASE(tensor, TF_ATTR_FUNC); -#undef LIST_CASE - // All lists empty, determine the type from the OpDef. - if (metadata.list_size == 0) { - for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { - const auto& a = oper->node.op_def().attr(i); - if (a.name() != attr_name) continue; - const string& typestr = a.type(); - if (typestr == "list(string)") { - metadata.type = TF_ATTR_STRING; - } else if (typestr == "list(int)") { - metadata.type = TF_ATTR_INT; - } else if (typestr == "list(float)") { - metadata.type = TF_ATTR_FLOAT; - } else if (typestr == "list(bool)") { - metadata.type = TF_ATTR_BOOL; - } else if (typestr == "list(type)") { - metadata.type = TF_ATTR_TYPE; - } else if (typestr == "list(shape)") { - metadata.type = TF_ATTR_SHAPE; - } else if (typestr == "list(tensor)") { - metadata.type = TF_ATTR_TENSOR; - } else if (typestr == "list(func)") { - metadata.type = TF_ATTR_FUNC; - } else { - status->status = InvalidArgument( - "Attribute '", attr_name, - "' has an empty value of an unrecognized type '", typestr, "'"); - return metadata; - } - } - } - break; - - case tensorflow::AttrValue::kPlaceholder: - metadata.is_list = 0; - metadata.list_size = -1; - metadata.type = TF_ATTR_PLACEHOLDER; - metadata.total_size = -1; - break; - - case tensorflow::AttrValue::kFunc: - metadata.is_list = 0; - metadata.list_size = -1; - metadata.type = TF_ATTR_FUNC; - metadata.total_size = -1; - break; - - case tensorflow::AttrValue::VALUE_NOT_SET: - status->status = - InvalidArgument("Attribute '", attr_name, "' has no value set"); - break; - } - return metadata; -} - -void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, - void* value, size_t max_length, - TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - if (attr->value_case() != tensorflow::AttrValue::kS) { - status->status = - InvalidArgument("Attribute '", attr_name, "' is not a string"); - return; - } - if (max_length <= 0) { - return; - } - const auto& s = attr->s(); - std::memcpy(value, s.data(), std::min(s.length(), max_length)); -} - -void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, - void** values, size_t* lengths, - int max_values, void* storage, - size_t storage_size, TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - if (attr->value_case() != tensorflow::AttrValue::kList) { - status->status = - InvalidArgument("Value for '", attr_name, "' is not a list"); - return; - } - const auto len = std::min(max_values, attr->list().s_size()); - char* p = static_cast(storage); - for (int i = 0; i < len; ++i) { - const string& s = attr->list().s(i); - values[i] = p; - lengths[i] = s.size(); - if ((p + s.size()) > (static_cast(storage) + storage_size)) { - status->status = InvalidArgument( - "Not enough storage to hold the requested list of strings"); - return; - } - memcpy(values[i], s.data(), s.size()); - p += s.size(); - } -} - -#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ - void func(TF_Operation* oper, const char* attr_name, c_type* value, \ - TF_Status* status) { \ - cpp_type v; \ - status->status = \ - tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ - *value = static_cast(v); \ - } \ - void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ - int max_values, TF_Status* status) { \ - const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (!status->status.ok()) return; \ - if (attr->value_case() != tensorflow::AttrValue::kList) { \ - status->status = \ - InvalidArgument("Value for '", attr_name, "' is not a list."); \ - return; \ - } \ - const auto len = std::min(max_values, attr->list().list_field##_size()); \ - for (int i = 0; i < len; ++i) { \ - values[i] = static_cast(attr->list().list_field(i)); \ - } \ - } -DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); -DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); -DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b); -DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); -#undef DEFINE_GETATTR - -void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, - int64_t* value, int num_dims, TF_Status* status) { - PartialTensorShape shape; - status->status = - tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); - if (!status->status.ok()) return; - auto len = std::min(shape.dims(), num_dims); - for (int i = 0; i < len; ++i) { - value[i] = shape.dim_size(i); - } -} - -void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, - int64_t** dims, int* num_dims, int num_shapes, - int64_t* storage, int storage_size, - TF_Status* status) { - std::vector shapes; - status->status = - tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); - if (!status->status.ok()) return; - auto len = std::min(static_cast(shapes.size()), num_shapes); - int64_t* p = storage; - int storage_left = storage_size; - for (int i = 0; i < len; ++i) { - // shapes[i].dims() == -1 for shapes with an unknown rank. - int64_t n = shapes[i].dims(); - num_dims[i] = n; - dims[i] = p; - if (n < 0) { - continue; - } - if (storage_left < n) { - status->status = InvalidArgument( - "Not enough storage to hold the requested list of shapes"); - return; - } - storage_left -= n; - for (int j = 0; j < n; ++j, ++p) { - *p = shapes[i].dim_size(j); - } - } -} - -void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, - const char* attr_name, - TF_Buffer* value, TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - if (attr->value_case() != tensorflow::AttrValue::kShape) { - status->status = - InvalidArgument("Value for '", attr_name, "' is not a shape."); - return; - } - status->status = MessageToBuffer(attr->shape(), value); -} - -void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, - const char* attr_name, - TF_Buffer** values, int max_values, - TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - if (attr->value_case() != tensorflow::AttrValue::kList) { - status->status = - InvalidArgument("Value for '", attr_name, "' is not a list"); - return; - } - const auto len = std::min(max_values, attr->list().shape_size()); - for (int i = 0; i < len; ++i) { - values[i] = TF_NewBuffer(); - status->status = MessageToBuffer(attr->list().shape(i), values[i]); - if (!status->status.ok()) { - // Delete everything allocated to far, the operation has failed. - for (int j = 0; j <= i; ++j) { - TF_DeleteBuffer(values[j]); - } - return; - } - } -} - -void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, - TF_Tensor** value, TF_Status* status) { - *value = nullptr; - Tensor t; - status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); - if (!status->status.ok()) return; - *value = TF_TensorFromTensor(t, &status->status); -} - -void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, - TF_Tensor** values, int max_values, - TF_Status* status) { - std::vector ts; - status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); - if (!status->status.ok()) return; - const auto len = std::min(max_values, static_cast(ts.size())); - for (int i = 0; i < len; ++i) { - values[i] = TF_TensorFromTensor(ts[i], &status->status); - } -} - -void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, - TF_Buffer* output_attr_value, - TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - status->status = MessageToBuffer(*attr, output_attr_value); -} - -void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, - TF_Status* status) { - status->status = MessageToBuffer(oper->node.def(), output_node_def); -} - -// TF_Graph functions --------------------------------------------------------- - -TF_Graph::TF_Graph() - : graph(tensorflow::OpRegistry::Global()), - refiner(graph.versions().producer(), graph.op_registry()), - delete_requested(false), - parent(nullptr), - parent_inputs(nullptr) { - // Tell the shape refiner to also run shape inference on functions. - refiner.set_function_library_for_shape_inference(&graph.flib_def()); -} - -TF_Graph* TF_NewGraph() { return new TF_Graph; } - -void TF_DeleteGraph(TF_Graph* g) { - if (g == nullptr) return; - g->mu.lock(); - g->delete_requested = true; - const bool del = g->sessions.empty(); - g->mu.unlock(); - if (del) delete g; -} - -TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) { - mutex_lock l(graph->mu); - auto iter = graph->name_map.find(oper_name); - if (iter == graph->name_map.end()) { - return nullptr; - } else { - return ToOperation(iter->second); - } -} - -TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) { - if (*pos == 0) { - // Advance past the first sentinel nodes in every graph (the source & sink). - *pos += 2; - } else { - // Advance to the next node. - *pos += 1; - } - - mutex_lock l(graph->mu); - while (*pos < static_cast(graph->graph.num_node_ids())) { - Node* node = graph->graph.FindNodeId(*pos); - // FindNodeId() returns nullptr for nodes that have been deleted. - // We aren't currently allowing nodes to be deleted, but it is safer - // to still check. - if (node != nullptr) return ToOperation(node); - *pos += 1; - } - - // No more nodes. - return nullptr; -} - -void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, - TF_Status* status) { - GraphDef def; - { - mutex_lock l(graph->mu); - graph->graph.ToGraphDef(&def); - } - status->status = MessageToBuffer(def, output_graph_def); -} - -void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, - TF_Buffer* output_op_def, TF_Status* status) { - const OpDef* op_def; - { - mutex_lock l(graph->mu); - status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); - if (!status->status.ok()) return; - } - status->status = MessageToBuffer(*op_def, output_op_def); -} - -void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def, - TF_Status* status) { - VersionDef versions; - { - mutex_lock l(graph->mu); - versions = graph->graph.versions(); - } - status->status = MessageToBuffer(versions, output_version_def); -} - -TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { - return new TF_ImportGraphDefOptions; -} -void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) { - delete opts; -} -void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, - const char* prefix) { - opts->opts.prefix = prefix; -} -void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts, - const char* device) { - opts->opts.default_device = device; -} - -void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, - unsigned char uniquify_names) { - opts->opts.uniquify_names = uniquify_names; -} - -void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, - unsigned char uniquify_prefix) { - opts->opts.uniquify_prefix = uniquify_prefix; -} - -void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, - const char* src_name, - int src_index, TF_Output dst) { - opts->tensor_id_data.push_back(src_name); - const string& src_name_str = opts->tensor_id_data.back(); - // We don't need to store dst's name in tensor_id_data, since `dst` must - // outlive the ImportGraphDef call. - opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); -} - -void TF_ImportGraphDefOptionsRemapControlDependency( - TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) { - opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] = - TensorId(dst->node.name(), tensorflow::Graph::kControlSlot); -} - -extern void TF_ImportGraphDefOptionsAddControlDependency( - TF_ImportGraphDefOptions* opts, TF_Operation* oper) { - opts->opts.control_dependencies.push_back(oper->node.name()); -} - -void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, - const char* oper_name, int index) { - opts->tensor_id_data.push_back(oper_name); - const string& oper_name_str = opts->tensor_id_data.back(); - opts->opts.return_tensors.emplace_back(oper_name_str, index); -} - -int TF_ImportGraphDefOptionsNumReturnOutputs( - const TF_ImportGraphDefOptions* opts) { - return opts->opts.return_tensors.size(); -} - -void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, - const char* oper_name) { - opts->opts.return_nodes.push_back(oper_name); -} - -int TF_ImportGraphDefOptionsNumReturnOperations( - const TF_ImportGraphDefOptions* opts) { - return opts->opts.return_nodes.size(); -} - -void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, - int* num_outputs, - TF_Output** outputs) { - *num_outputs = results->return_tensors.size(); - *outputs = results->return_tensors.data(); -} - -void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, - int* num_opers, - TF_Operation*** opers) { - *num_opers = results->return_nodes.size(); - *opers = results->return_nodes.data(); -} - -void TF_ImportGraphDefResultsMissingUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, - const char*** src_names, int** src_indexes) { - *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); - *src_names = results->missing_unused_key_names.data(); - *src_indexes = results->missing_unused_key_indexes.data(); -} - -void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { - delete results; -} - -static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, - const TF_ImportGraphDefOptions* opts, - TF_ImportGraphDefResults* tf_results, - TF_Status* status) - TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - const int last_node_id = graph->graph.num_node_ids(); - tensorflow::ImportGraphDefResults results; - status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, - &graph->refiner, &results); - if (!status->status.ok()) return; - - // Add new nodes to name_map - for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { - auto* node = graph->graph.FindNodeId(i); - if (node != nullptr) graph->name_map[node->name()] = node; - } - - // Populate return_tensors - DCHECK(tf_results->return_tensors.empty()); - tf_results->return_tensors.resize(results.return_tensors.size()); - for (int i = 0; i < results.return_tensors.size(); ++i) { - tf_results->return_tensors[i].oper = - ToOperation(results.return_tensors[i].first); - tf_results->return_tensors[i].index = results.return_tensors[i].second; - } - - // Populate return_nodes - DCHECK(tf_results->return_nodes.empty()); - tf_results->return_nodes.resize(results.return_nodes.size()); - for (int i = 0; i < results.return_nodes.size(); ++i) { - tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); - } - - // Populate missing unused map keys - DCHECK(tf_results->missing_unused_key_names.empty()); - DCHECK(tf_results->missing_unused_key_indexes.empty()); - DCHECK(tf_results->missing_unused_key_names_data.empty()); - - size_t size = results.missing_unused_input_map_keys.size(); - tf_results->missing_unused_key_names.resize(size); - tf_results->missing_unused_key_indexes.resize(size); - - for (int i = 0; i < size; ++i) { - TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.emplace_back(id.first); - tf_results->missing_unused_key_names[i] = - tf_results->missing_unused_key_names_data.back().c_str(); - tf_results->missing_unused_key_indexes[i] = id.second; - } -} - -TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, TF_Status* status) { - GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, - graph_def->length)) { - status->status = InvalidArgument("Invalid GraphDef"); - return nullptr; - } - auto results = new TF_ImportGraphDefResults(); - mutex_lock l(graph->mu); - GraphImportGraphDefLocked(graph, def, options, results, status); - if (!status->status.ok()) { - delete results; - return nullptr; - } - return results; -} - -void TF_GraphImportGraphDefWithReturnOutputs( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, - int num_return_outputs, TF_Status* status) { - if (num_return_outputs != options->opts.return_tensors.size()) { - status->status = InvalidArgument("Expected 'num_return_outputs' to be ", - options->opts.return_tensors.size(), - ", got ", num_return_outputs); - return; - } - if (num_return_outputs > 0 && return_outputs == nullptr) { - status->status = InvalidArgument( - "'return_outputs' must be preallocated to length ", num_return_outputs); - return; - } - GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, - graph_def->length)) { - status->status = InvalidArgument("Invalid GraphDef"); - return; - } - TF_ImportGraphDefResults results; - mutex_lock l(graph->mu); - GraphImportGraphDefLocked(graph, def, options, &results, status); - DCHECK_EQ(results.return_tensors.size(), num_return_outputs); - memcpy(return_outputs, results.return_tensors.data(), - num_return_outputs * sizeof(TF_Output)); -} - -void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, - TF_Status* status) { - TF_ImportGraphDefResults* results = - TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); - TF_DeleteImportGraphDefResults(results); -} - // While loop functions ------------------------------------------------------- namespace { @@ -2161,404 +480,4 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } -// TF_Session functions ---------------------------------------------- - -TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {} - -TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, - TF_Status* status) { - Session* session; - status->status = NewSession(opt->options, &session); - if (status->status.ok()) { - TF_Session* new_session = new TF_Session(session, graph); - if (graph != nullptr) { - mutex_lock l(graph->mu); - graph->sessions[new_session] = ""; - } - return new_session; - } else { - DCHECK_EQ(nullptr, session); - return nullptr; - } -} - -TF_Session* TF_LoadSessionFromSavedModel( - const TF_SessionOptions* session_options, const TF_Buffer* run_options, - const char* export_dir, const char* const* tags, int tags_len, - TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) { -// TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring -// that the tensorflow/cc/saved_model:loader build target is mobile friendly. -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Loading a SavedModel is not supported on mobile. File a bug at " - "https://github.com/tensorflow/tensorflow/issues if this feature is " - "important to you"); - return nullptr; -#else - mutex_lock l(graph->mu); - if (!graph->name_map.empty()) { - status->status = InvalidArgument("Graph is non-empty."); - return nullptr; - } - - RunOptions run_options_proto; - if (run_options != nullptr && !run_options_proto.ParseFromArray( - run_options->data, run_options->length)) { - status->status = InvalidArgument("Unparseable RunOptions proto"); - return nullptr; - } - - std::unordered_set tag_set; - for (int i = 0; i < tags_len; i++) { - tag_set.insert(string(tags[i])); - } - - tensorflow::SavedModelBundle bundle; - status->status = - tensorflow::LoadSavedModel(session_options->options, run_options_proto, - export_dir, tag_set, &bundle); - if (!status->status.ok()) return nullptr; - - // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session - // extends using GraphDefs. The Graph instance is different, but equivalent - // to the one used to create the session. - // - // TODO(jhseu): When Session is modified to take Graphs instead of - // GraphDefs, return the Graph generated in LoadSavedModel(). - TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); - TF_ImportGraphDefResults results; - GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), - import_opts, &results, status); - TF_DeleteImportGraphDefOptions(import_opts); - if (!status->status.ok()) return nullptr; - - if (meta_graph_def != nullptr) { - status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); - if (!status->status.ok()) return nullptr; - } - - TF_Session* session = new TF_Session(bundle.session.release(), graph); - - graph->sessions[session] = ""; - session->last_num_graph_nodes = graph->graph.num_node_ids(); - return session; -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -void TF_CloseSession(TF_Session* s, TF_Status* status) { - status->status = s->session->Close(); -} - -void TF_DeleteSession(TF_Session* s, TF_Status* status) { - status->status = Status::OK(); - if (s == nullptr) return; - TF_Graph* const graph = s->graph; - if (graph != nullptr) { - graph->mu.lock(); - graph->sessions.erase(s); - const bool del = graph->delete_requested && graph->sessions.empty(); - graph->mu.unlock(); - if (del) delete graph; - } - delete s->session; - delete s; -} - -void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, - const TF_Output* inputs, TF_Tensor* const* input_values, - int ninputs, const TF_Output* outputs, - TF_Tensor** output_values, int noutputs, - const TF_Operation* const* target_opers, int ntargets, - TF_Buffer* run_metadata, TF_Status* status) { - // TODO(josh11b,mrry): Change Session to be able to use a Graph* - // directly, instead of requiring us to serialize to a GraphDef and - // call Session::Extend(). - if (session->extend_before_run && - !ExtendSessionGraphHelper(session, status)) { - return; - } - - TF_Run_Setup(noutputs, output_values, status); - - // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector> input_pairs(ninputs); - if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; - for (int i = 0; i < ninputs; ++i) { - input_pairs[i].first = OutputName(inputs[i]); - } - - // Convert from TF_Output to string names. - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = OutputName(outputs[i]); - } - - // Convert from TF_Operation* to string names. - std::vector target_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_names[i] = target_opers[i]->node.name(); - } - - // Actually run. - TF_Run_Helper(session->session, nullptr, run_options, input_pairs, - output_names, output_values, target_names, run_metadata, - status); -} - -void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, - int ninputs, const TF_Output* outputs, int noutputs, - const TF_Operation* const* target_opers, int ntargets, - const char** handle, TF_Status* status) { - *handle = nullptr; - - if (session->extend_before_run && - !ExtendSessionGraphHelper(session, status)) { - return; - } - - std::vector input_names(ninputs); - for (int i = 0; i < ninputs; ++i) { - input_names[i] = OutputName(inputs[i]); - } - - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = OutputName(outputs[i]); - } - - std::vector target_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_names[i] = target_opers[i]->node.name(); - } - - string new_handle; - status->status = session->session->PRunSetup(input_names, output_names, - target_names, &new_handle); - if (status->status.ok()) { - char* buf = new char[new_handle.size() + 1]; - memcpy(buf, new_handle.c_str(), new_handle.size() + 1); - *handle = buf; - } -} - -void TF_DeletePRunHandle(const char* handle) { - delete[] handle; - // TODO(suharshs): Free up any resources held by the partial run state. -} - -void TF_SessionPRun(TF_Session* session, const char* handle, - const TF_Output* inputs, TF_Tensor* const* input_values, - int ninputs, const TF_Output* outputs, - TF_Tensor** output_values, int noutputs, - const TF_Operation* const* target_opers, int ntargets, - TF_Status* status) { - // TODO(josh11b,mrry): Change Session to be able to use a Graph* - // directly, instead of requiring us to serialize to a GraphDef and - // call Session::Extend(). - if (session->extend_before_run && - !ExtendSessionGraphHelper(session, status)) { - return; - } - - TF_Run_Setup(noutputs, output_values, status); - - // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector> input_pairs(ninputs); - if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; - for (int i = 0; i < ninputs; ++i) { - input_pairs[i].first = OutputName(inputs[i]); - } - - // Convert from TF_Output to string names. - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = OutputName(outputs[i]); - } - - // Convert from TF_Operation* to string names. - std::vector target_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_names[i] = target_opers[i]->node.name(); - } - - TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names, - output_values, target_names, nullptr, status); -} - -unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, - TF_Tensor** result, TF_Status* status) { - *result = nullptr; - mutex_lock l(graph->mu); - OutputTensor tensor(&output.oper->node, output.index); - bool evaluated; - Tensor result_tensor; - status->status = EvaluateConstantTensor( - tensor, graph->refiner, *graph->graph.op_registry(), - graph->graph.versions().producer(), &evaluated, &result_tensor); - if (evaluated) { - DCHECK(status->status.ok()); - *result = TF_TensorFromTensor(result_tensor, &status->status); - if (!status->status.ok()) evaluated = false; - } - return evaluated; -} - -TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { - tensorflow::OpList op_list; - if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { - status->status = InvalidArgument("Unparseable OpList"); - return nullptr; - } - status->status = Status::OK(); - return new TF_ApiDefMap(op_list); -} - -void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } - -void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, - size_t text_len, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "ApiDefMap is not supported on mobile."); -#else - mutex_lock l(api_def_map->lock); - if (api_def_map->update_docs_called) { - status->status = FailedPrecondition( - "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " - "called."); - return; - } - string api_def_text(text, text_len); - status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, - size_t name_len, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "ApiDefMap is not supported on mobile."); - return nullptr; -#else - mutex_lock l(api_def_map->lock); - if (!api_def_map->update_docs_called) { - api_def_map->api_def_map.UpdateDocs(); - api_def_map->update_docs_called = true; - } - string name_str(name, name_len); - const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); - if (api_def == nullptr) { - return nullptr; - } - - TF_Buffer* ret = TF_NewBuffer(); - status->status = MessageToBuffer(*api_def, ret); - if (!status->status.ok()) { - TF_DeleteBuffer(ret); - return nullptr; - } - return ret; -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { - tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); - TF_Buffer* ret = TF_NewBuffer(); - status->status = MessageToBuffer(kernel_list, ret); - if (!status->status.ok()) { - TF_DeleteBuffer(ret); - return nullptr; - } - return ret; -} - -TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { - tensorflow::KernelList kernel_list = - tensorflow::GetRegisteredKernelsForOp(name); - TF_Buffer* ret = TF_NewBuffer(); - status->status = MessageToBuffer(kernel_list, ret); - if (!status->status.ok()) { - TF_DeleteBuffer(ret); - return nullptr; - } - return ret; -} - -// TF_Server functions ---------------------------------------------- - -#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -TF_Server::TF_Server(std::unique_ptr server) - : target(server->target()), server(std::move(server)) {} -#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) - -TF_Server* TF_NewServer(const void* proto, size_t proto_len, - TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported on mobile"); - return nullptr; -#else - tensorflow::ServerDef server_def; - if (!server_def.ParseFromArray(proto, static_cast(proto_len))) { - status->status = InvalidArgument( - "Could not parse provided bytes into a ServerDef protocol buffer"); - return nullptr; - } - - std::unique_ptr out_server; - status->status = tensorflow::NewServer(server_def, &out_server); - if (!status->status.ok()) return nullptr; - - return new TF_Server(std::move(out_server)); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -void TF_ServerStart(TF_Server* server, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported on mobile"); -#else - status->status = server->server->Start(); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -void TF_ServerStop(TF_Server* server, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported on mobile"); -#else - status->status = server->server->Stop(); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -void TF_ServerJoin(TF_Server* server, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported on mobile"); -#else - status->status = server->server->Join(); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -const char* TF_ServerTarget(TF_Server* server) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - return nullptr; -#else - return server->target.c_str(); -#endif -} - -void TF_DeleteServer(TF_Server* server) { -#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) - delete server; -#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -} - -void TF_RegisterLogListener(void (*listener)(const char*)) { -#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) - tensorflow::logging::RegisterListener(listener); -#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -} - } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 0c413f6ebae..f9942239eec 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -19,884 +19,21 @@ limitations under the License. #include #include +#include "tensorflow/c/c_core_api.h" #include "tensorflow/c/tf_attrtype.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" // -------------------------------------------------------------------------- -// C API for TensorFlow. +// Non-core C API for TensorFlow. // -// The API leans towards simplicity and uniformity instead of convenience -// since most usage will be by language specific wrappers. -// -// Conventions: -// * We use the prefix TF_ for everything in the API. -// * Objects are always passed around as pointers to opaque structs -// and these structs are allocated/deallocated via the API. -// * TF_Status holds error information. It is an object type -// and therefore is passed around as a pointer to an opaque -// struct as mentioned above. -// * Every call that has a TF_Status* argument clears it on success -// and fills it with error info on failure. -// * unsigned char is used for booleans (instead of the 'bool' type). -// In C++ bool is a keyword while in C99 bool is a macro defined -// in stdbool.h. It is possible for the two to be inconsistent. -// For example, neither the C99 nor the C++11 standard force a byte -// size on the bool type, so the macro defined in stdbool.h could -// be inconsistent with the bool keyword in C++. Thus, the use -// of stdbool.h is avoided and unsigned char is used instead. -// * size_t is used to represent byte sizes of objects that are -// materialized in the address space of the calling process. -// * int is used as an index into arrays. -// * Deletion functions are safe to call on nullptr. -// -// Questions left to address: -// * Might at some point need a way for callers to provide their own Env. -// * Maybe add TF_TensorShape that encapsulates dimension info. -// -// Design decisions made: -// * Backing store for tensor memory has an associated deallocation -// function. This deallocation function will point to client code -// for tensors populated by the client. So the client can do things -// like shadowing a numpy array. -// * We do not provide TF_OK since it is not strictly necessary and we -// are not optimizing for convenience. -// * We make assumption that one session has one graph. This should be -// fine since we have the ability to run sub-graphs. -// * We could allow NULL for some arguments (e.g., NULL options arg). -// However since convenience is not a primary goal, we don't do this. -// * Devices are not in this API. Instead, they are created/used internally -// and the API just provides high level controls over the number of -// devices of each type. - -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - +// This file contains the non-core C API for TensorFlow. Most of the +// API documentation and functionality resides in c_core_api.h. #ifdef __cplusplus extern "C" { #endif -// -------------------------------------------------------------------------- -// TF_Version returns a string describing version information of the -// TensorFlow library. TensorFlow using semantic versioning. -TF_CAPI_EXPORT extern const char* TF_Version(void); - -// -------------------------------------------------------------------------- -// TF_Buffer holds a pointer to a block of data and its associated length. -// Typically, the data consists of a serialized protocol buffer, but other data -// may also be held in a buffer. -// -// By default, TF_Buffer itself does not do any memory management of the -// pointed-to block. If need be, users of this struct should specify how to -// deallocate the block by setting the `data_deallocator` function pointer. -typedef struct TF_Buffer { - const void* data; - size_t length; - void (*data_deallocator)(void* data, size_t length); -} TF_Buffer; - -// Makes a copy of the input and sets an appropriate deallocator. Useful for -// passing in read-only, input protobufs. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, - size_t proto_len); - -// Useful for passing *out* a protobuf. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); - -TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); - -TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); - -// -------------------------------------------------------------------------- -// TF_SessionOptions holds options that can be passed during session creation. -typedef struct TF_SessionOptions TF_SessionOptions; - -// Return a new options object. -TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); - -// Set the target in TF_SessionOptions.options. -// target can be empty, a single entry, or a comma separated list of entries. -// Each entry is in one of the following formats : -// "local" -// ip:port -// host:port -TF_CAPI_EXPORT extern void TF_SetTarget(TF_SessionOptions* options, - const char* target); - -// Set the config in TF_SessionOptions.options. -// config should be a serialized tensorflow.ConfigProto proto. -// If config was not parsed successfully as a ConfigProto, record the -// error information in *status. -TF_CAPI_EXPORT extern void TF_SetConfig(TF_SessionOptions* options, - const void* proto, size_t proto_len, - TF_Status* status); - -// Destroy an options object. -TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); - -// TODO(jeff,sanjay): -// - export functions to set Config fields - -// -------------------------------------------------------------------------- -// The new graph construction API, still under development. - -// Represents a computation graph. Graphs may be shared between sessions. -// Graphs are thread-safe when used as directed below. -typedef struct TF_Graph TF_Graph; - -// Return a new graph object. -TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); - -// Destroy an options object. Graph will be deleted once no more -// TFSession's are referencing it. -TF_CAPI_EXPORT extern void TF_DeleteGraph(TF_Graph*); - -// Operation being built. The underlying graph must outlive this. -typedef struct TF_OperationDescription TF_OperationDescription; - -// Operation that has been added to the graph. Valid until the graph is -// deleted -- in particular adding a new operation to the graph does not -// invalidate old TF_Operation* pointers. -typedef struct TF_Operation TF_Operation; - -// Represents a specific input of an operation. -typedef struct TF_Input { - TF_Operation* oper; - int index; // The index of the input within oper. -} TF_Input; - -// Represents a specific output of an operation. -typedef struct TF_Output { - TF_Operation* oper; - int index; // The index of the output within oper. -} TF_Output; - -// TF_Function is a grouping of operations with defined inputs and outputs. -// Once created and added to graphs, functions can be invoked by creating an -// operation whose operation type matches the function name. -typedef struct TF_Function TF_Function; - -// Function definition options. TODO(iga): Define and implement -typedef struct TF_FunctionOptions TF_FunctionOptions; - -// Sets the shape of the Tensor referenced by `output` in `graph` to -// the shape described by `dims` and `num_dims`. -// -// If the number of dimensions is unknown, `num_dims` must be set to -// -1 and `dims` can be null. If a dimension is unknown, the -// corresponding entry in the `dims` array must be -1. -// -// This does not overwrite the existing shape associated with `output`, -// but merges the input shape with the existing shape. For example, -// setting a shape of [-1, 2] with an existing shape [2, -1] would set -// a final shape of [2, 2] based on shape merging semantics. -// -// Returns an error into `status` if: -// * `output` is not in `graph`. -// * An invalid shape is being set (e.g., the shape being set -// is incompatible with the existing shape). -TF_CAPI_EXPORT extern void TF_GraphSetTensorShape(TF_Graph* graph, - TF_Output output, - const int64_t* dims, - const int num_dims, - TF_Status* status); - -// Returns the number of dimensions of the Tensor referenced by `output` -// in `graph`. -// -// If the number of dimensions in the shape is unknown, returns -1. -// -// Returns an error into `status` if: -// * `output` is not in `graph`. -TF_CAPI_EXPORT extern int TF_GraphGetTensorNumDims(TF_Graph* graph, - TF_Output output, - TF_Status* status); - -// Returns the shape of the Tensor referenced by `output` in `graph` -// into `dims`. `dims` must be an array large enough to hold `num_dims` -// entries (e.g., the return value of TF_GraphGetTensorNumDims). -// -// If the number of dimensions in the shape is unknown or the shape is -// a scalar, `dims` will remain untouched. Otherwise, each element of -// `dims` will be set corresponding to the size of the dimension. An -// unknown dimension is represented by `-1`. -// -// Returns an error into `status` if: -// * `output` is not in `graph`. -// * `num_dims` does not match the actual number of dimensions. -TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, - TF_Output output, - int64_t* dims, int num_dims, - TF_Status* status); - -// Operation will only be added to *graph when TF_FinishOperation() is -// called (assuming TF_FinishOperation() does not return an error). -// *graph must not be deleted until after TF_FinishOperation() is -// called. -TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperation( - TF_Graph* graph, const char* op_type, const char* oper_name); - -// Specify the device for `desc`. Defaults to empty, meaning unconstrained. -TF_CAPI_EXPORT extern void TF_SetDevice(TF_OperationDescription* desc, - const char* device); - -// The calls to TF_AddInput and TF_AddInputList must match (in number, -// order, and type) the op declaration. For example, the "Concat" op -// has registration: -// REGISTER_OP("Concat") -// .Input("concat_dim: int32") -// .Input("values: N * T") -// .Output("output: T") -// .Attr("N: int >= 2") -// .Attr("T: type"); -// that defines two inputs, "concat_dim" and "values" (in that order). -// You must use TF_AddInput() for the first input (since it takes a -// single tensor), and TF_AddInputList() for the second input (since -// it takes a list, even if you were to pass a list with a single -// tensor), as in: -// TF_OperationDescription* desc = TF_NewOperation(graph, "Concat", "c"); -// TF_Output concat_dim_input = {...}; -// TF_AddInput(desc, concat_dim_input); -// TF_Output values_inputs[5] = {{...}, ..., {...}}; -// TF_AddInputList(desc, values_inputs, 5); - -// For inputs that take a single tensor. -TF_CAPI_EXPORT extern void TF_AddInput(TF_OperationDescription* desc, - TF_Output input); - -// For inputs that take a list of tensors. -// inputs must point to TF_Output[num_inputs]. -TF_CAPI_EXPORT extern void TF_AddInputList(TF_OperationDescription* desc, - const TF_Output* inputs, - int num_inputs); - -// Call once per control input to `desc`. -TF_CAPI_EXPORT extern void TF_AddControlInput(TF_OperationDescription* desc, - TF_Operation* input); - -// Request that `desc` be co-located on the device where `op` -// is placed. -// -// Use of this is discouraged since the implementation of device placement is -// subject to change. Primarily intended for internal libraries -TF_CAPI_EXPORT extern void TF_ColocateWith(TF_OperationDescription* desc, - TF_Operation* op); - -// Call some TF_SetAttr*() function for every attr that is not -// inferred from an input and doesn't have a default value you wish to -// keep. - -// `value` must point to a string of length `length` bytes. -TF_CAPI_EXPORT extern void TF_SetAttrString(TF_OperationDescription* desc, - const char* attr_name, - const void* value, size_t length); -// `values` and `lengths` each must have lengths `num_values`. -// `values[i]` must point to a string of length `lengths[i]` bytes. -TF_CAPI_EXPORT extern void TF_SetAttrStringList(TF_OperationDescription* desc, - const char* attr_name, - const void* const* values, - const size_t* lengths, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrInt(TF_OperationDescription* desc, - const char* attr_name, int64_t value); -TF_CAPI_EXPORT extern void TF_SetAttrIntList(TF_OperationDescription* desc, - const char* attr_name, - const int64_t* values, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrFloat(TF_OperationDescription* desc, - const char* attr_name, float value); -TF_CAPI_EXPORT extern void TF_SetAttrFloatList(TF_OperationDescription* desc, - const char* attr_name, - const float* values, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrBool(TF_OperationDescription* desc, - const char* attr_name, - unsigned char value); -TF_CAPI_EXPORT extern void TF_SetAttrBoolList(TF_OperationDescription* desc, - const char* attr_name, - const unsigned char* values, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrType(TF_OperationDescription* desc, - const char* attr_name, - TF_DataType value); -TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, - const char* attr_name, - const TF_DataType* values, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc, - const char* attr_name, - const char* placeholder); - -// Set a 'func' attribute to the specified name. -// `value` must point to a string of length `length` bytes. -TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, - const char* attr_name, - const char* value, size_t length); - -// Set `num_dims` to -1 to represent "unknown rank". Otherwise, -// `dims` points to an array of length `num_dims`. `dims[i]` must be -// >= -1, with -1 meaning "unknown dimension". -TF_CAPI_EXPORT extern void TF_SetAttrShape(TF_OperationDescription* desc, - const char* attr_name, - const int64_t* dims, int num_dims); -// `dims` and `num_dims` must point to arrays of length `num_shapes`. -// Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise, -// `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]` -// must be >= -1, with -1 meaning "unknown dimension". -TF_CAPI_EXPORT extern void TF_SetAttrShapeList(TF_OperationDescription* desc, - const char* attr_name, - const int64_t* const* dims, - const int* num_dims, - int num_shapes); -// `proto` must point to an array of `proto_len` bytes representing a -// binary-serialized TensorShapeProto. -TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProto( - TF_OperationDescription* desc, const char* attr_name, const void* proto, - size_t proto_len, TF_Status* status); -// `protos` and `proto_lens` must point to arrays of length `num_shapes`. -// `protos[i]` must point to an array of `proto_lens[i]` bytes -// representing a binary-serialized TensorShapeProto. -TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProtoList( - TF_OperationDescription* desc, const char* attr_name, - const void* const* protos, const size_t* proto_lens, int num_shapes, - TF_Status* status); - -TF_CAPI_EXPORT extern void TF_SetAttrTensor(TF_OperationDescription* desc, - const char* attr_name, - TF_Tensor* value, - TF_Status* status); -TF_CAPI_EXPORT extern void TF_SetAttrTensorList(TF_OperationDescription* desc, - const char* attr_name, - TF_Tensor* const* values, - int num_values, - TF_Status* status); - -// `proto` should point to a sequence of bytes of length `proto_len` -// representing a binary serialization of an AttrValue protocol -// buffer. -TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, - const char* attr_name, - const void* proto, - size_t proto_len, - TF_Status* status); - -// If this function succeeds: -// * *status is set to an OK value, -// * a TF_Operation is added to the graph, -// * a non-null value pointing to the added operation is returned -- -// this value is valid until the underlying graph is deleted. -// Otherwise: -// * *status is set to a non-OK value, -// * the graph is not modified, -// * a null value is returned. -// In either case, it deletes `desc`. -TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperation( - TF_OperationDescription* desc, TF_Status* status); - -// TF_Operation functions. Operations are immutable once created, so -// these are all query functions. - -TF_CAPI_EXPORT extern const char* TF_OperationName(TF_Operation* oper); -TF_CAPI_EXPORT extern const char* TF_OperationOpType(TF_Operation* oper); -TF_CAPI_EXPORT extern const char* TF_OperationDevice(TF_Operation* oper); - -TF_CAPI_EXPORT extern int TF_OperationNumOutputs(TF_Operation* oper); -TF_CAPI_EXPORT extern TF_DataType TF_OperationOutputType(TF_Output oper_out); -TF_CAPI_EXPORT extern int TF_OperationOutputListLength(TF_Operation* oper, - const char* arg_name, - TF_Status* status); - -TF_CAPI_EXPORT extern int TF_OperationNumInputs(TF_Operation* oper); -TF_CAPI_EXPORT extern TF_DataType TF_OperationInputType(TF_Input oper_in); -TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper, - const char* arg_name, - TF_Status* status); - -// In this code: -// TF_Output producer = TF_OperationInput(consumer); -// There is an edge from producer.oper's output (given by -// producer.index) to consumer.oper's input (given by consumer.index). -TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in); - -// Get list of all inputs of a specific operation. `inputs` must point to -// an array of length at least `max_inputs` (ideally set to -// TF_OperationNumInputs(oper)). Beware that a concurrent -// modification of the graph can increase the number of inputs of -// an operation. -TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper, - TF_Output* inputs, - int max_inputs); - -// Get the number of current consumers of a specific output of an -// operation. Note that this number can change when new operations -// are added to the graph. -TF_CAPI_EXPORT extern int TF_OperationOutputNumConsumers(TF_Output oper_out); - -// Get list of all current consumers of a specific output of an -// operation. `consumers` must point to an array of length at least -// `max_consumers` (ideally set to -// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent -// modification of the graph can increase the number of consumers of -// an operation. Returns the number of output consumers (should match -// TF_OperationOutputNumConsumers(oper_out)). -TF_CAPI_EXPORT extern int TF_OperationOutputConsumers(TF_Output oper_out, - TF_Input* consumers, - int max_consumers); - -// Get the number of control inputs to an operation. -TF_CAPI_EXPORT extern int TF_OperationNumControlInputs(TF_Operation* oper); - -// Get list of all control inputs to an operation. `control_inputs` must -// point to an array of length `max_control_inputs` (ideally set to -// TF_OperationNumControlInputs(oper)). Returns the number of control -// inputs (should match TF_OperationNumControlInputs(oper)). -TF_CAPI_EXPORT extern int TF_OperationGetControlInputs( - TF_Operation* oper, TF_Operation** control_inputs, int max_control_inputs); - -// Get the number of operations that have `*oper` as a control input. -// Note that this number can change when new operations are added to -// the graph. -TF_CAPI_EXPORT extern int TF_OperationNumControlOutputs(TF_Operation* oper); - -// Get the list of operations that have `*oper` as a control input. -// `control_outputs` must point to an array of length at least -// `max_control_outputs` (ideally set to -// TF_OperationNumControlOutputs(oper)). Beware that a concurrent -// modification of the graph can increase the number of control -// outputs. Returns the number of control outputs (should match -// TF_OperationNumControlOutputs(oper)). -TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( - TF_Operation* oper, TF_Operation** control_outputs, - int max_control_outputs); - -// TF_AttrMetadata describes the value of an attribute on an operation. -typedef struct TF_AttrMetadata { - // A boolean: 1 if the attribute value is a list, 0 otherwise. - unsigned char is_list; - - // Length of the list if is_list is true. Undefined otherwise. - int64_t list_size; - - // Type of elements of the list if is_list != 0. - // Type of the single value stored in the attribute if is_list == 0. - TF_AttrType type; - - // Total size the attribute value. - // The units of total_size depend on is_list and type. - // (1) If type == TF_ATTR_STRING and is_list == 0 - // then total_size is the byte size of the string - // valued attribute. - // (2) If type == TF_ATTR_STRING and is_list == 1 - // then total_size is the cumulative byte size - // of all the strings in the list. - // (3) If type == TF_ATTR_SHAPE and is_list == 0 - // then total_size is the number of dimensions - // of the shape valued attribute, or -1 - // if its rank is unknown. - // (4) If type == TF_ATTR_SHAPE and is_list == 1 - // then total_size is the cumulative number - // of dimensions of all shapes in the list. - // (5) Otherwise, total_size is undefined. - int64_t total_size; -} TF_AttrMetadata; - -// Returns metadata about the value of the attribute `attr_name` of `oper`. -TF_CAPI_EXPORT extern TF_AttrMetadata TF_OperationGetAttrMetadata( - TF_Operation* oper, const char* attr_name, TF_Status* status); - -// Fills in `value` with the value of the attribute `attr_name`. `value` must -// point to an array of length at least `max_length` (ideally set to -// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrString(TF_Operation* oper, - const char* attr_name, - void* value, - size_t max_length, - TF_Status* status); - -// Get the list of strings in the value of the attribute `attr_name`. Fills in -// `values` and `lengths`, each of which must point to an array of length at -// least `max_values`. -// -// The elements of values will point to addresses in `storage` which must be at -// least `storage_size` bytes in length. Ideally, max_values would be set to -// TF_AttrMetadata.list_size and `storage` would be at least -// TF_AttrMetadata.total_size, obtained from TF_OperationGetAttrMetadata(oper, -// attr_name). -// -// Fails if storage_size is too small to hold the requested number of strings. -TF_CAPI_EXPORT extern void TF_OperationGetAttrStringList( - TF_Operation* oper, const char* attr_name, void** values, size_t* lengths, - int max_values, void* storage, size_t storage_size, TF_Status* status); - -TF_CAPI_EXPORT extern void TF_OperationGetAttrInt(TF_Operation* oper, - const char* attr_name, - int64_t* value, - TF_Status* status); - -// Fills in `values` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `max_values` (ideally set -// TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrIntList(TF_Operation* oper, - const char* attr_name, - int64_t* values, - int max_values, - TF_Status* status); - -TF_CAPI_EXPORT extern void TF_OperationGetAttrFloat(TF_Operation* oper, - const char* attr_name, - float* value, - TF_Status* status); - -// Fills in `values` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `max_values` (ideally set -// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrFloatList(TF_Operation* oper, - const char* attr_name, - float* values, - int max_values, - TF_Status* status); - -TF_CAPI_EXPORT extern void TF_OperationGetAttrBool(TF_Operation* oper, - const char* attr_name, - unsigned char* value, - TF_Status* status); - -// Fills in `values` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `max_values` (ideally set -// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrBoolList(TF_Operation* oper, - const char* attr_name, - unsigned char* values, - int max_values, - TF_Status* status); - -TF_CAPI_EXPORT extern void TF_OperationGetAttrType(TF_Operation* oper, - const char* attr_name, - TF_DataType* value, - TF_Status* status); - -// Fills in `values` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `max_values` (ideally set -// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrTypeList(TF_Operation* oper, - const char* attr_name, - TF_DataType* values, - int max_values, - TF_Status* status); - -// Fills in `value` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `num_dims` (ideally set to -// TF_Attr_Meta.size from TF_OperationGetAttrMetadata(oper, attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrShape(TF_Operation* oper, - const char* attr_name, - int64_t* value, - int num_dims, - TF_Status* status); - -// Fills in `dims` with the list of shapes in the attribute `attr_name` of -// `oper` and `num_dims` with the corresponding number of dimensions. On return, -// for every i where `num_dims[i]` > 0, `dims[i]` will be an array of -// `num_dims[i]` elements. A value of -1 for `num_dims[i]` indicates that the -// i-th shape in the list is unknown. -// -// The elements of `dims` will point to addresses in `storage` which must be -// large enough to hold at least `storage_size` int64_ts. Ideally, `num_shapes` -// would be set to TF_AttrMetadata.list_size and `storage_size` would be set to -// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, -// attr_name). -// -// Fails if storage_size is insufficient to hold the requested shapes. -TF_CAPI_EXPORT extern void TF_OperationGetAttrShapeList( - TF_Operation* oper, const char* attr_name, int64_t** dims, int* num_dims, - int num_shapes, int64_t* storage, int storage_size, TF_Status* status); - -// Sets `value` to the binary-serialized TensorShapeProto of the value of -// `attr_name` attribute of `oper`'. -TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProto( - TF_Operation* oper, const char* attr_name, TF_Buffer* value, - TF_Status* status); - -// Fills in `values` with binary-serialized TensorShapeProto values of the -// attribute `attr_name` of `oper`. `values` must point to an array of length at -// least `num_values` (ideally set to TF_AttrMetadata.list_size from -// TF_OperationGetAttrMetadata(oper, attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProtoList( - TF_Operation* oper, const char* attr_name, TF_Buffer** values, - int max_values, TF_Status* status); - -// Gets the TF_Tensor valued attribute of `attr_name` of `oper`. -// -// Allocates a new TF_Tensor which the caller is expected to take -// ownership of (and can deallocate using TF_DeleteTensor). -TF_CAPI_EXPORT extern void TF_OperationGetAttrTensor(TF_Operation* oper, - const char* attr_name, - TF_Tensor** value, - TF_Status* status); - -// Fills in `values` with the TF_Tensor values of the attribute `attr_name` of -// `oper`. `values` must point to an array of TF_Tensor* of length at least -// `max_values` (ideally set to TF_AttrMetadata.list_size from -// TF_OperationGetAttrMetadata(oper, attr_name)). -// -// The caller takes ownership of all the non-null TF_Tensor* entries in `values` -// (which can be deleted using TF_DeleteTensor(values[i])). -TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorList(TF_Operation* oper, - const char* attr_name, - TF_Tensor** values, - int max_values, - TF_Status* status); - -// Sets `output_attr_value` to the binary-serialized AttrValue proto -// representation of the value of the `attr_name` attr of `oper`. -TF_CAPI_EXPORT extern void TF_OperationGetAttrValueProto( - TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, - TF_Status* status); - -// Returns the operation in the graph with `oper_name`. Returns nullptr if -// no operation found. -TF_CAPI_EXPORT extern TF_Operation* TF_GraphOperationByName( - TF_Graph* graph, const char* oper_name); - -// Iterate through the operations of a graph. To use: -// size_t pos = 0; -// TF_Operation* oper; -// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { -// DoSomethingWithOperation(oper); -// } -TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, - size_t* pos); - -// Write out a serialized representation of `graph` (as a GraphDef protocol -// message) to `output_graph_def` (allocated by TF_NewBuffer()). -// `output_graph_def`'s underlying buffer will be freed when TF_DeleteBuffer() -// is called. -// -// May fail on very large graphs in the future. -TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, - TF_Buffer* output_graph_def, - TF_Status* status); - -// Returns the serialized OpDef proto with name `op_name`, or a bad status if no -// such op exists. This can return OpDefs of functions copied into the graph. -TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph, - const char* op_name, - TF_Buffer* output_op_def, - TF_Status* status); - -// Returns the serialized VersionDef proto for this graph. -TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, - TF_Buffer* output_version_def, - TF_Status* status); - -// TF_ImportGraphDefOptions holds options that can be passed to -// TF_GraphImportGraphDef. -typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; - -TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( - void); -TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( - TF_ImportGraphDefOptions* opts); - -// Set the prefix to be prepended to the names of nodes in `graph_def` that will -// be imported into `graph`. `prefix` is copied and has no lifetime -// requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( - TF_ImportGraphDefOptions* opts, const char* prefix); - -// Set the execution device for nodes in `graph_def`. -// Only applies to nodes where a device was not already explicitly specified. -// `device` is copied and has no lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetDefaultDevice( - TF_ImportGraphDefOptions* opts, const char* device); - -// Set whether to uniquify imported operation names. If true, imported operation -// names will be modified if their name already exists in the graph. If false, -// conflicting names will be treated as an error. Note that this option has no -// effect if a prefix is set, since the prefix will guarantee all names are -// unique. Defaults to false. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( - TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); - -// If true, the specified prefix will be modified if it already exists as an -// operation name or prefix in the graph. If false, a conflicting prefix will be -// treated as an error. This option has no effect if no prefix is specified. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( - TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); - -// Set any imported nodes with input `src_name:src_index` to have that input -// replaced with `dst`. `src_name` refers to a node in the graph to be imported, -// `dst` references a node already existing in the graph being imported into. -// `src_name` is copied and has no lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( - TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, - TF_Output dst); - -// Set any imported nodes with control input `src_name` to have that input -// replaced with `dst`. `src_name` refers to a node in the graph to be imported, -// `dst` references an operation already existing in the graph being imported -// into. `src_name` is copied and has no lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( - TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); - -// Cause the imported graph to have a control dependency on `oper`. `oper` -// should exist in the graph being imported into. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( - TF_ImportGraphDefOptions* opts, TF_Operation* oper); - -// Add an output in `graph_def` to be returned via the `return_outputs` output -// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input -// mapping, the corresponding existing tensor in `graph` will be returned. -// `oper_name` is copied and has no lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( - TF_ImportGraphDefOptions* opts, const char* oper_name, int index); - -// Returns the number of return outputs added via -// TF_ImportGraphDefOptionsAddReturnOutput(). -TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( - const TF_ImportGraphDefOptions* opts); - -// Add an operation in `graph_def` to be returned via the `return_opers` output -// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no -// lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( - TF_ImportGraphDefOptions* opts, const char* oper_name); - -// Returns the number of return operations added via -// TF_ImportGraphDefOptionsAddReturnOperation(). -TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations( - const TF_ImportGraphDefOptions* opts); - -// TF_ImportGraphDefResults holds results that are generated by -// TF_GraphImportGraphDefWithResults(). -typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults; - -// Fetches the return outputs requested via -// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is -// returned in `num_outputs`. The array of return outputs is returned in -// `outputs`. `*outputs` is owned by and has the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs( - TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs); - -// Fetches the return operations requested via -// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched -// operations is returned in `num_opers`. The array of return operations is -// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( - TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); - -// Fetches any input mappings requested via -// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef -// and weren't used as input to any node in the imported graph def. The number -// of fetched mappings is returned in `num_missing_unused_input_mappings`. The -// array of each mapping's source node name is returned in `src_names`, and the -// array of each mapping's source index is returned in `src_indexes`. -// -// `*src_names`, `*src_indexes`, and the memory backing each string in -// `src_names` are owned by and have the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, - const char*** src_names, int** src_indexes); - -// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). -TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults( - TF_ImportGraphDefResults* results); - -// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and -// a bad status on error. Otherwise, returns a populated -// TF_ImportGraphDefResults instance. The returned instance must be deleted via -// TF_DeleteImportGraphDefResults(). -TF_CAPI_EXPORT extern TF_ImportGraphDefResults* -TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, - TF_Status* status); - -// Import the graph serialized in `graph_def` into `graph`. -// Convenience function for when only return outputs are needed. -// -// `num_return_outputs` must be the number of return outputs added (i.e. the -// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If -// `num_return_outputs` is non-zero, `return_outputs` must be of length -// `num_return_outputs`. Otherwise it can be null. -TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, - int num_return_outputs, TF_Status* status); - -// Import the graph serialized in `graph_def` into `graph`. -// Convenience function for when no results are needed. -TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, TF_Status* status); - -// Adds a copy of function `func` and optionally its gradient function `grad` -// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating -// an operation using the function's name. -// Any changes to `func`/`grad` (including deleting it) done after this method -// returns, won't affect the copy of `func`/`grad` in `g`. -// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no -// effect on them, but can establish the function->gradient relationship -// between them if `func` does not already have a gradient. If `func` already -// has a gradient different from `grad`, an error is returned. -// -// `func` must not be null. -// If `grad` is null and `func` is not in `g`, `func` is added without a -// gradient. -// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop. -// `grad` must have appropriate signature as described in the doc of -// GradientDef in tensorflow/core/framework/function.proto. -// -// If successful, status is set to OK and `func` and `grad` are added to `g`. -// Otherwise, status is set to the encountered error and `g` is unmodified. -TF_CAPI_EXPORT extern void TF_GraphCopyFunction(TF_Graph* g, - const TF_Function* func, - const TF_Function* grad, - TF_Status* status); - -// Returns the number of TF_Functions registered in `g`. -TF_CAPI_EXPORT extern int TF_GraphNumFunctions(TF_Graph* g); - -// Fills in `funcs` with the TF_Function* registered in `g`. -// `funcs` must point to an array of TF_Function* of length at least -// `max_func`. In usual usage, max_func should be set to the result of -// TF_GraphNumFunctions(g). In this case, all the functions registered in -// `g` will be returned. Else, an unspecified subset. -// -// If successful, returns the number of TF_Function* successfully set in -// `funcs` and sets status to OK. The caller takes ownership of -// all the returned TF_Functions. They must be deleted with TF_DeleteFunction. -// On error, returns 0, sets status to the encountered error, and the contents -// of funcs will be undefined. -TF_CAPI_EXPORT extern int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, - int max_func, TF_Status* status); - -// Note: The following function may fail on very large protos in the future. - -TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, - TF_Buffer* output_node_def, - TF_Status* status); - typedef struct TF_WhileParams { // The number of inputs to the while loop, i.e. the number of loop variables. // This is the size of cond_inputs, body_inputs, and body_outputs. @@ -1012,558 +149,6 @@ TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* dx, TF_Status* status, TF_Output* dy); -// Create a TF_Function from a TF_Graph -// -// Params: -// fn_body - the graph whose operations (or subset of whose operations) will be -// converted to TF_Function. -// fn_name - the name of the new TF_Function. Should match the operation -// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. -// If `append_hash_to_fn_name` is false, `fn_name` must be distinct -// from other function and operation names (at least those -// registered in graphs where this function will be used). -// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name -// of the function will be `fn_name` appended with -// '_'. -// If set to 0, the function's name will be `fn_name`. -// num_opers - `num_opers` contains the number of elements in the `opers` array -// or a special value of -1 meaning that no array is given. -// The distinction between an empty array of operations and no -// array of operations is necessary to distinguish the case of -// creating a function with no body (e.g. identity or permutation) -// and the case of creating a function whose body contains all -// the nodes in the graph (except for the automatic skipping, see -// below). -// opers - Array of operations to become the body of the function or null. -// - If no array is given (`num_opers` = -1), all the -// operations in `fn_body` will become part of the function -// except operations referenced in `inputs`. These operations -// must have a single output (these operations are typically -// placeholders created for the sole purpose of representing -// an input. We can relax this constraint if there are -// compelling use cases). -// - If an array is given (`num_opers` >= 0), all operations -// in it will become part of the function. In particular, no -// automatic skipping of dummy input operations is performed. -// ninputs - number of elements in `inputs` array -// inputs - array of TF_Outputs that specify the inputs to the function. -// If `ninputs` is zero (the function takes no inputs), `inputs` -// can be null. The names used for function inputs are normalized -// names of the operations (usually placeholders) pointed to by -// `inputs`. These operation names should start with a letter. -// Normalization will convert all letters to lowercase and -// non-alphanumeric characters to '_' to make resulting names match -// the "[a-z][a-z0-9_]*" pattern for operation argument names. -// `inputs` cannot contain the same tensor twice. -// noutputs - number of elements in `outputs` array -// outputs - array of TF_Outputs that specify the outputs of the function. -// If `noutputs` is zero (the function returns no outputs), `outputs` -// can be null. `outputs` can contain the same tensor more than once. -// output_names - The names of the function's outputs. `output_names` array -// must either have the same length as `outputs` -// (i.e. `noutputs`) or be null. In the former case, -// the names should match the regular expression for ArgDef -// names - "[a-z][a-z0-9_]*". In the latter case, -// names for outputs will be generated automatically. -// opts - various options for the function, e.g. XLA's inlining control. -// description - optional human-readable description of this function. -// status - Set to OK on success and an appropriate error on failure. -// -// Note that when the same TF_Output is listed as both an input and an output, -// the corresponding function's output will equal to this input, -// instead of the original node's output. -// -// Callers must also satisfy the following constraints: -// - `inputs` cannot refer to TF_Outputs within a control flow context. For -// example, one cannot use the output of "switch" node as input. -// - `inputs` and `outputs` cannot have reference types. Reference types are -// not exposed through C API and are being replaced with Resources. We support -// reference types inside function's body to support legacy code. Do not -// use them in new code. -// - Every node in the function's body must have all of its inputs (including -// control inputs). In other words, for every node in the body, each input -// must be either listed in `inputs` or must come from another node in -// the body. In particular, it is an error to have a control edge going from -// a node outside of the body into a node in the body. This applies to control -// edges going from nodes referenced in `inputs` to nodes in the body when -// the former nodes are not in the body (automatically skipped or not -// included in explicitly specified body). -// -// Returns: -// On success, a newly created TF_Function instance. It must be deleted by -// calling TF_DeleteFunction. -// -// On failure, null. -TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( - const TF_Graph* fn_body, const char* fn_name, - unsigned char append_hash_to_fn_name, int num_opers, - const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, - int noutputs, const TF_Output* outputs, const char* const* output_names, - const TF_FunctionOptions* opts, const char* description, TF_Status* status); - -// Similar to TF_GraphToFunction but allows specifying control outputs of the -// function. -// -// The arguments of TF_GraphToFunction have the same meaning, but the new -// arguments are as follows: -// -// ncontrol_outputs: Number of control outputs of the function. -// control_outputs: vector of TF_Operation objects to be marked as control -// outputs of the function. Operations marked as control outputs are -// guaranteed to execute. -// control_output_names: Optional. If not nullptr, vector of strings, one -// per control output, with their names to be added to the function's -// OpDef. -TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( - const TF_Graph* fn_body, const char* fn_name, - unsigned char append_hash_to_fn_name, int num_opers, - const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, - int noutputs, const TF_Output* outputs, const char* const* output_names, - int ncontrol_outputs, const TF_Operation* const* control_outputs, - const char* const* control_output_names, const TF_FunctionOptions* opts, - const char* description, TF_Status* status); - -// Returns the name of the graph function. -// The return value points to memory that is only usable until the next -// mutation to *func. -TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); - -// Write out a serialized representation of `func` (as a FunctionDef protocol -// message) to `output_func_def` (allocated by TF_NewBuffer()). -// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() -// is called. -// -// May fail on very large graphs in the future. -TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, - TF_Buffer* output_func_def, - TF_Status* status); - -// Construct and return the function whose FunctionDef representation is -// serialized in `proto`. `proto_len` must equal the number of bytes -// pointed to by `proto`. -// Returns: -// On success, a newly created TF_Function instance. It must be deleted by -// calling TF_DeleteFunction. -// -// On failure, null. -TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( - const void* proto, size_t proto_len, TF_Status* status); - -// Sets function attribute named `attr_name` to value stored in `proto`. -// If this attribute is already set to another value, it is overridden. -// `proto` should point to a sequence of bytes of length `proto_len` -// representing a binary serialization of an AttrValue protocol -// buffer. -TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func, - const char* attr_name, - const void* proto, - size_t proto_len, - TF_Status* status); - -// Sets `output_attr_value` to the binary-serialized AttrValue proto -// representation of the value of the `attr_name` attr of `func`. -// If `attr_name` attribute is not present, status is set to an error. -TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( - TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value, - TF_Status* status); - -// Frees the memory used by the `func` struct. -// TF_DeleteFunction is a noop if `func` is null. -// Deleting a function does not remove it from any graphs it was copied to. -TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); - -// Attempts to evaluate `output`. This will only be possible if `output` doesn't -// depend on any graph inputs (this function is safe to call if this isn't the -// case though). -// -// If the evaluation is successful, this function returns true and `output`s -// value is returned in `result`. Otherwise returns false. An error status is -// returned if something is wrong with the graph or input. Note that this may -// return false even if no error status is set. -TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph, - TF_Output output, - TF_Tensor** result, - TF_Status* status); - -// TODO(josh11b): Register OpDef, available to all operations added -// to this graph. - -// -------------------------------------------------------------------------- -// API for driving Graph execution. - -typedef struct TF_Session TF_Session; - -// Return a new execution session with the associated graph, or NULL on -// error. Does not take ownership of any input parameters. -// -// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be -// kept alive for the lifetime of the returned TF_Session. New nodes can still -// be added to `graph` after this call. -TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, - const TF_SessionOptions* opts, - TF_Status* status); - -// This function creates a new TF_Session (which is created on success) using -// `session_options`, and then initializes state (restoring tensors and other -// assets) using `run_options`. -// -// Any NULL and non-NULL value combinations for (`run_options, `meta_graph_def`) -// are valid. -// -// - `export_dir` must be set to the path of the exported SavedModel. -// - `tags` must include the set of tags used to identify one MetaGraphDef in -// the SavedModel. -// - `graph` must be a graph newly allocated with TF_NewGraph(). -// -// If successful, populates `graph` with the contents of the Graph and -// `meta_graph_def` with the MetaGraphDef of the loaded model. -TF_CAPI_EXPORT extern TF_Session* TF_LoadSessionFromSavedModel( - const TF_SessionOptions* session_options, const TF_Buffer* run_options, - const char* export_dir, const char* const* tags, int tags_len, - TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status); - -// Close a session. -// -// Contacts any other processes associated with the session, if applicable. -// May not be called after TF_DeleteSession(). -TF_CAPI_EXPORT extern void TF_CloseSession(TF_Session*, TF_Status* status); - -// Destroy a session object. -// -// Even if error information is recorded in *status, this call discards all -// local resources associated with the session. The session may not be used -// during or after this call (and the session drops its reference to the -// corresponding graph). -TF_CAPI_EXPORT extern void TF_DeleteSession(TF_Session*, TF_Status* status); - -// Run the graph associated with the session starting with the supplied inputs -// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). -// -// Any NULL and non-NULL value combinations for (`run_options`, -// `run_metadata`) are valid. -// -// - `run_options` may be NULL, in which case it will be ignored; or -// non-NULL, in which case it must point to a `TF_Buffer` containing the -// serialized representation of a `RunOptions` protocol buffer. -// - `run_metadata` may be NULL, in which case it will be ignored; or -// non-NULL, in which case it must point to an empty, freshly allocated -// `TF_Buffer` that may be updated to contain the serialized representation -// of a `RunMetadata` protocol buffer. -// -// The caller retains ownership of `input_values` (which can be deleted using -// TF_DeleteTensor). The caller also retains ownership of `run_options` and/or -// `run_metadata` (when not NULL) and should manually call TF_DeleteBuffer on -// them. -// -// On success, the tensors corresponding to outputs[0,noutputs-1] are placed in -// output_values[]. Ownership of the elements of output_values[] is transferred -// to the caller, which must eventually call TF_DeleteTensor on them. -// -// On failure, output_values[] contains NULLs. -TF_CAPI_EXPORT extern void TF_SessionRun( - TF_Session* session, - // RunOptions - const TF_Buffer* run_options, - // Input tensors - const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, - // Output tensors - const TF_Output* outputs, TF_Tensor** output_values, int noutputs, - // Target operations - const TF_Operation* const* target_opers, int ntargets, - // RunMetadata - TF_Buffer* run_metadata, - // Output status - TF_Status*); - -// Set up the graph with the intended feeds (inputs) and fetches (outputs) for a -// sequence of partial run calls. -// -// On success, returns a handle that is used for subsequent PRun calls. The -// handle should be deleted with TF_DeletePRunHandle when it is no longer -// needed. -// -// On failure, out_status contains a tensorflow::Status with an error -// message. *handle is set to nullptr. -TF_CAPI_EXPORT extern void TF_SessionPRunSetup( - TF_Session*, - // Input names - const TF_Output* inputs, int ninputs, - // Output names - const TF_Output* outputs, int noutputs, - // Target operations - const TF_Operation* const* target_opers, int ntargets, - // Output handle - const char** handle, - // Output status - TF_Status*); - -// Continue to run the graph with additional feeds and fetches. The -// execution state is uniquely identified by the handle. -TF_CAPI_EXPORT extern void TF_SessionPRun( - TF_Session*, const char* handle, - // Input tensors - const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, - // Output tensors - const TF_Output* outputs, TF_Tensor** output_values, int noutputs, - // Target operations - const TF_Operation* const* target_opers, int ntargets, - // Output status - TF_Status*); - -// Deletes a handle allocated by TF_SessionPRunSetup. -// Once called, no more calls to TF_SessionPRun should be made. -TF_CAPI_EXPORT extern void TF_DeletePRunHandle(const char* handle); - -// -------------------------------------------------------------------------- -// The deprecated session API. Please switch to the above instead of -// TF_ExtendGraph(). This deprecated API can be removed at any time without -// notice. - -typedef struct TF_DeprecatedSession TF_DeprecatedSession; - -TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession( - const TF_SessionOptions*, TF_Status* status); -TF_CAPI_EXPORT extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, - TF_Status* status); -TF_CAPI_EXPORT extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, - TF_Status* status); -TF_CAPI_EXPORT extern void TF_Reset(const TF_SessionOptions* opt, - const char** containers, int ncontainers, - TF_Status* status); -// Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and -// add the nodes in that GraphDef to the graph for the session. -// -// Prefer use of TF_Session and TF_GraphImportGraphDef over this. -TF_CAPI_EXPORT extern void TF_ExtendGraph(TF_DeprecatedSession*, - const void* proto, size_t proto_len, - TF_Status*); - -// See TF_SessionRun() above. -TF_CAPI_EXPORT extern void TF_Run(TF_DeprecatedSession*, - const TF_Buffer* run_options, - const char** input_names, TF_Tensor** inputs, - int ninputs, const char** output_names, - TF_Tensor** outputs, int noutputs, - const char** target_oper_names, int ntargets, - TF_Buffer* run_metadata, TF_Status*); - -// See TF_SessionPRunSetup() above. -TF_CAPI_EXPORT extern void TF_PRunSetup(TF_DeprecatedSession*, - const char** input_names, int ninputs, - const char** output_names, int noutputs, - const char** target_oper_names, - int ntargets, const char** handle, - TF_Status*); - -// See TF_SessionPRun above. -TF_CAPI_EXPORT extern void TF_PRun(TF_DeprecatedSession*, const char* handle, - const char** input_names, TF_Tensor** inputs, - int ninputs, const char** output_names, - TF_Tensor** outputs, int noutputs, - const char** target_oper_names, int ntargets, - TF_Status*); - -typedef struct TF_DeviceList TF_DeviceList; - -// Lists all devices in a TF_Session. -// -// Caller takes ownership of the returned TF_DeviceList* which must eventually -// be freed with a call to TF_DeleteDeviceList. -TF_CAPI_EXPORT extern TF_DeviceList* TF_SessionListDevices(TF_Session* session, - TF_Status* status); - -// Lists all devices in a TF_Session. -// -// Caller takes ownership of the returned TF_DeviceList* which must eventually -// be freed with a call to TF_DeleteDeviceList. -TF_CAPI_EXPORT extern TF_DeviceList* TF_DeprecatedSessionListDevices( - TF_DeprecatedSession* session, TF_Status* status); - -// Deallocates the device list. -TF_CAPI_EXPORT extern void TF_DeleteDeviceList(TF_DeviceList* list); - -// Counts the number of elements in the device list. -TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); - -// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) -// The return value will be a pointer to a null terminated string. The caller -// must not modify or delete the string. It will be deallocated upon a call to -// TF_DeleteDeviceList. -// -// If index is out of bounds, an error code will be set in the status object, -// and a null pointer will be returned. -TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, - int index, - TF_Status* status); - -// Retrieves the type of the device at the given index. -// -// The caller must not modify or delete the string. It will be deallocated upon -// a call to TF_DeleteDeviceList. -// -// If index is out of bounds, an error code will be set in the status object, -// and a null pointer will be returned. -TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, - int index, - TF_Status* status); - -// Retrieve the amount of memory associated with a given device. -// -// If index is out of bounds, an error code will be set in the status object, -// and -1 will be returned. -TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( - const TF_DeviceList* list, int index, TF_Status* status); - -// Retrieve the incarnation number of a given device. -// -// If index is out of bounds, an error code will be set in the status object, -// and 0 will be returned. -TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation( - const TF_DeviceList* list, int index, TF_Status* status); - -// -------------------------------------------------------------------------- -// Load plugins containing custom ops and kernels - -// TF_Library holds information about dynamically loaded TensorFlow plugins. -typedef struct TF_Library TF_Library; - -// Load the library specified by library_filename and register the ops and -// kernels present in that library. -// -// Pass "library_filename" to a platform-specific mechanism for dynamically -// loading a library. The rules for determining the exact location of the -// library are platform-specific and are not documented here. -// -// On success, place OK in status and return the newly created library handle. -// The caller owns the library handle. -// -// On failure, place an error status in status and return NULL. -TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename, - TF_Status* status); - -// Get the OpList of OpDefs defined in the library pointed by lib_handle. -// -// Returns a TF_Buffer. The memory pointed to by the result is owned by -// lib_handle. The data in the buffer will be the serialized OpList proto for -// ops defined in the library. -TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); - -// Frees the memory associated with the library handle. -// Does NOT unload the library. -TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); - -// Get the OpList of all OpDefs defined in this address space. -// Returns a TF_Buffer, ownership of which is transferred to the caller -// (and can be freed using TF_DeleteBuffer). -// -// The data in the buffer will be the serialized OpList proto for ops registered -// in this address space. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); - -// TF_ApiDefMap encapsulates a collection of API definitions for an operation. -// -// This object maps the name of a TensorFlow operation to a description of the -// API to generate for it, as defined by the ApiDef protocol buffer ( -// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) -// -// The ApiDef messages are typically used to generate convenience wrapper -// functions for TensorFlow operations in various language bindings. -typedef struct TF_ApiDefMap TF_ApiDefMap; - -// Creates a new TF_ApiDefMap instance. -// -// Params: -// op_list_buffer - TF_Buffer instance containing serialized OpList -// protocol buffer. (See -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto -// for the OpList proto definition). -// status - Set to OK on success and an appropriate error on failure. -TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, - TF_Status* status); - -// Deallocates a TF_ApiDefMap. -TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap); - -// Add ApiDefs to the map. -// -// `text` corresponds to a text representation of an ApiDefs protocol message. -// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). -// -// The provided ApiDefs will be merged with existing ones in the map, with -// precedence given to the newly added version in case of conflicts with -// previous calls to TF_ApiDefMapPut. -TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, - const char* text, size_t text_len, - TF_Status* status); - -// Returns a serialized ApiDef protocol buffer for the TensorFlow operation -// named `name`. -TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, - const char* name, - size_t name_len, - TF_Status* status); - -// -------------------------------------------------------------------------- -// Kernel definition information. - -// Returns a serialized KernelList protocol buffer containing KernelDefs for all -// registered kernels. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); - -// Returns a serialized KernelList protocol buffer containing KernelDefs for all -// kernels registered for the operation named `name`. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( - const char* name, TF_Status* status); - -// -------------------------------------------------------------------------- -// In-process TensorFlow server functionality, for use in distributed training. -// A Server instance encapsulates a set of devices and a Session target that -// can participate in distributed training. A server belongs to a cluster -// (specified by a ClusterSpec), and corresponds to a particular task in a -// named job. The server can communicate with any other server in the same -// cluster. - -// In-process TensorFlow server. -typedef struct TF_Server TF_Server; - -// Creates a new in-process TensorFlow server configured using a serialized -// ServerDef protocol buffer provided via `proto` and `proto_len`. -// -// The server will not serve any requests until TF_ServerStart is invoked. -// The server will stop serving requests once TF_ServerStop or -// TF_DeleteServer is invoked. -TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, - size_t proto_len, - TF_Status* status); - -// Starts an in-process TensorFlow server. -TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); - -// Stops an in-process TensorFlow server. -TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); - -// Blocks until the server has been successfully stopped (via TF_ServerStop or -// TF_ServerClose). -TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); - -// Returns the target string that can be provided to TF_SetTarget() to connect -// a TF_Session to `server`. -// -// The returned string is valid only until TF_DeleteServer is invoked. -TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); - -// Destroy an in-process TensorFlow server, frees memory. If server is running -// it will be stopped and joined. -TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); - -// Register a listener method that processes printed messages. -// -// If any listeners are registered, the print operator will call all listeners -// with the printed messages and immediately return without writing to the -// logs. -TF_CAPI_EXPORT extern void TF_RegisterLogListener( - void (*listener)(const char*)); - #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 32880378c2b..11fb7705625 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_C_C_API_INTERNAL_H_ #define TENSORFLOW_C_C_API_INTERNAL_H_ -#include "tensorflow/c/c_api.h" - #include #include #include #include #include +#include "tensorflow/c/c_core_api.h" + // clang-format off // Required for IS_MOBILE_PLATFORM #include "tensorflow/core/platform/platform.h" @@ -217,6 +217,10 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) std::string getTF_OutputDebugString(TF_Output node); +TF_Operation* ToOperation(Node* node); + +TensorId ToTensorId(const TF_Output& output); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_core_api.cc b/tensorflow/c/c_core_api.cc new file mode 100644 index 00000000000..67daaef08ac --- /dev/null +++ b/tensorflow/c/c_core_api.cc @@ -0,0 +1,2193 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api.h" + +#include +#include +#include +#include + +#include "absl/strings/match.h" +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" // NOLINT + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/framework/logging.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eval_const_tensor.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/version.h" + +// The implementation below is at the top level instead of the +// brain namespace because we are defining 'extern "C"' functions. +using tensorflow::AllocationDescription; +using tensorflow::DataType; +using tensorflow::ExtendSessionGraphHelper; +using tensorflow::Graph; +using tensorflow::GraphDef; +using tensorflow::mutex_lock; +using tensorflow::NameRangeMap; +using tensorflow::NameRangesForNode; +using tensorflow::NewSession; +using tensorflow::Node; +using tensorflow::NodeBuilder; +using tensorflow::NodeDef; +using tensorflow::OpDef; +using tensorflow::OpRegistry; +using tensorflow::OutputTensor; +using tensorflow::PartialTensorShape; +using tensorflow::RunMetadata; +using tensorflow::RunOptions; +using tensorflow::Session; +using tensorflow::Status; +using tensorflow::string; +using tensorflow::Tensor; +using tensorflow::TensorBuffer; +using tensorflow::TensorId; +using tensorflow::TensorShape; +using tensorflow::TensorShapeProto; +using tensorflow::ToTensorId; +using tensorflow::VersionDef; +using tensorflow::errors::FailedPrecondition; +using tensorflow::errors::InvalidArgument; +using tensorflow::gtl::ArraySlice; +using tensorflow::strings::StrCat; + +extern "C" { + +// -------------------------------------------------------------------------- +const char* TF_Version() { return TF_VERSION_STRING; } + +// -------------------------------------------------------------------------- + +// -------------------------------------------------------------------------- +TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } +void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } + +void TF_SetTarget(TF_SessionOptions* options, const char* target) { + options->options.target = target; +} + +void TF_SetConfig(TF_SessionOptions* options, const void* proto, + size_t proto_len, TF_Status* status) { + if (!options->options.config.ParseFromArray(proto, proto_len)) { + status->status = InvalidArgument("Unparseable ConfigProto"); + } +} +// -------------------------------------------------------------------------- +TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; } + +TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { + void* copy = tensorflow::port::Malloc(proto_len); + memcpy(copy, proto, proto_len); + + TF_Buffer* buf = new TF_Buffer; + buf->data = copy; + buf->length = proto_len; + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; + return buf; +} + +void TF_DeleteBuffer(TF_Buffer* buffer) { + if (buffer == nullptr) return; + if (buffer->data_deallocator != nullptr) { + (*buffer->data_deallocator)(const_cast(buffer->data), + buffer->length); + } + delete buffer; +} + +TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } + +// -------------------------------------------------------------------------- + +TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, + TF_Status* status) { + Session* session; + status->status = NewSession(opt->options, &session); + if (status->status.ok()) { + return new TF_DeprecatedSession({session}); + } else { + DCHECK_EQ(nullptr, session); + return nullptr; + } +} + +void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { + status->status = s->session->Close(); +} + +void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { + status->status = Status::OK(); + if (s == nullptr) return; + delete s->session; + delete s; +} + +void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto, + size_t proto_len, TF_Status* status) { + GraphDef g; + if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { + status->status = InvalidArgument("Invalid GraphDef"); + return; + } + status->status = s->session->Extend(g); +} + +} // end extern "C" + +// Reset helper for converting character arrays to string vectors. +static void TF_Reset_Helper(const TF_SessionOptions* opt, + const char** containers, int ncontainers, + TF_Status* status) { + std::vector container_names(ncontainers); + for (int i = 0; i < ncontainers; ++i) { + container_names[i] = containers[i]; + } + + status->status = Reset(opt->options, container_names); +} + +extern "C" { + +void TF_Reset(const TF_SessionOptions* opt, const char** containers, + int ncontainers, TF_Status* status) { + TF_Reset_Helper(opt, containers, ncontainers, status); +} + +} // end extern "C" + +namespace tensorflow { + + +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out) { + if (out->data != nullptr) { + return InvalidArgument("Passing non-empty TF_Buffer is invalid."); + } + const size_t proto_size = in.ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return tensorflow::errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '", + in.GetTypeName(), "' and size ", proto_size); + } + if (!in.SerializeWithCachedSizesToArray(static_cast(buf))) { + port::Free(buf); + return InvalidArgument("Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", + proto_size, " bytes) is too large?"); + } + out->data = buf; + out->length = proto_size; + out->data_deallocator = [](void* data, size_t length) { port::Free(data); }; + return Status::OK(); +} + +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) { + // If any session has already run this node_id, mark this session as + // unrunnable. + for (auto it : graph->sessions) { + mutex_lock session_lock(it.first->mu); + if (it.first->last_num_graph_nodes > op.node.id()) { + it.second = strings::StrCat( + "Operation '", op.node.DebugString(), "' was changed by ", + mutation_type, + " after it was run by a session. This mutation will have no effect, " + "and will trigger an error in the future. Either don't modify " + "nodes after running them or create a new session."); + } + } +} + +namespace { + +// Helper method that creates a shape handle for a shape described by dims. +tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( + tensorflow::shape_inference::InferenceContext* ic, int num_dims, + const int64_t* dims) { + if (num_dims != -1) { + std::vector dim_vec; + dim_vec.reserve(num_dims); + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + return ic->MakeShape(dim_vec); + } else { + return ic->UnknownShape(); + } +} + +} // namespace + +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + + auto shape_and_type_vec = + std::vector( + num_shapes_and_types); + for (int i = 0; i < num_shapes_and_types; ++i) { + tensorflow::shape_inference::ShapeHandle shape_handle = + ShapeHandleFromDims(ic, ranks[i], shapes[i]); + shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( + shape_handle, static_cast(types[i])); + } + + ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); +} + +// Helpers for loading a TensorFlow plugin (a .so file). +Status LoadLibrary(const char* library_filename, void** result, + const void** buf, size_t* len); + +// TODO(josh11b,mrry): Change Session to be able to use a Graph* +// directly, instead of requiring us to serialize to a GraphDef and +// call Session::Extend(). +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { + if (session->graph != nullptr) { + // Take the graph lock before the session lock to avoid deadlock. This is + // safe since session->graph does not change. + session->graph->mu.lock(); + mutex_lock session_lock(session->mu); + const Graph& graph = session->graph->graph; + + const string& mutation_warning = session->graph->sessions[session]; + if (!mutation_warning.empty()) { + // TODO(b/74949947): turn this back into an error status + LOG(WARNING) << mutation_warning; + session->graph->sessions[session].clear(); + } + + const auto num_nodes = graph.num_node_ids(); + if (session->last_num_graph_nodes < num_nodes) { + // TODO(nolivia): check this on a subset of the graph instead of all of + // it. + status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + + GraphDef graph_def; + *graph_def.mutable_versions() = graph.versions(); + // Fill graph_def with nodes with ids in the range + // [session->last_num_graph_nodes, num_nodes), that is the nodes + // added since the last TF_SessionRun() call. + for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { + Node* const node = graph.FindNodeId(id); + if (node != nullptr && node->IsOp()) { + NodeDef* const node_def = graph_def.add_node(); + *node_def = node->def(); + } + } + *graph_def.mutable_library() = graph.flib_def().ToProto(); + session->graph->mu.unlock(); + status->status = session->session->Extend(std::move(graph_def)); + if (!status->status.ok()) { + // Contract is we always delete input_values[i]. + return false; + } + // Note: session->session is not modified if Extend() fails, so + // we only set last_num_graph_nodes if it succeeds. + session->last_num_graph_nodes = num_nodes; + } else { + session->graph->mu.unlock(); + } + } + return true; +} + +} // namespace tensorflow + +static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, + TF_Status* status) { + status->status = Status::OK(); + for (int i = 0; i < noutputs; ++i) { + c_outputs[i] = nullptr; + } +} + +static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, + std::vector>* input_pairs, + TF_Status* status) { + const int ninputs = input_pairs->size(); + for (int i = 0; i < ninputs; ++i) { + status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); + if (!status->status.ok()) return false; + } + return true; +} + +// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to +// result in a zero-sized tensor. +static TF_Tensor* EmptyTensor(TF_DataType dtype, + const tensorflow::TensorShape& shape) { + static char empty; + tensorflow::int64 nelems = 1; + std::vector dims; + for (int i = 0; i < shape.dims(); ++i) { + dims.push_back(shape.dim_size(i)); + nelems *= shape.dim_size(i); + } + CHECK_EQ(nelems, 0); + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + return TF_NewTensor( + dtype, reinterpret_cast(dims.data()), shape.dims(), + reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); +} + +static void TF_Run_Helper( + Session* session, const char* handle, const TF_Buffer* run_options, + // Input tensors + const std::vector>& input_pairs, + // Output tensors + const std::vector& output_tensor_names, TF_Tensor** c_outputs, + // Target nodes + const std::vector& target_oper_names, TF_Buffer* run_metadata, + TF_Status* status) { + const int noutputs = output_tensor_names.size(); + std::vector outputs(noutputs); + Status result; + + if (handle == nullptr) { + RunOptions run_options_proto; + if (run_options != nullptr && !run_options_proto.ParseFromArray( + run_options->data, run_options->length)) { + status->status = InvalidArgument("Unparseable RunOptions proto"); + return; + } + if (run_metadata != nullptr && run_metadata->data != nullptr) { + status->status = + InvalidArgument("Passing non-empty run_metadata is invalid."); + return; + } + + RunMetadata run_metadata_proto; + result = session->Run(run_options_proto, input_pairs, output_tensor_names, + target_oper_names, &outputs, &run_metadata_proto); + + // Serialize back to upstream client, who now owns the new buffer + if (run_metadata != nullptr) { + status->status = MessageToBuffer(run_metadata_proto, run_metadata); + if (!status->status.ok()) return; + } + } else { + // NOTE(zongheng): PRun does not support RunOptions yet. + result = session->PRun(handle, input_pairs, output_tensor_names, &outputs); + } + if (!result.ok()) { + status->status = result; + return; + } + + // Store results in c_outputs[] + for (int i = 0; i < noutputs; ++i) { + const Tensor& src = outputs[i]; + if (!src.IsInitialized() || src.NumElements() == 0) { + c_outputs[i] = + EmptyTensor(static_cast(src.dtype()), src.shape()); + continue; + } + c_outputs[i] = TF_TensorFromTensor(src, &status->status); + if (!status->status.ok()) return; + } +} + +extern "C" { + +void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, + // Input tensors + const char** c_input_names, TF_Tensor** c_inputs, int ninputs, + // Output tensors + const char** c_output_names, TF_Tensor** c_outputs, int noutputs, + // Target nodes + const char** c_target_oper_names, int ntargets, + TF_Buffer* run_metadata, TF_Status* status) { + TF_Run_Setup(noutputs, c_outputs, status); + std::vector> input_pairs(ninputs); + if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; + for (int i = 0; i < ninputs; ++i) { + input_pairs[i].first = c_input_names[i]; + } + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = c_output_names[i]; + } + std::vector target_oper_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_oper_names[i] = c_target_oper_names[i]; + } + TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names, + c_outputs, target_oper_names, run_metadata, status); +} + +void TF_PRunSetup(TF_DeprecatedSession* s, + // Input names + const char** c_input_names, int ninputs, + // Output names + const char** c_output_names, int noutputs, + // Target nodes + const char** c_target_oper_names, int ntargets, + const char** handle, TF_Status* status) { + *handle = nullptr; + + std::vector input_names(ninputs); + std::vector output_names(noutputs); + std::vector target_oper_names(ntargets); + for (int i = 0; i < ninputs; ++i) { + input_names[i] = c_input_names[i]; + } + for (int i = 0; i < noutputs; ++i) { + output_names[i] = c_output_names[i]; + } + for (int i = 0; i < ntargets; ++i) { + target_oper_names[i] = c_target_oper_names[i]; + } + string new_handle; + status->status = s->session->PRunSetup(input_names, output_names, + target_oper_names, &new_handle); + if (status->status.ok()) { + char* buf = new char[new_handle.size() + 1]; + memcpy(buf, new_handle.c_str(), new_handle.size() + 1); + *handle = buf; + } +} + +void TF_PRun(TF_DeprecatedSession* s, const char* handle, + // Input tensors + const char** c_input_names, TF_Tensor** c_inputs, int ninputs, + // Output tensors + const char** c_output_names, TF_Tensor** c_outputs, int noutputs, + // Target nodes + const char** c_target_oper_names, int ntargets, + TF_Status* status) { + TF_Run_Setup(noutputs, c_outputs, status); + std::vector> input_pairs(ninputs); + if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; + for (int i = 0; i < ninputs; ++i) { + input_pairs[i].first = c_input_names[i]; + } + + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = c_output_names[i]; + } + std::vector target_oper_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_oper_names[i] = c_target_oper_names[i]; + } + TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names, + c_outputs, target_oper_names, nullptr, status); +} + +TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { + TF_Library* lib_handle = new TF_Library; + status->status = tensorflow::LoadLibrary( + library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, + &lib_handle->op_list.length); + if (!status->status.ok()) { + delete lib_handle; + return nullptr; + } + return lib_handle; +} + +TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } + +void TF_DeleteLibraryHandle(TF_Library* lib_handle) { + if (lib_handle == nullptr) return; + tensorflow::port::Free(const_cast(lib_handle->op_list.data)); + delete lib_handle; +} + +TF_Buffer* TF_GetAllOpList() { + std::vector op_defs; + tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs); + tensorflow::OpList op_list; + for (const auto& op : op_defs) { + *(op_list.add_op()) = op; + } + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(op_list, ret)); + return ret; +} + +// -------------------------------------------------------------------------- +// ListDevices & SessionListDevices API + +void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; } + +TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { + TF_DeviceList* response = new TF_DeviceList; + status->status = session->session->ListDevices(&response->response); + return response; +} + +TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session, + TF_Status* status) { + TF_DeviceList* response = new TF_DeviceList; + status->status = session->session->ListDevices(&response->response); + return response; +} + +int TF_DeviceListCount(const TF_DeviceList* list) { + return list->response.size(); +} + +#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \ + return_type method_name(const TF_DeviceList* list, const int index, \ + TF_Status* status) { \ + if (list == nullptr) { \ + status->status = InvalidArgument("list is null!"); \ + return err_val; \ + } \ + if (index < 0 || index >= list->response.size()) { \ + status->status = InvalidArgument("index out of bounds"); \ + return err_val; \ + } \ + status->status = Status::OK(); \ + return list->response[index].accessor; \ + } + +TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); +TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), + nullptr); +TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); +TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0); + +#undef TF_DEVICELIST_METHOD + +} // end extern "C" + +// -------------------------------------------------------------------------- +// New Graph and Session API + +// Helper functions ----------------------------------------------------------- + +namespace tensorflow { + +TF_Operation* ToOperation(Node* node) { + return static_cast(static_cast(node)); +} + +TensorId ToTensorId(const TF_Output& output) { + return TensorId(output.oper->node.name(), output.index); +} + +} // namespace tensorflow + +namespace { + +string OutputName(const TF_Output& output) { + return StrCat(output.oper->node.name(), ":", output.index); +} + +const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, + const char* attr_name, + TF_Status* status) { + const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); + if (attr == nullptr) { + status->status = InvalidArgument("Operation '", oper->node.name(), + "' has no attr named '", attr_name, "'."); + } + return attr; +} + +} // namespace + +// Shape functions ----------------------------------------------------------- + +void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, + const int64_t* dims, const int num_dims, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + tensorflow::shape_inference::ShapeHandle new_shape = + tensorflow::ShapeHandleFromDims(ic, num_dims, dims); + status->status = graph->refiner.SetShape(node, output.index, new_shape); +} + +int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return -1; + } + + tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); + + // Unknown rank means the number of dimensions is -1. + if (!ic->RankKnown(shape)) { + return -1; + } + + return ic->Rank(shape); +} + +void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, + int num_dims, TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + + tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); + + int rank = -1; + if (ic->RankKnown(shape)) { + rank = ic->Rank(shape); + } + + if (num_dims != rank) { + status->status = InvalidArgument("Expected rank is ", num_dims, + " but actual rank is ", rank); + return; + } + + if (num_dims == 0) { + // Output shape is a scalar. + return; + } + + // Rank is greater than 0, so fill in the values, if known, and + // -1 for unknown values. + for (int i = 0; i < num_dims; ++i) { + auto dim = ic->Dim(shape, i); + tensorflow::int64 value = -1; + if (ic->ValueKnown(dim)) { + value = ic->Value(dim); + } + dims[i] = value; + } +} + +// TF_OperationDescription functions ------------------------------------------ + +extern "C" { + +static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, + const char* op_type, + const char* oper_name) + TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + return new TF_OperationDescription(graph, op_type, oper_name); +} + +TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type, + const char* oper_name) { + mutex_lock l(graph->mu); + return TF_NewOperationLocked(graph, op_type, oper_name); +} + +void TF_SetDevice(TF_OperationDescription* desc, const char* device) { + desc->node_builder.Device(device); +} + +void TF_AddInput(TF_OperationDescription* desc, TF_Output input) { + desc->node_builder.Input(&input.oper->node, input.index); +} + +void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs, + int num_inputs) { + std::vector input_list; + input_list.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + input_list.emplace_back(&inputs[i].oper->node, inputs[i].index); + } + desc->node_builder.Input(input_list); +} + +void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { + desc->node_builder.ControlInput(&input->node); +} + +void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { + desc->colocation_constraints.emplace( + StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); +} + +void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, + const void* value, size_t length) { + tensorflow::StringPiece s(static_cast(value), length); + desc->node_builder.Attr(attr_name, s); +} + +void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values) { + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + desc->colocation_constraints.clear(); + for (int i = 0; i < num_values; ++i) { + desc->colocation_constraints.emplace(static_cast(values[i]), + lengths[i]); + } + } else { + std::vector v; + v.reserve(num_values); + for (int i = 0; i < num_values; ++i) { + v.emplace_back(static_cast(values[i]), lengths[i]); + } + desc->node_builder.Attr(attr_name, v); + } +} + +void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, + int64_t value) { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + desc->node_builder.Attr(attr_name, static_cast(value)); +} + +void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name, + const int64_t* values, int num_values) { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + desc->node_builder.Attr( + attr_name, + ArraySlice( + reinterpret_cast(values), num_values)); +} + +void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name, + float value) { + desc->node_builder.Attr(attr_name, value); +} + +void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name, + const float* values, int num_values) { + desc->node_builder.Attr(attr_name, + ArraySlice(values, num_values)); +} + +void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, + unsigned char value) { + desc->node_builder.Attr(attr_name, static_cast(value)); +} + +void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name, + const unsigned char* values, int num_values) { + std::unique_ptr b(new bool[num_values]); + for (int i = 0; i < num_values; ++i) { + b[i] = values[i]; + } + desc->node_builder.Attr(attr_name, + ArraySlice(b.get(), num_values)); +} + +void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, + TF_DataType value) { + desc->node_builder.Attr(attr_name, static_cast(value)); +} + +void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, + const TF_DataType* values, int num_values) { + desc->node_builder.Attr( + attr_name, ArraySlice( + reinterpret_cast(values), num_values)); +} + +void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name, + const char* placeholder) { + tensorflow::AttrValue attr_value; + attr_value.set_placeholder(placeholder); + desc->node_builder.Attr(attr_name, attr_value); +} + +void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, + const char* value, size_t length) { + tensorflow::NameAttrList func_name; + func_name.set_name(string(value, value + length)); + desc->node_builder.Attr(attr_name, func_name); +} + +void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, + const int64_t* dims, int num_dims) { + PartialTensorShape shape; + if (num_dims >= 0) { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + shape = PartialTensorShape(ArraySlice( + reinterpret_cast(dims), num_dims)); + } + desc->node_builder.Attr(attr_name, shape); +} + +void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name, + const int64_t* const* dims, const int* num_dims, + int num_shapes) { + std::vector shapes; + shapes.reserve(num_shapes); + for (int i = 0; i < num_shapes; ++i) { + if (num_dims[i] < 0) { + shapes.emplace_back(); + } else { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + shapes.emplace_back(ArraySlice( + reinterpret_cast(dims[i]), num_dims[i])); + } + } + desc->node_builder.Attr(attr_name, shapes); +} + +void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, + const char* attr_name, const void* proto, + size_t proto_len, TF_Status* status) { + // shape.ParseFromArray takes an int as length, this function takes size_t, + // make sure there is no information loss. + if (proto_len > std::numeric_limits::max()) { + status->status = InvalidArgument( + "proto_len (", proto_len, + " bytes) is too large to be parsed by the protocol buffer library"); + return; + } + TensorShapeProto shape; + if (shape.ParseFromArray(proto, static_cast(proto_len))) { + desc->node_builder.Attr(attr_name, shape); + status->status = Status::OK(); + } else { + status->status = InvalidArgument("Unparseable TensorShapeProto"); + } +} + +void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, + const char* attr_name, + const void* const* protos, + const size_t* proto_lens, int num_shapes, + TF_Status* status) { + std::vector shapes; + shapes.resize(num_shapes); + for (int i = 0; i < num_shapes; ++i) { + if (proto_lens[i] > std::numeric_limits::max()) { + status->status = InvalidArgument( + "length of element ", i, " in the list (", proto_lens[i], + " bytes) is too large to be parsed by the protocol buffer library"); + return; + } + if (!shapes[i].ParseFromArray(protos[i], static_cast(proto_lens[i]))) { + status->status = + InvalidArgument("Unparseable TensorShapeProto at index ", i); + return; + } + } + desc->node_builder.Attr(attr_name, shapes); + status->status = Status::OK(); +} + +void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, + TF_Tensor* value, TF_Status* status) { + Tensor t; + status->status = TF_TensorToTensor(value, &t); + if (status->status.ok()) desc->node_builder.Attr(attr_name, t); +} + +void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, + TF_Tensor* const* values, int num_values, + TF_Status* status) { + status->status = Status::OK(); + std::vector t; + t.reserve(num_values); + + for (int i = 0; i < num_values && status->status.ok(); ++i) { + Tensor v; + status->status = TF_TensorToTensor(values[i], &v); + t.emplace_back(v); + } + + if (status->status.ok()) desc->node_builder.Attr(attr_name, t); +} + +void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::AttrValue attr_value; + if (!attr_value.ParseFromArray(proto, proto_len)) { + status->status = InvalidArgument("Unparseable AttrValue proto"); + return; + } + + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + if (attr_value.value_case() != tensorflow::AttrValue::kList && + attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) { + status->status = + InvalidArgument("Expected \"list\" field for \"", + tensorflow::kColocationAttrName, "\" attribute"); + return; + } + desc->colocation_constraints.clear(); + for (const string& location : attr_value.list().s()) { + desc->colocation_constraints.insert(location); + } + } else { + desc->node_builder.Attr(attr_name, std::move(attr_value)); + } + + status->status = Status::OK(); +} + +static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, + TF_Status* status) + TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { + Node* ret = nullptr; + + if (desc->graph->name_map.count(desc->node_builder.node_name())) { + status->status = InvalidArgument("Duplicate node name in graph: '", + desc->node_builder.node_name(), "'"); + } else { + if (!desc->colocation_constraints.empty()) { + desc->node_builder.Attr( + tensorflow::kColocationAttrName, + std::vector(desc->colocation_constraints.begin(), + desc->colocation_constraints.end())); + } + status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret, + /*consume=*/true); + + if (status->status.ok()) { + // Run shape inference function for newly added node. + status->status = desc->graph->refiner.AddNode(ret); + } + if (status->status.ok()) { + // Add the node to the name-to-node mapping. + desc->graph->name_map[ret->name()] = ret; + } else if (ret != nullptr) { + desc->graph->graph.RemoveNode(ret); + ret = nullptr; + } + } + + delete desc; + + return ToOperation(ret); +} + +TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, + TF_Status* status) { + mutex_lock l(desc->graph->mu); + return TF_FinishOperationLocked(desc, status); +} + +// TF_Operation functions +// ---------------------------------------------------------- + +const char* TF_OperationName(TF_Operation* oper) { + return oper->node.name().c_str(); +} + +const char* TF_OperationOpType(TF_Operation* oper) { + return oper->node.type_string().c_str(); +} + +const char* TF_OperationDevice(TF_Operation* oper) { + return oper->node.requested_device().c_str(); +} + +int TF_OperationNumOutputs(TF_Operation* oper) { + return oper->node.num_outputs(); +} + +TF_DataType TF_OperationOutputType(TF_Output oper_out) { + return static_cast( + oper_out.oper->node.output_type(oper_out.index)); +} + +int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, + TF_Status* status) { + NameRangeMap name_ranges; + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); + if (!status->status.ok()) return -1; + auto iter = name_ranges.find(arg_name); + if (iter == name_ranges.end()) { + status->status = InvalidArgument("Output arg '", arg_name, "' not found"); + return -1; + } + return iter->second.second - iter->second.first; +} + +int TF_OperationNumInputs(TF_Operation* oper) { + return oper->node.num_inputs(); +} + +TF_DataType TF_OperationInputType(TF_Input oper_in) { + return static_cast(oper_in.oper->node.input_type(oper_in.index)); +} + +int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, + TF_Status* status) { + NameRangeMap name_ranges; + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); + if (!status->status.ok()) return -1; + auto iter = name_ranges.find(arg_name); + if (iter == name_ranges.end()) { + status->status = InvalidArgument("Input arg '", arg_name, "' not found"); + return -1; + } + return iter->second.second - iter->second.first; +} + +TF_Output TF_OperationInput(TF_Input oper_in) { + const tensorflow::Edge* edge; + Status s = oper_in.oper->node.input_edge(oper_in.index, &edge); + if (!s.ok()) { + return {nullptr, -1}; + } + + return {ToOperation(edge->src()), edge->src_output()}; +} + +void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs, + int max_inputs) { + for (auto* edge : oper->node.in_edges()) { + if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) { + inputs[edge->dst_input()] = {ToOperation(edge->src()), + edge->src_output()}; + } + } +} + +int TF_OperationOutputNumConsumers(TF_Output oper_out) { + int count = 0; + for (const auto* edge : oper_out.oper->node.out_edges()) { + if (edge->src_output() == oper_out.index) { + ++count; + } + } + return count; +} + +int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, + int max_consumers) { + int count = 0; + for (const auto* edge : oper_out.oper->node.out_edges()) { + if (edge->src_output() == oper_out.index) { + if (count < max_consumers) { + consumers[count] = {ToOperation(edge->dst()), edge->dst_input()}; + } + ++count; + } + } + return count; +} + +int TF_OperationNumControlInputs(TF_Operation* oper) { + int count = 0; + for (const auto* edge : oper->node.in_edges()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { + ++count; + } + } + return count; +} + +int TF_OperationGetControlInputs(TF_Operation* oper, + TF_Operation** control_inputs, + int max_control_inputs) { + int count = 0; + for (const auto* edge : oper->node.in_edges()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { + if (count < max_control_inputs) { + control_inputs[count] = ToOperation(edge->src()); + } + ++count; + } + } + return count; +} + +int TF_OperationNumControlOutputs(TF_Operation* oper) { + int count = 0; + for (const auto* edge : oper->node.out_edges()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { + ++count; + } + } + return count; +} + +int TF_OperationGetControlOutputs(TF_Operation* oper, + TF_Operation** control_outputs, + int max_control_outputs) { + int count = 0; + for (const auto* edge : oper->node.out_edges()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { + if (count < max_control_outputs) { + control_outputs[count] = ToOperation(edge->dst()); + } + ++count; + } + } + return count; +} + +TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, + const char* attr_name, + TF_Status* status) { + TF_AttrMetadata metadata; + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return metadata; + switch (attr->value_case()) { +#define SINGLE_CASE(kK, attr_type, size_expr) \ + case tensorflow::AttrValue::kK: \ + metadata.is_list = 0; \ + metadata.list_size = -1; \ + metadata.type = attr_type; \ + metadata.total_size = size_expr; \ + break; + + SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); + SINGLE_CASE(kI, TF_ATTR_INT, -1); + SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); + SINGLE_CASE(kB, TF_ATTR_BOOL, -1); + SINGLE_CASE(kType, TF_ATTR_TYPE, -1); + SINGLE_CASE(kShape, TF_ATTR_SHAPE, + attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); + SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); +#undef SINGLE_CASE + + case tensorflow::AttrValue::kList: + metadata.is_list = 1; + metadata.list_size = 0; + metadata.total_size = -1; +#define LIST_CASE(field, attr_type, ...) \ + if (attr->list().field##_size() > 0) { \ + metadata.type = attr_type; \ + metadata.list_size = attr->list().field##_size(); \ + __VA_ARGS__; \ + break; \ + } + + LIST_CASE( + s, TF_ATTR_STRING, metadata.total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { metadata.total_size += attr->list().s(i).size(); }); + LIST_CASE(i, TF_ATTR_INT); + LIST_CASE(f, TF_ATTR_FLOAT); + LIST_CASE(b, TF_ATTR_BOOL); + LIST_CASE(type, TF_ATTR_TYPE); + LIST_CASE( + shape, TF_ATTR_SHAPE, metadata.total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); + LIST_CASE(tensor, TF_ATTR_TENSOR); + LIST_CASE(tensor, TF_ATTR_FUNC); +#undef LIST_CASE + // All lists empty, determine the type from the OpDef. + if (metadata.list_size == 0) { + for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { + const auto& a = oper->node.op_def().attr(i); + if (a.name() != attr_name) continue; + const string& typestr = a.type(); + if (typestr == "list(string)") { + metadata.type = TF_ATTR_STRING; + } else if (typestr == "list(int)") { + metadata.type = TF_ATTR_INT; + } else if (typestr == "list(float)") { + metadata.type = TF_ATTR_FLOAT; + } else if (typestr == "list(bool)") { + metadata.type = TF_ATTR_BOOL; + } else if (typestr == "list(type)") { + metadata.type = TF_ATTR_TYPE; + } else if (typestr == "list(shape)") { + metadata.type = TF_ATTR_SHAPE; + } else if (typestr == "list(tensor)") { + metadata.type = TF_ATTR_TENSOR; + } else if (typestr == "list(func)") { + metadata.type = TF_ATTR_FUNC; + } else { + status->status = InvalidArgument( + "Attribute '", attr_name, + "' has an empty value of an unrecognized type '", typestr, "'"); + return metadata; + } + } + } + break; + + case tensorflow::AttrValue::kPlaceholder: + metadata.is_list = 0; + metadata.list_size = -1; + metadata.type = TF_ATTR_PLACEHOLDER; + metadata.total_size = -1; + break; + + case tensorflow::AttrValue::kFunc: + metadata.is_list = 0; + metadata.list_size = -1; + metadata.type = TF_ATTR_FUNC; + metadata.total_size = -1; + break; + + case tensorflow::AttrValue::VALUE_NOT_SET: + status->status = + InvalidArgument("Attribute '", attr_name, "' has no value set"); + break; + } + return metadata; +} + +void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, + void* value, size_t max_length, + TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + if (attr->value_case() != tensorflow::AttrValue::kS) { + status->status = + InvalidArgument("Attribute '", attr_name, "' is not a string"); + return; + } + if (max_length <= 0) { + return; + } + const auto& s = attr->s(); + std::memcpy(value, s.data(), std::min(s.length(), max_length)); +} + +void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, + void** values, size_t* lengths, + int max_values, void* storage, + size_t storage_size, TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + if (attr->value_case() != tensorflow::AttrValue::kList) { + status->status = + InvalidArgument("Value for '", attr_name, "' is not a list"); + return; + } + const auto len = std::min(max_values, attr->list().s_size()); + char* p = static_cast(storage); + for (int i = 0; i < len; ++i) { + const string& s = attr->list().s(i); + values[i] = p; + lengths[i] = s.size(); + if ((p + s.size()) > (static_cast(storage) + storage_size)) { + status->status = InvalidArgument( + "Not enough storage to hold the requested list of strings"); + return; + } + memcpy(values[i], s.data(), s.size()); + p += s.size(); + } +} + +#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ + void func(TF_Operation* oper, const char* attr_name, c_type* value, \ + TF_Status* status) { \ + cpp_type v; \ + status->status = \ + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ + *value = static_cast(v); \ + } \ + void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ + int max_values, TF_Status* status) { \ + const auto* attr = GetAttrValue(oper, attr_name, status); \ + if (!status->status.ok()) return; \ + if (attr->value_case() != tensorflow::AttrValue::kList) { \ + status->status = \ + InvalidArgument("Value for '", attr_name, "' is not a list."); \ + return; \ + } \ + const auto len = std::min(max_values, attr->list().list_field##_size()); \ + for (int i = 0; i < len; ++i) { \ + values[i] = static_cast(attr->list().list_field(i)); \ + } \ + } +DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); +DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); +DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b); +DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); +#undef DEFINE_GETATTR + +void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, + int64_t* value, int num_dims, TF_Status* status) { + PartialTensorShape shape; + status->status = + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); + if (!status->status.ok()) return; + auto len = std::min(shape.dims(), num_dims); + for (int i = 0; i < len; ++i) { + value[i] = shape.dim_size(i); + } +} + +void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, + int64_t** dims, int* num_dims, int num_shapes, + int64_t* storage, int storage_size, + TF_Status* status) { + std::vector shapes; + status->status = + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); + if (!status->status.ok()) return; + auto len = std::min(static_cast(shapes.size()), num_shapes); + int64_t* p = storage; + int storage_left = storage_size; + for (int i = 0; i < len; ++i) { + // shapes[i].dims() == -1 for shapes with an unknown rank. + int64_t n = shapes[i].dims(); + num_dims[i] = n; + dims[i] = p; + if (n < 0) { + continue; + } + if (storage_left < n) { + status->status = InvalidArgument( + "Not enough storage to hold the requested list of shapes"); + return; + } + storage_left -= n; + for (int j = 0; j < n; ++j, ++p) { + *p = shapes[i].dim_size(j); + } + } +} + +void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, + const char* attr_name, + TF_Buffer* value, TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + if (attr->value_case() != tensorflow::AttrValue::kShape) { + status->status = + InvalidArgument("Value for '", attr_name, "' is not a shape."); + return; + } + status->status = MessageToBuffer(attr->shape(), value); +} + +void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, + const char* attr_name, + TF_Buffer** values, int max_values, + TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + if (attr->value_case() != tensorflow::AttrValue::kList) { + status->status = + InvalidArgument("Value for '", attr_name, "' is not a list"); + return; + } + const auto len = std::min(max_values, attr->list().shape_size()); + for (int i = 0; i < len; ++i) { + values[i] = TF_NewBuffer(); + status->status = MessageToBuffer(attr->list().shape(i), values[i]); + if (!status->status.ok()) { + // Delete everything allocated to far, the operation has failed. + for (int j = 0; j <= i; ++j) { + TF_DeleteBuffer(values[j]); + } + return; + } + } +} + +void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, + TF_Tensor** value, TF_Status* status) { + *value = nullptr; + Tensor t; + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); + if (!status->status.ok()) return; + *value = TF_TensorFromTensor(t, &status->status); +} + +void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, + TF_Tensor** values, int max_values, + TF_Status* status) { + std::vector ts; + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); + if (!status->status.ok()) return; + const auto len = std::min(max_values, static_cast(ts.size())); + for (int i = 0; i < len; ++i) { + values[i] = TF_TensorFromTensor(ts[i], &status->status); + } +} + +void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, + TF_Buffer* output_attr_value, + TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + status->status = MessageToBuffer(*attr, output_attr_value); +} + +void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, + TF_Status* status) { + status->status = MessageToBuffer(oper->node.def(), output_node_def); +} + +// TF_Graph functions --------------------------------------------------------- + +TF_Graph::TF_Graph() + : graph(tensorflow::OpRegistry::Global()), + refiner(graph.versions().producer(), graph.op_registry()), + delete_requested(false), + parent(nullptr), + parent_inputs(nullptr) { + // Tell the shape refiner to also run shape inference on functions. + refiner.set_function_library_for_shape_inference(&graph.flib_def()); +} + +TF_Graph* TF_NewGraph() { return new TF_Graph; } + +void TF_DeleteGraph(TF_Graph* g) { + if (g == nullptr) return; + g->mu.lock(); + g->delete_requested = true; + const bool del = g->sessions.empty(); + g->mu.unlock(); + if (del) delete g; +} + +TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) { + mutex_lock l(graph->mu); + auto iter = graph->name_map.find(oper_name); + if (iter == graph->name_map.end()) { + return nullptr; + } else { + return ToOperation(iter->second); + } +} + +TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) { + if (*pos == 0) { + // Advance past the first sentinel nodes in every graph (the source & sink). + *pos += 2; + } else { + // Advance to the next node. + *pos += 1; + } + + mutex_lock l(graph->mu); + while (*pos < static_cast(graph->graph.num_node_ids())) { + Node* node = graph->graph.FindNodeId(*pos); + // FindNodeId() returns nullptr for nodes that have been deleted. + // We aren't currently allowing nodes to be deleted, but it is safer + // to still check. + if (node != nullptr) return ToOperation(node); + *pos += 1; + } + + // No more nodes. + return nullptr; +} + +void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, + TF_Status* status) { + GraphDef def; + { + mutex_lock l(graph->mu); + graph->graph.ToGraphDef(&def); + } + status->status = MessageToBuffer(def, output_graph_def); +} + +void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, + TF_Buffer* output_op_def, TF_Status* status) { + const OpDef* op_def; + { + mutex_lock l(graph->mu); + status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); + if (!status->status.ok()) return; + } + status->status = MessageToBuffer(*op_def, output_op_def); +} + +void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def, + TF_Status* status) { + VersionDef versions; + { + mutex_lock l(graph->mu); + versions = graph->graph.versions(); + } + status->status = MessageToBuffer(versions, output_version_def); +} + +TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { + return new TF_ImportGraphDefOptions; +} +void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) { + delete opts; +} +void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, + const char* prefix) { + opts->opts.prefix = prefix; +} +void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts, + const char* device) { + opts->opts.default_device = device; +} + +void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_names) { + opts->opts.uniquify_names = uniquify_names; +} + +void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_prefix) { + opts->opts.uniquify_prefix = uniquify_prefix; +} + +void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, + const char* src_name, + int src_index, TF_Output dst) { + opts->tensor_id_data.push_back(src_name); + const string& src_name_str = opts->tensor_id_data.back(); + // We don't need to store dst's name in tensor_id_data, since `dst` must + // outlive the ImportGraphDef call. + opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); +} + +void TF_ImportGraphDefOptionsRemapControlDependency( + TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) { + opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] = + TensorId(dst->node.name(), tensorflow::Graph::kControlSlot); +} + +extern void TF_ImportGraphDefOptionsAddControlDependency( + TF_ImportGraphDefOptions* opts, TF_Operation* oper) { + opts->opts.control_dependencies.push_back(oper->node.name()); +} + +void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, + const char* oper_name, int index) { + opts->tensor_id_data.push_back(oper_name); + const string& oper_name_str = opts->tensor_id_data.back(); + opts->opts.return_tensors.emplace_back(oper_name_str, index); +} + +int TF_ImportGraphDefOptionsNumReturnOutputs( + const TF_ImportGraphDefOptions* opts) { + return opts->opts.return_tensors.size(); +} + +void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, + const char* oper_name) { + opts->opts.return_nodes.push_back(oper_name); +} + +int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts) { + return opts->opts.return_nodes.size(); +} + +void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, + int* num_outputs, + TF_Output** outputs) { + *num_outputs = results->return_tensors.size(); + *outputs = results->return_tensors.data(); +} + +void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, + int* num_opers, + TF_Operation*** opers) { + *num_opers = results->return_nodes.size(); + *opers = results->return_nodes.data(); +} + +void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, + const char*** src_names, int** src_indexes) { + *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); + *src_names = results->missing_unused_key_names.data(); + *src_indexes = results->missing_unused_key_indexes.data(); +} + +void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { + delete results; +} + +static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, + const TF_ImportGraphDefOptions* opts, + TF_ImportGraphDefResults* tf_results, + TF_Status* status) + TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + const int last_node_id = graph->graph.num_node_ids(); + tensorflow::ImportGraphDefResults results; + status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, + &graph->refiner, &results); + if (!status->status.ok()) return; + + // Add new nodes to name_map + for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { + auto* node = graph->graph.FindNodeId(i); + if (node != nullptr) graph->name_map[node->name()] = node; + } + + // Populate return_tensors + DCHECK(tf_results->return_tensors.empty()); + tf_results->return_tensors.resize(results.return_tensors.size()); + for (int i = 0; i < results.return_tensors.size(); ++i) { + tf_results->return_tensors[i].oper = + ToOperation(results.return_tensors[i].first); + tf_results->return_tensors[i].index = results.return_tensors[i].second; + } + + // Populate return_nodes + DCHECK(tf_results->return_nodes.empty()); + tf_results->return_nodes.resize(results.return_nodes.size()); + for (int i = 0; i < results.return_nodes.size(); ++i) { + tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); + } + + // Populate missing unused map keys + DCHECK(tf_results->missing_unused_key_names.empty()); + DCHECK(tf_results->missing_unused_key_indexes.empty()); + DCHECK(tf_results->missing_unused_key_names_data.empty()); + + size_t size = results.missing_unused_input_map_keys.size(); + tf_results->missing_unused_key_names.resize(size); + tf_results->missing_unused_key_indexes.resize(size); + + for (int i = 0; i < size; ++i) { + TensorId id = results.missing_unused_input_map_keys[i]; + tf_results->missing_unused_key_names_data.emplace_back(id.first); + tf_results->missing_unused_key_names[i] = + tf_results->missing_unused_key_names_data.back().c_str(); + tf_results->missing_unused_key_indexes[i] = id.second; + } +} + +TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status) { + GraphDef def; + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { + status->status = InvalidArgument("Invalid GraphDef"); + return nullptr; + } + auto results = new TF_ImportGraphDefResults(); + mutex_lock l(graph->mu); + GraphImportGraphDefLocked(graph, def, options, results, status); + if (!status->status.ok()) { + delete results; + return nullptr; + } + return results; +} + +void TF_GraphImportGraphDefWithReturnOutputs( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, + int num_return_outputs, TF_Status* status) { + if (num_return_outputs != options->opts.return_tensors.size()) { + status->status = InvalidArgument("Expected 'num_return_outputs' to be ", + options->opts.return_tensors.size(), + ", got ", num_return_outputs); + return; + } + if (num_return_outputs > 0 && return_outputs == nullptr) { + status->status = InvalidArgument( + "'return_outputs' must be preallocated to length ", num_return_outputs); + return; + } + GraphDef def; + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { + status->status = InvalidArgument("Invalid GraphDef"); + return; + } + TF_ImportGraphDefResults results; + mutex_lock l(graph->mu); + GraphImportGraphDefLocked(graph, def, options, &results, status); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); + memcpy(return_outputs, results.return_tensors.data(), + num_return_outputs * sizeof(TF_Output)); +} + +void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, + TF_Status* status) { + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); + TF_DeleteImportGraphDefResults(results); +} + +// TF_Session functions ---------------------------------------------- + +TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) + : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {} + +TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, + TF_Status* status) { + Session* session; + status->status = NewSession(opt->options, &session); + if (status->status.ok()) { + TF_Session* new_session = new TF_Session(session, graph); + if (graph != nullptr) { + mutex_lock l(graph->mu); + graph->sessions[new_session] = ""; + } + return new_session; + } else { + DCHECK_EQ(nullptr, session); + return nullptr; + } +} + +TF_Session* TF_LoadSessionFromSavedModel( + const TF_SessionOptions* session_options, const TF_Buffer* run_options, + const char* export_dir, const char* const* tags, int tags_len, + TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) { +// TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring +// that the tensorflow/cc/saved_model:loader build target is mobile friendly. +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Loading a SavedModel is not supported on mobile. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); + return nullptr; +#else + mutex_lock l(graph->mu); + if (!graph->name_map.empty()) { + status->status = InvalidArgument("Graph is non-empty."); + return nullptr; + } + + RunOptions run_options_proto; + if (run_options != nullptr && !run_options_proto.ParseFromArray( + run_options->data, run_options->length)) { + status->status = InvalidArgument("Unparseable RunOptions proto"); + return nullptr; + } + + std::unordered_set tag_set; + for (int i = 0; i < tags_len; i++) { + tag_set.insert(string(tags[i])); + } + + tensorflow::SavedModelBundle bundle; + status->status = + tensorflow::LoadSavedModel(session_options->options, run_options_proto, + export_dir, tag_set, &bundle); + if (!status->status.ok()) return nullptr; + + // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session + // extends using GraphDefs. The Graph instance is different, but equivalent + // to the one used to create the session. + // + // TODO(jhseu): When Session is modified to take Graphs instead of + // GraphDefs, return the Graph generated in LoadSavedModel(). + TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefResults results; + GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), + import_opts, &results, status); + TF_DeleteImportGraphDefOptions(import_opts); + if (!status->status.ok()) return nullptr; + + if (meta_graph_def != nullptr) { + status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); + if (!status->status.ok()) return nullptr; + } + + TF_Session* session = new TF_Session(bundle.session.release(), graph); + + graph->sessions[session] = ""; + session->last_num_graph_nodes = graph->graph.num_node_ids(); + return session; +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +void TF_CloseSession(TF_Session* s, TF_Status* status) { + status->status = s->session->Close(); +} + +void TF_DeleteSession(TF_Session* s, TF_Status* status) { + status->status = Status::OK(); + if (s == nullptr) return; + TF_Graph* const graph = s->graph; + if (graph != nullptr) { + graph->mu.lock(); + graph->sessions.erase(s); + const bool del = graph->delete_requested && graph->sessions.empty(); + graph->mu.unlock(); + if (del) delete graph; + } + delete s->session; + delete s; +} + +void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, + const TF_Output* inputs, TF_Tensor* const* input_values, + int ninputs, const TF_Output* outputs, + TF_Tensor** output_values, int noutputs, + const TF_Operation* const* target_opers, int ntargets, + TF_Buffer* run_metadata, TF_Status* status) { + // TODO(josh11b,mrry): Change Session to be able to use a Graph* + // directly, instead of requiring us to serialize to a GraphDef and + // call Session::Extend(). + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; + } + + TF_Run_Setup(noutputs, output_values, status); + + // Convert from TF_Output and TF_Tensor to a string and Tensor. + std::vector> input_pairs(ninputs); + if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; + for (int i = 0; i < ninputs; ++i) { + input_pairs[i].first = OutputName(inputs[i]); + } + + // Convert from TF_Output to string names. + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = OutputName(outputs[i]); + } + + // Convert from TF_Operation* to string names. + std::vector target_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_names[i] = target_opers[i]->node.name(); + } + + // Actually run. + TF_Run_Helper(session->session, nullptr, run_options, input_pairs, + output_names, output_values, target_names, run_metadata, + status); +} + +void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, + int ninputs, const TF_Output* outputs, int noutputs, + const TF_Operation* const* target_opers, int ntargets, + const char** handle, TF_Status* status) { + *handle = nullptr; + + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; + } + + std::vector input_names(ninputs); + for (int i = 0; i < ninputs; ++i) { + input_names[i] = OutputName(inputs[i]); + } + + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = OutputName(outputs[i]); + } + + std::vector target_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_names[i] = target_opers[i]->node.name(); + } + + string new_handle; + status->status = session->session->PRunSetup(input_names, output_names, + target_names, &new_handle); + if (status->status.ok()) { + char* buf = new char[new_handle.size() + 1]; + memcpy(buf, new_handle.c_str(), new_handle.size() + 1); + *handle = buf; + } +} + +void TF_DeletePRunHandle(const char* handle) { + delete[] handle; + // TODO(suharshs): Free up any resources held by the partial run state. +} + +void TF_SessionPRun(TF_Session* session, const char* handle, + const TF_Output* inputs, TF_Tensor* const* input_values, + int ninputs, const TF_Output* outputs, + TF_Tensor** output_values, int noutputs, + const TF_Operation* const* target_opers, int ntargets, + TF_Status* status) { + // TODO(josh11b,mrry): Change Session to be able to use a Graph* + // directly, instead of requiring us to serialize to a GraphDef and + // call Session::Extend(). + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; + } + + TF_Run_Setup(noutputs, output_values, status); + + // Convert from TF_Output and TF_Tensor to a string and Tensor. + std::vector> input_pairs(ninputs); + if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; + for (int i = 0; i < ninputs; ++i) { + input_pairs[i].first = OutputName(inputs[i]); + } + + // Convert from TF_Output to string names. + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = OutputName(outputs[i]); + } + + // Convert from TF_Operation* to string names. + std::vector target_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_names[i] = target_opers[i]->node.name(); + } + + TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names, + output_values, target_names, nullptr, status); +} + +unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, + TF_Tensor** result, TF_Status* status) { + *result = nullptr; + mutex_lock l(graph->mu); + OutputTensor tensor(&output.oper->node, output.index); + bool evaluated; + Tensor result_tensor; + status->status = EvaluateConstantTensor( + tensor, graph->refiner, *graph->graph.op_registry(), + graph->graph.versions().producer(), &evaluated, &result_tensor); + if (evaluated) { + DCHECK(status->status.ok()); + *result = TF_TensorFromTensor(result_tensor, &status->status); + if (!status->status.ok()) evaluated = false; + } + return evaluated; +} + +TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { + tensorflow::OpList op_list; + if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { + status->status = InvalidArgument("Unparseable OpList"); + return nullptr; + } + status->status = Status::OK(); + return new TF_ApiDefMap(op_list); +} + +void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } + +void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, + size_t text_len, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported on mobile."); +#else + mutex_lock l(api_def_map->lock); + if (api_def_map->update_docs_called) { + status->status = FailedPrecondition( + "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " + "called."); + return; + } + string api_def_text(text, text_len); + status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, + size_t name_len, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported on mobile."); + return nullptr; +#else + mutex_lock l(api_def_map->lock); + if (!api_def_map->update_docs_called) { + api_def_map->api_def_map.UpdateDocs(); + api_def_map->update_docs_called = true; + } + string name_str(name, name_len); + const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); + if (api_def == nullptr) { + return nullptr; + } + + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(*api_def, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { + tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(kernel_list, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +} + +TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { + tensorflow::KernelList kernel_list = + tensorflow::GetRegisteredKernelsForOp(name); + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(kernel_list, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +} + +// TF_Server functions ---------------------------------------------- + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +TF_Server::TF_Server(std::unique_ptr server) + : target(server->target()), server(std::move(server)) {} +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + +TF_Server* TF_NewServer(const void* proto, size_t proto_len, + TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported on mobile"); + return nullptr; +#else + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, static_cast(proto_len))) { + status->status = InvalidArgument( + "Could not parse provided bytes into a ServerDef protocol buffer"); + return nullptr; + } + + std::unique_ptr out_server; + status->status = tensorflow::NewServer(server_def, &out_server); + if (!status->status.ok()) return nullptr; + + return new TF_Server(std::move(out_server)); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +void TF_ServerStart(TF_Server* server, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported on mobile"); +#else + status->status = server->server->Start(); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +void TF_ServerStop(TF_Server* server, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported on mobile"); +#else + status->status = server->server->Stop(); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +void TF_ServerJoin(TF_Server* server, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported on mobile"); +#else + status->status = server->server->Join(); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +const char* TF_ServerTarget(TF_Server* server) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + return nullptr; +#else + return server->target.c_str(); +#endif +} + +void TF_DeleteServer(TF_Server* server) { +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + delete server; +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +} + +void TF_RegisterLogListener(void (*listener)(const char*)) { +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + tensorflow::logging::RegisterListener(listener); +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +} + +} // end extern "C" diff --git a/tensorflow/c/c_core_api.h b/tensorflow/c/c_core_api.h new file mode 100644 index 00000000000..d3b5447b717 --- /dev/null +++ b/tensorflow/c/c_core_api.h @@ -0,0 +1,1456 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_CORE_API_H_ +#define TENSORFLOW_C_C_CORE_API_H_ + +#include +#include + +#include "tensorflow/c/tf_attrtype.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" + +// -------------------------------------------------------------------------- +// C API for TensorFlow. +// +// The API leans towards simplicity and uniformity instead of convenience +// since most usage will be by language specific wrappers. +// +// Conventions: +// * We use the prefix TF_ for everything in the API. +// * Objects are always passed around as pointers to opaque structs +// and these structs are allocated/deallocated via the API. +// * TF_Status holds error information. It is an object type +// and therefore is passed around as a pointer to an opaque +// struct as mentioned above. +// * Every call that has a TF_Status* argument clears it on success +// and fills it with error info on failure. +// * unsigned char is used for booleans (instead of the 'bool' type). +// In C++ bool is a keyword while in C99 bool is a macro defined +// in stdbool.h. It is possible for the two to be inconsistent. +// For example, neither the C99 nor the C++11 standard force a byte +// size on the bool type, so the macro defined in stdbool.h could +// be inconsistent with the bool keyword in C++. Thus, the use +// of stdbool.h is avoided and unsigned char is used instead. +// * size_t is used to represent byte sizes of objects that are +// materialized in the address space of the calling process. +// * int is used as an index into arrays. +// * Deletion functions are safe to call on nullptr. +// +// Questions left to address: +// * Might at some point need a way for callers to provide their own Env. +// * Maybe add TF_TensorShape that encapsulates dimension info. +// +// Design decisions made: +// * Backing store for tensor memory has an associated deallocation +// function. This deallocation function will point to client code +// for tensors populated by the client. So the client can do things +// like shadowing a numpy array. +// * We do not provide TF_OK since it is not strictly necessary and we +// are not optimizing for convenience. +// * We make assumption that one session has one graph. This should be +// fine since we have the ability to run sub-graphs. +// * We could allow NULL for some arguments (e.g., NULL options arg). +// However since convenience is not a primary goal, we don't do this. +// * Devices are not in this API. Instead, they are created/used internally +// and the API just provides high level controls over the number of +// devices of each type. + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes. +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// TF_Version returns a string describing version information of the +// TensorFlow library. TensorFlow using semantic versioning. +TF_CAPI_EXPORT extern const char* TF_Version(void); + +// -------------------------------------------------------------------------- +// TF_Buffer holds a pointer to a block of data and its associated length. +// Typically, the data consists of a serialized protocol buffer, but other data +// may also be held in a buffer. +// +// By default, TF_Buffer itself does not do any memory management of the +// pointed-to block. If need be, users of this struct should specify how to +// deallocate the block by setting the `data_deallocator` function pointer. +typedef struct TF_Buffer { + const void* data; + size_t length; + void (*data_deallocator)(void* data, size_t length); +} TF_Buffer; + +// Makes a copy of the input and sets an appropriate deallocator. Useful for +// passing in read-only, input protobufs. +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, + size_t proto_len); + +// Useful for passing *out* a protobuf. +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); + +TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); + +TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); + +// -------------------------------------------------------------------------- +// TF_SessionOptions holds options that can be passed during session creation. +typedef struct TF_SessionOptions TF_SessionOptions; + +// Return a new options object. +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); + +// Set the target in TF_SessionOptions.options. +// target can be empty, a single entry, or a comma separated list of entries. +// Each entry is in one of the following formats : +// "local" +// ip:port +// host:port +TF_CAPI_EXPORT extern void TF_SetTarget(TF_SessionOptions* options, + const char* target); + +// Set the config in TF_SessionOptions.options. +// config should be a serialized tensorflow.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +TF_CAPI_EXPORT extern void TF_SetConfig(TF_SessionOptions* options, + const void* proto, size_t proto_len, + TF_Status* status); + +// Destroy an options object. +TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); + +// TODO(jeff,sanjay): +// - export functions to set Config fields + +// -------------------------------------------------------------------------- +// The new graph construction API, still under development. + +// Represents a computation graph. Graphs may be shared between sessions. +// Graphs are thread-safe when used as directed below. +typedef struct TF_Graph TF_Graph; + +// Return a new graph object. +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); + +// Destroy an options object. Graph will be deleted once no more +// TFSession's are referencing it. +TF_CAPI_EXPORT extern void TF_DeleteGraph(TF_Graph*); + +// Operation being built. The underlying graph must outlive this. +typedef struct TF_OperationDescription TF_OperationDescription; + +// Operation that has been added to the graph. Valid until the graph is +// deleted -- in particular adding a new operation to the graph does not +// invalidate old TF_Operation* pointers. +typedef struct TF_Operation TF_Operation; + +// Represents a specific input of an operation. +typedef struct TF_Input { + TF_Operation* oper; + int index; // The index of the input within oper. +} TF_Input; + +// Represents a specific output of an operation. +typedef struct TF_Output { + TF_Operation* oper; + int index; // The index of the output within oper. +} TF_Output; + +// TF_Function is a grouping of operations with defined inputs and outputs. +// Once created and added to graphs, functions can be invoked by creating an +// operation whose operation type matches the function name. +typedef struct TF_Function TF_Function; + +// Function definition options. TODO(iga): Define and implement +typedef struct TF_FunctionOptions TF_FunctionOptions; + +// Sets the shape of the Tensor referenced by `output` in `graph` to +// the shape described by `dims` and `num_dims`. +// +// If the number of dimensions is unknown, `num_dims` must be set to +// -1 and `dims` can be null. If a dimension is unknown, the +// corresponding entry in the `dims` array must be -1. +// +// This does not overwrite the existing shape associated with `output`, +// but merges the input shape with the existing shape. For example, +// setting a shape of [-1, 2] with an existing shape [2, -1] would set +// a final shape of [2, 2] based on shape merging semantics. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +// * An invalid shape is being set (e.g., the shape being set +// is incompatible with the existing shape). +TF_CAPI_EXPORT extern void TF_GraphSetTensorShape(TF_Graph* graph, + TF_Output output, + const int64_t* dims, + const int num_dims, + TF_Status* status); + +// Returns the number of dimensions of the Tensor referenced by `output` +// in `graph`. +// +// If the number of dimensions in the shape is unknown, returns -1. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +TF_CAPI_EXPORT extern int TF_GraphGetTensorNumDims(TF_Graph* graph, + TF_Output output, + TF_Status* status); + +// Returns the shape of the Tensor referenced by `output` in `graph` +// into `dims`. `dims` must be an array large enough to hold `num_dims` +// entries (e.g., the return value of TF_GraphGetTensorNumDims). +// +// If the number of dimensions in the shape is unknown or the shape is +// a scalar, `dims` will remain untouched. Otherwise, each element of +// `dims` will be set corresponding to the size of the dimension. An +// unknown dimension is represented by `-1`. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +// * `num_dims` does not match the actual number of dimensions. +TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, + TF_Output output, + int64_t* dims, int num_dims, + TF_Status* status); + +// Operation will only be added to *graph when TF_FinishOperation() is +// called (assuming TF_FinishOperation() does not return an error). +// *graph must not be deleted until after TF_FinishOperation() is +// called. +TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperation( + TF_Graph* graph, const char* op_type, const char* oper_name); + +// Specify the device for `desc`. Defaults to empty, meaning unconstrained. +TF_CAPI_EXPORT extern void TF_SetDevice(TF_OperationDescription* desc, + const char* device); + +// The calls to TF_AddInput and TF_AddInputList must match (in number, +// order, and type) the op declaration. For example, the "Concat" op +// has registration: +// REGISTER_OP("Concat") +// .Input("concat_dim: int32") +// .Input("values: N * T") +// .Output("output: T") +// .Attr("N: int >= 2") +// .Attr("T: type"); +// that defines two inputs, "concat_dim" and "values" (in that order). +// You must use TF_AddInput() for the first input (since it takes a +// single tensor), and TF_AddInputList() for the second input (since +// it takes a list, even if you were to pass a list with a single +// tensor), as in: +// TF_OperationDescription* desc = TF_NewOperation(graph, "Concat", "c"); +// TF_Output concat_dim_input = {...}; +// TF_AddInput(desc, concat_dim_input); +// TF_Output values_inputs[5] = {{...}, ..., {...}}; +// TF_AddInputList(desc, values_inputs, 5); + +// For inputs that take a single tensor. +TF_CAPI_EXPORT extern void TF_AddInput(TF_OperationDescription* desc, + TF_Output input); + +// For inputs that take a list of tensors. +// inputs must point to TF_Output[num_inputs]. +TF_CAPI_EXPORT extern void TF_AddInputList(TF_OperationDescription* desc, + const TF_Output* inputs, + int num_inputs); + +// Call once per control input to `desc`. +TF_CAPI_EXPORT extern void TF_AddControlInput(TF_OperationDescription* desc, + TF_Operation* input); + +// Request that `desc` be co-located on the device where `op` +// is placed. +// +// Use of this is discouraged since the implementation of device placement is +// subject to change. Primarily intended for internal libraries +TF_CAPI_EXPORT extern void TF_ColocateWith(TF_OperationDescription* desc, + TF_Operation* op); + +// Call some TF_SetAttr*() function for every attr that is not +// inferred from an input and doesn't have a default value you wish to +// keep. + +// `value` must point to a string of length `length` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrString(TF_OperationDescription* desc, + const char* attr_name, + const void* value, size_t length); +// `values` and `lengths` each must have lengths `num_values`. +// `values[i]` must point to a string of length `lengths[i]` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrStringList(TF_OperationDescription* desc, + const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrInt(TF_OperationDescription* desc, + const char* attr_name, int64_t value); +TF_CAPI_EXPORT extern void TF_SetAttrIntList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrFloat(TF_OperationDescription* desc, + const char* attr_name, float value); +TF_CAPI_EXPORT extern void TF_SetAttrFloatList(TF_OperationDescription* desc, + const char* attr_name, + const float* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrBool(TF_OperationDescription* desc, + const char* attr_name, + unsigned char value); +TF_CAPI_EXPORT extern void TF_SetAttrBoolList(TF_OperationDescription* desc, + const char* attr_name, + const unsigned char* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrType(TF_OperationDescription* desc, + const char* attr_name, + TF_DataType value); +TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, + const char* attr_name, + const TF_DataType* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc, + const char* attr_name, + const char* placeholder); + +// Set a 'func' attribute to the specified name. +// `value` must point to a string of length `length` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, + const char* attr_name, + const char* value, size_t length); + +// Set `num_dims` to -1 to represent "unknown rank". Otherwise, +// `dims` points to an array of length `num_dims`. `dims[i]` must be +// >= -1, with -1 meaning "unknown dimension". +TF_CAPI_EXPORT extern void TF_SetAttrShape(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* dims, int num_dims); +// `dims` and `num_dims` must point to arrays of length `num_shapes`. +// Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise, +// `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]` +// must be >= -1, with -1 meaning "unknown dimension". +TF_CAPI_EXPORT extern void TF_SetAttrShapeList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* const* dims, + const int* num_dims, + int num_shapes); +// `proto` must point to an array of `proto_len` bytes representing a +// binary-serialized TensorShapeProto. +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProto( + TF_OperationDescription* desc, const char* attr_name, const void* proto, + size_t proto_len, TF_Status* status); +// `protos` and `proto_lens` must point to arrays of length `num_shapes`. +// `protos[i]` must point to an array of `proto_lens[i]` bytes +// representing a binary-serialized TensorShapeProto. +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProtoList( + TF_OperationDescription* desc, const char* attr_name, + const void* const* protos, const size_t* proto_lens, int num_shapes, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_SetAttrTensor(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* value, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorList(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* const* values, + int num_values, + TF_Status* status); + +// `proto` should point to a sequence of bytes of length `proto_len` +// representing a binary serialization of an AttrValue protocol +// buffer. +TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// If this function succeeds: +// * *status is set to an OK value, +// * a TF_Operation is added to the graph, +// * a non-null value pointing to the added operation is returned -- +// this value is valid until the underlying graph is deleted. +// Otherwise: +// * *status is set to a non-OK value, +// * the graph is not modified, +// * a null value is returned. +// In either case, it deletes `desc`. +TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperation( + TF_OperationDescription* desc, TF_Status* status); + +// TF_Operation functions. Operations are immutable once created, so +// these are all query functions. + +TF_CAPI_EXPORT extern const char* TF_OperationName(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationOpType(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationDevice(TF_Operation* oper); + +TF_CAPI_EXPORT extern int TF_OperationNumOutputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationOutputType(TF_Output oper_out); +TF_CAPI_EXPORT extern int TF_OperationOutputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); + +TF_CAPI_EXPORT extern int TF_OperationNumInputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationInputType(TF_Input oper_in); +TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); + +// In this code: +// TF_Output producer = TF_OperationInput(consumer); +// There is an edge from producer.oper's output (given by +// producer.index) to consumer.oper's input (given by consumer.index). +TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in); + +// Get list of all inputs of a specific operation. `inputs` must point to +// an array of length at least `max_inputs` (ideally set to +// TF_OperationNumInputs(oper)). Beware that a concurrent +// modification of the graph can increase the number of inputs of +// an operation. +TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper, + TF_Output* inputs, + int max_inputs); + +// Get the number of current consumers of a specific output of an +// operation. Note that this number can change when new operations +// are added to the graph. +TF_CAPI_EXPORT extern int TF_OperationOutputNumConsumers(TF_Output oper_out); + +// Get list of all current consumers of a specific output of an +// operation. `consumers` must point to an array of length at least +// `max_consumers` (ideally set to +// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent +// modification of the graph can increase the number of consumers of +// an operation. Returns the number of output consumers (should match +// TF_OperationOutputNumConsumers(oper_out)). +TF_CAPI_EXPORT extern int TF_OperationOutputConsumers(TF_Output oper_out, + TF_Input* consumers, + int max_consumers); + +// Get the number of control inputs to an operation. +TF_CAPI_EXPORT extern int TF_OperationNumControlInputs(TF_Operation* oper); + +// Get list of all control inputs to an operation. `control_inputs` must +// point to an array of length `max_control_inputs` (ideally set to +// TF_OperationNumControlInputs(oper)). Returns the number of control +// inputs (should match TF_OperationNumControlInputs(oper)). +TF_CAPI_EXPORT extern int TF_OperationGetControlInputs( + TF_Operation* oper, TF_Operation** control_inputs, int max_control_inputs); + +// Get the number of operations that have `*oper` as a control input. +// Note that this number can change when new operations are added to +// the graph. +TF_CAPI_EXPORT extern int TF_OperationNumControlOutputs(TF_Operation* oper); + +// Get the list of operations that have `*oper` as a control input. +// `control_outputs` must point to an array of length at least +// `max_control_outputs` (ideally set to +// TF_OperationNumControlOutputs(oper)). Beware that a concurrent +// modification of the graph can increase the number of control +// outputs. Returns the number of control outputs (should match +// TF_OperationNumControlOutputs(oper)). +TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( + TF_Operation* oper, TF_Operation** control_outputs, + int max_control_outputs); + +// TF_AttrMetadata describes the value of an attribute on an operation. +typedef struct TF_AttrMetadata { + // A boolean: 1 if the attribute value is a list, 0 otherwise. + unsigned char is_list; + + // Length of the list if is_list is true. Undefined otherwise. + int64_t list_size; + + // Type of elements of the list if is_list != 0. + // Type of the single value stored in the attribute if is_list == 0. + TF_AttrType type; + + // Total size the attribute value. + // The units of total_size depend on is_list and type. + // (1) If type == TF_ATTR_STRING and is_list == 0 + // then total_size is the byte size of the string + // valued attribute. + // (2) If type == TF_ATTR_STRING and is_list == 1 + // then total_size is the cumulative byte size + // of all the strings in the list. + // (3) If type == TF_ATTR_SHAPE and is_list == 0 + // then total_size is the number of dimensions + // of the shape valued attribute, or -1 + // if its rank is unknown. + // (4) If type == TF_ATTR_SHAPE and is_list == 1 + // then total_size is the cumulative number + // of dimensions of all shapes in the list. + // (5) Otherwise, total_size is undefined. + int64_t total_size; +} TF_AttrMetadata; + +// Returns metadata about the value of the attribute `attr_name` of `oper`. +TF_CAPI_EXPORT extern TF_AttrMetadata TF_OperationGetAttrMetadata( + TF_Operation* oper, const char* attr_name, TF_Status* status); + +// Fills in `value` with the value of the attribute `attr_name`. `value` must +// point to an array of length at least `max_length` (ideally set to +// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrString(TF_Operation* oper, + const char* attr_name, + void* value, + size_t max_length, + TF_Status* status); + +// Get the list of strings in the value of the attribute `attr_name`. Fills in +// `values` and `lengths`, each of which must point to an array of length at +// least `max_values`. +// +// The elements of values will point to addresses in `storage` which must be at +// least `storage_size` bytes in length. Ideally, max_values would be set to +// TF_AttrMetadata.list_size and `storage` would be at least +// TF_AttrMetadata.total_size, obtained from TF_OperationGetAttrMetadata(oper, +// attr_name). +// +// Fails if storage_size is too small to hold the requested number of strings. +TF_CAPI_EXPORT extern void TF_OperationGetAttrStringList( + TF_Operation* oper, const char* attr_name, void** values, size_t* lengths, + int max_values, void* storage, size_t storage_size, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrInt(TF_Operation* oper, + const char* attr_name, + int64_t* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrIntList(TF_Operation* oper, + const char* attr_name, + int64_t* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloat(TF_Operation* oper, + const char* attr_name, + float* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloatList(TF_Operation* oper, + const char* attr_name, + float* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrBool(TF_Operation* oper, + const char* attr_name, + unsigned char* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrBoolList(TF_Operation* oper, + const char* attr_name, + unsigned char* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrType(TF_Operation* oper, + const char* attr_name, + TF_DataType* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTypeList(TF_Operation* oper, + const char* attr_name, + TF_DataType* values, + int max_values, + TF_Status* status); + +// Fills in `value` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `num_dims` (ideally set to +// TF_Attr_Meta.size from TF_OperationGetAttrMetadata(oper, attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrShape(TF_Operation* oper, + const char* attr_name, + int64_t* value, + int num_dims, + TF_Status* status); + +// Fills in `dims` with the list of shapes in the attribute `attr_name` of +// `oper` and `num_dims` with the corresponding number of dimensions. On return, +// for every i where `num_dims[i]` > 0, `dims[i]` will be an array of +// `num_dims[i]` elements. A value of -1 for `num_dims[i]` indicates that the +// i-th shape in the list is unknown. +// +// The elements of `dims` will point to addresses in `storage` which must be +// large enough to hold at least `storage_size` int64_ts. Ideally, `num_shapes` +// would be set to TF_AttrMetadata.list_size and `storage_size` would be set to +// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, +// attr_name). +// +// Fails if storage_size is insufficient to hold the requested shapes. +TF_CAPI_EXPORT extern void TF_OperationGetAttrShapeList( + TF_Operation* oper, const char* attr_name, int64_t** dims, int* num_dims, + int num_shapes, int64_t* storage, int storage_size, TF_Status* status); + +// Sets `value` to the binary-serialized TensorShapeProto of the value of +// `attr_name` attribute of `oper`'. +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* value, + TF_Status* status); + +// Fills in `values` with binary-serialized TensorShapeProto values of the +// attribute `attr_name` of `oper`. `values` must point to an array of length at +// least `num_values` (ideally set to TF_AttrMetadata.list_size from +// TF_OperationGetAttrMetadata(oper, attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProtoList( + TF_Operation* oper, const char* attr_name, TF_Buffer** values, + int max_values, TF_Status* status); + +// Gets the TF_Tensor valued attribute of `attr_name` of `oper`. +// +// Allocates a new TF_Tensor which the caller is expected to take +// ownership of (and can deallocate using TF_DeleteTensor). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensor(TF_Operation* oper, + const char* attr_name, + TF_Tensor** value, + TF_Status* status); + +// Fills in `values` with the TF_Tensor values of the attribute `attr_name` of +// `oper`. `values` must point to an array of TF_Tensor* of length at least +// `max_values` (ideally set to TF_AttrMetadata.list_size from +// TF_OperationGetAttrMetadata(oper, attr_name)). +// +// The caller takes ownership of all the non-null TF_Tensor* entries in `values` +// (which can be deleted using TF_DeleteTensor(values[i])). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorList(TF_Operation* oper, + const char* attr_name, + TF_Tensor** values, + int max_values, + TF_Status* status); + +// Sets `output_attr_value` to the binary-serialized AttrValue proto +// representation of the value of the `attr_name` attr of `oper`. +TF_CAPI_EXPORT extern void TF_OperationGetAttrValueProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); + +// Returns the operation in the graph with `oper_name`. Returns nullptr if +// no operation found. +TF_CAPI_EXPORT extern TF_Operation* TF_GraphOperationByName( + TF_Graph* graph, const char* oper_name); + +// Iterate through the operations of a graph. To use: +// size_t pos = 0; +// TF_Operation* oper; +// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { +// DoSomethingWithOperation(oper); +// } +TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, + size_t* pos); + +// Write out a serialized representation of `graph` (as a GraphDef protocol +// message) to `output_graph_def` (allocated by TF_NewBuffer()). +// `output_graph_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. +// +// May fail on very large graphs in the future. +TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, + TF_Buffer* output_graph_def, + TF_Status* status); + +// Returns the serialized OpDef proto with name `op_name`, or a bad status if no +// such op exists. This can return OpDefs of functions copied into the graph. +TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph, + const char* op_name, + TF_Buffer* output_op_def, + TF_Status* status); + +// Returns the serialized VersionDef proto for this graph. +TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, + TF_Buffer* output_version_def, + TF_Status* status); + +// TF_ImportGraphDefOptions holds options that can be passed to +// TF_GraphImportGraphDef. +typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; + +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( + TF_ImportGraphDefOptions* opts); + +// Set the prefix to be prepended to the names of nodes in `graph_def` that will +// be imported into `graph`. `prefix` is copied and has no lifetime +// requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( + TF_ImportGraphDefOptions* opts, const char* prefix); + +// Set the execution device for nodes in `graph_def`. +// Only applies to nodes where a device was not already explicitly specified. +// `device` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetDefaultDevice( + TF_ImportGraphDefOptions* opts, const char* device); + +// Set whether to uniquify imported operation names. If true, imported operation +// names will be modified if their name already exists in the graph. If false, +// conflicting names will be treated as an error. Note that this option has no +// effect if a prefix is set, since the prefix will guarantee all names are +// unique. Defaults to false. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); + +// If true, the specified prefix will be modified if it already exists as an +// operation name or prefix in the graph. If false, a conflicting prefix will be +// treated as an error. This option has no effect if no prefix is specified. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); + +// Set any imported nodes with input `src_name:src_index` to have that input +// replaced with `dst`. `src_name` refers to a node in the graph to be imported, +// `dst` references a node already existing in the graph being imported into. +// `src_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( + TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, + TF_Output dst); + +// Set any imported nodes with control input `src_name` to have that input +// replaced with `dst`. `src_name` refers to a node in the graph to be imported, +// `dst` references an operation already existing in the graph being imported +// into. `src_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( + TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); + +// Cause the imported graph to have a control dependency on `oper`. `oper` +// should exist in the graph being imported into. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( + TF_ImportGraphDefOptions* opts, TF_Operation* oper); + +// Add an output in `graph_def` to be returned via the `return_outputs` output +// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input +// mapping, the corresponding existing tensor in `graph` will be returned. +// `oper_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( + TF_ImportGraphDefOptions* opts, const char* oper_name, int index); + +// Returns the number of return outputs added via +// TF_ImportGraphDefOptionsAddReturnOutput(). +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( + const TF_ImportGraphDefOptions* opts); + +// Add an operation in `graph_def` to be returned via the `return_opers` output +// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no +// lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( + TF_ImportGraphDefOptions* opts, const char* oper_name); + +// Returns the number of return operations added via +// TF_ImportGraphDefOptionsAddReturnOperation(). +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts); + +// TF_ImportGraphDefResults holds results that are generated by +// TF_GraphImportGraphDefWithResults(). +typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults; + +// Fetches the return outputs requested via +// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is +// returned in `num_outputs`. The array of return outputs is returned in +// `outputs`. `*outputs` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs( + TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs); + +// Fetches the return operations requested via +// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched +// operations is returned in `num_opers`. The array of return operations is +// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( + TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); + +// Fetches any input mappings requested via +// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef +// and weren't used as input to any node in the imported graph def. The number +// of fetched mappings is returned in `num_missing_unused_input_mappings`. The +// array of each mapping's source node name is returned in `src_names`, and the +// array of each mapping's source index is returned in `src_indexes`. +// +// `*src_names`, `*src_indexes`, and the memory backing each string in +// `src_names` are owned by and have the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, + const char*** src_names, int** src_indexes); + +// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults( + TF_ImportGraphDefResults* results); + +// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and +// a bad status on error. Otherwise, returns a populated +// TF_ImportGraphDefResults instance. The returned instance must be deleted via +// TF_DeleteImportGraphDefResults(). +TF_CAPI_EXPORT extern TF_ImportGraphDefResults* +TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, + TF_Status* status); + +// Import the graph serialized in `graph_def` into `graph`. +// Convenience function for when only return outputs are needed. +// +// `num_return_outputs` must be the number of return outputs added (i.e. the +// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If +// `num_return_outputs` is non-zero, `return_outputs` must be of length +// `num_return_outputs`. Otherwise it can be null. +TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, + int num_return_outputs, TF_Status* status); + +// Import the graph serialized in `graph_def` into `graph`. +// Convenience function for when no results are needed. +TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status); + +// Adds a copy of function `func` and optionally its gradient function `grad` +// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating +// an operation using the function's name. +// Any changes to `func`/`grad` (including deleting it) done after this method +// returns, won't affect the copy of `func`/`grad` in `g`. +// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no +// effect on them, but can establish the function->gradient relationship +// between them if `func` does not already have a gradient. If `func` already +// has a gradient different from `grad`, an error is returned. +// +// `func` must not be null. +// If `grad` is null and `func` is not in `g`, `func` is added without a +// gradient. +// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop. +// `grad` must have appropriate signature as described in the doc of +// GradientDef in tensorflow/core/framework/function.proto. +// +// If successful, status is set to OK and `func` and `grad` are added to `g`. +// Otherwise, status is set to the encountered error and `g` is unmodified. +TF_CAPI_EXPORT extern void TF_GraphCopyFunction(TF_Graph* g, + const TF_Function* func, + const TF_Function* grad, + TF_Status* status); + +// Returns the number of TF_Functions registered in `g`. +TF_CAPI_EXPORT extern int TF_GraphNumFunctions(TF_Graph* g); + +// Fills in `funcs` with the TF_Function* registered in `g`. +// `funcs` must point to an array of TF_Function* of length at least +// `max_func`. In usual usage, max_func should be set to the result of +// TF_GraphNumFunctions(g). In this case, all the functions registered in +// `g` will be returned. Else, an unspecified subset. +// +// If successful, returns the number of TF_Function* successfully set in +// `funcs` and sets status to OK. The caller takes ownership of +// all the returned TF_Functions. They must be deleted with TF_DeleteFunction. +// On error, returns 0, sets status to the encountered error, and the contents +// of funcs will be undefined. +TF_CAPI_EXPORT extern int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, + int max_func, TF_Status* status); + +// Note: The following function may fail on very large protos in the future. + +TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, + TF_Buffer* output_node_def, + TF_Status* status); + +// Create a TF_Function from a TF_Graph +// +// Params: +// fn_body - the graph whose operations (or subset of whose operations) will be +// converted to TF_Function. +// fn_name - the name of the new TF_Function. Should match the operation +// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. +// If `append_hash_to_fn_name` is false, `fn_name` must be distinct +// from other function and operation names (at least those +// registered in graphs where this function will be used). +// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name +// of the function will be `fn_name` appended with +// '_'. +// If set to 0, the function's name will be `fn_name`. +// num_opers - `num_opers` contains the number of elements in the `opers` array +// or a special value of -1 meaning that no array is given. +// The distinction between an empty array of operations and no +// array of operations is necessary to distinguish the case of +// creating a function with no body (e.g. identity or permutation) +// and the case of creating a function whose body contains all +// the nodes in the graph (except for the automatic skipping, see +// below). +// opers - Array of operations to become the body of the function or null. +// - If no array is given (`num_opers` = -1), all the +// operations in `fn_body` will become part of the function +// except operations referenced in `inputs`. These operations +// must have a single output (these operations are typically +// placeholders created for the sole purpose of representing +// an input. We can relax this constraint if there are +// compelling use cases). +// - If an array is given (`num_opers` >= 0), all operations +// in it will become part of the function. In particular, no +// automatic skipping of dummy input operations is performed. +// ninputs - number of elements in `inputs` array +// inputs - array of TF_Outputs that specify the inputs to the function. +// If `ninputs` is zero (the function takes no inputs), `inputs` +// can be null. The names used for function inputs are normalized +// names of the operations (usually placeholders) pointed to by +// `inputs`. These operation names should start with a letter. +// Normalization will convert all letters to lowercase and +// non-alphanumeric characters to '_' to make resulting names match +// the "[a-z][a-z0-9_]*" pattern for operation argument names. +// `inputs` cannot contain the same tensor twice. +// noutputs - number of elements in `outputs` array +// outputs - array of TF_Outputs that specify the outputs of the function. +// If `noutputs` is zero (the function returns no outputs), `outputs` +// can be null. `outputs` can contain the same tensor more than once. +// output_names - The names of the function's outputs. `output_names` array +// must either have the same length as `outputs` +// (i.e. `noutputs`) or be null. In the former case, +// the names should match the regular expression for ArgDef +// names - "[a-z][a-z0-9_]*". In the latter case, +// names for outputs will be generated automatically. +// opts - various options for the function, e.g. XLA's inlining control. +// description - optional human-readable description of this function. +// status - Set to OK on success and an appropriate error on failure. +// +// Note that when the same TF_Output is listed as both an input and an output, +// the corresponding function's output will equal to this input, +// instead of the original node's output. +// +// Callers must also satisfy the following constraints: +// - `inputs` cannot refer to TF_Outputs within a control flow context. For +// example, one cannot use the output of "switch" node as input. +// - `inputs` and `outputs` cannot have reference types. Reference types are +// not exposed through C API and are being replaced with Resources. We support +// reference types inside function's body to support legacy code. Do not +// use them in new code. +// - Every node in the function's body must have all of its inputs (including +// control inputs). In other words, for every node in the body, each input +// must be either listed in `inputs` or must come from another node in +// the body. In particular, it is an error to have a control edge going from +// a node outside of the body into a node in the body. This applies to control +// edges going from nodes referenced in `inputs` to nodes in the body when +// the former nodes are not in the body (automatically skipped or not +// included in explicitly specified body). +// +// Returns: +// On success, a newly created TF_Function instance. It must be deleted by +// calling TF_DeleteFunction. +// +// On failure, null. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + const TF_FunctionOptions* opts, const char* description, TF_Status* status); + +// Similar to TF_GraphToFunction but allows specifying control outputs of the +// function. +// +// The arguments of TF_GraphToFunction have the same meaning, but the new +// arguments are as follows: +// +// ncontrol_outputs: Number of control outputs of the function. +// control_outputs: vector of TF_Operation objects to be marked as control +// outputs of the function. Operations marked as control outputs are +// guaranteed to execute. +// control_output_names: Optional. If not nullptr, vector of strings, one +// per control output, with their names to be added to the function's +// OpDef. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status); + +// Returns the name of the graph function. +// The return value points to memory that is only usable until the next +// mutation to *func. +TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); + +// Write out a serialized representation of `func` (as a FunctionDef protocol +// message) to `output_func_def` (allocated by TF_NewBuffer()). +// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. +// +// May fail on very large graphs in the future. +TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, + TF_Buffer* output_func_def, + TF_Status* status); + +// Construct and return the function whose FunctionDef representation is +// serialized in `proto`. `proto_len` must equal the number of bytes +// pointed to by `proto`. +// Returns: +// On success, a newly created TF_Function instance. It must be deleted by +// calling TF_DeleteFunction. +// +// On failure, null. +TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( + const void* proto, size_t proto_len, TF_Status* status); + +// Sets function attribute named `attr_name` to value stored in `proto`. +// If this attribute is already set to another value, it is overridden. +// `proto` should point to a sequence of bytes of length `proto_len` +// representing a binary serialization of an AttrValue protocol +// buffer. +TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Sets `output_attr_value` to the binary-serialized AttrValue proto +// representation of the value of the `attr_name` attr of `func`. +// If `attr_name` attribute is not present, status is set to an error. +TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( + TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); + +// Frees the memory used by the `func` struct. +// TF_DeleteFunction is a noop if `func` is null. +// Deleting a function does not remove it from any graphs it was copied to. +TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); + +// Attempts to evaluate `output`. This will only be possible if `output` doesn't +// depend on any graph inputs (this function is safe to call if this isn't the +// case though). +// +// If the evaluation is successful, this function returns true and `output`s +// value is returned in `result`. Otherwise returns false. An error status is +// returned if something is wrong with the graph or input. Note that this may +// return false even if no error status is set. +TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph, + TF_Output output, + TF_Tensor** result, + TF_Status* status); + +// TODO(josh11b): Register OpDef, available to all operations added +// to this graph. + +// -------------------------------------------------------------------------- +// API for driving Graph execution. + +typedef struct TF_Session TF_Session; + +// Return a new execution session with the associated graph, or NULL on +// error. Does not take ownership of any input parameters. +// +// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be +// kept alive for the lifetime of the returned TF_Session. New nodes can still +// be added to `graph` after this call. +TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, + const TF_SessionOptions* opts, + TF_Status* status); + +// This function creates a new TF_Session (which is created on success) using +// `session_options`, and then initializes state (restoring tensors and other +// assets) using `run_options`. +// +// Any NULL and non-NULL value combinations for (`run_options, `meta_graph_def`) +// are valid. +// +// - `export_dir` must be set to the path of the exported SavedModel. +// - `tags` must include the set of tags used to identify one MetaGraphDef in +// the SavedModel. +// - `graph` must be a graph newly allocated with TF_NewGraph(). +// +// If successful, populates `graph` with the contents of the Graph and +// `meta_graph_def` with the MetaGraphDef of the loaded model. +TF_CAPI_EXPORT extern TF_Session* TF_LoadSessionFromSavedModel( + const TF_SessionOptions* session_options, const TF_Buffer* run_options, + const char* export_dir, const char* const* tags, int tags_len, + TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status); + +// Close a session. +// +// Contacts any other processes associated with the session, if applicable. +// May not be called after TF_DeleteSession(). +TF_CAPI_EXPORT extern void TF_CloseSession(TF_Session*, TF_Status* status); + +// Destroy a session object. +// +// Even if error information is recorded in *status, this call discards all +// local resources associated with the session. The session may not be used +// during or after this call (and the session drops its reference to the +// corresponding graph). +TF_CAPI_EXPORT extern void TF_DeleteSession(TF_Session*, TF_Status* status); + +// Run the graph associated with the session starting with the supplied inputs +// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). +// +// Any NULL and non-NULL value combinations for (`run_options`, +// `run_metadata`) are valid. +// +// - `run_options` may be NULL, in which case it will be ignored; or +// non-NULL, in which case it must point to a `TF_Buffer` containing the +// serialized representation of a `RunOptions` protocol buffer. +// - `run_metadata` may be NULL, in which case it will be ignored; or +// non-NULL, in which case it must point to an empty, freshly allocated +// `TF_Buffer` that may be updated to contain the serialized representation +// of a `RunMetadata` protocol buffer. +// +// The caller retains ownership of `input_values` (which can be deleted using +// TF_DeleteTensor). The caller also retains ownership of `run_options` and/or +// `run_metadata` (when not NULL) and should manually call TF_DeleteBuffer on +// them. +// +// On success, the tensors corresponding to outputs[0,noutputs-1] are placed in +// output_values[]. Ownership of the elements of output_values[] is transferred +// to the caller, which must eventually call TF_DeleteTensor on them. +// +// On failure, output_values[] contains NULLs. +TF_CAPI_EXPORT extern void TF_SessionRun( + TF_Session* session, + // RunOptions + const TF_Buffer* run_options, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // RunMetadata + TF_Buffer* run_metadata, + // Output status + TF_Status*); + +// Set up the graph with the intended feeds (inputs) and fetches (outputs) for a +// sequence of partial run calls. +// +// On success, returns a handle that is used for subsequent PRun calls. The +// handle should be deleted with TF_DeletePRunHandle when it is no longer +// needed. +// +// On failure, out_status contains a tensorflow::Status with an error +// message. *handle is set to nullptr. +TF_CAPI_EXPORT extern void TF_SessionPRunSetup( + TF_Session*, + // Input names + const TF_Output* inputs, int ninputs, + // Output names + const TF_Output* outputs, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output handle + const char** handle, + // Output status + TF_Status*); + +// Continue to run the graph with additional feeds and fetches. The +// execution state is uniquely identified by the handle. +TF_CAPI_EXPORT extern void TF_SessionPRun( + TF_Session*, const char* handle, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output status + TF_Status*); + +// Deletes a handle allocated by TF_SessionPRunSetup. +// Once called, no more calls to TF_SessionPRun should be made. +TF_CAPI_EXPORT extern void TF_DeletePRunHandle(const char* handle); + +// -------------------------------------------------------------------------- +// The deprecated session API. Please switch to the above instead of +// TF_ExtendGraph(). This deprecated API can be removed at any time without +// notice. + +typedef struct TF_DeprecatedSession TF_DeprecatedSession; + +TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession( + const TF_SessionOptions*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_Reset(const TF_SessionOptions* opt, + const char** containers, int ncontainers, + TF_Status* status); +// Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and +// add the nodes in that GraphDef to the graph for the session. +// +// Prefer use of TF_Session and TF_GraphImportGraphDef over this. +TF_CAPI_EXPORT extern void TF_ExtendGraph(TF_DeprecatedSession*, + const void* proto, size_t proto_len, + TF_Status*); + +// See TF_SessionRun() above. +TF_CAPI_EXPORT extern void TF_Run(TF_DeprecatedSession*, + const TF_Buffer* run_options, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Buffer* run_metadata, TF_Status*); + +// See TF_SessionPRunSetup() above. +TF_CAPI_EXPORT extern void TF_PRunSetup(TF_DeprecatedSession*, + const char** input_names, int ninputs, + const char** output_names, int noutputs, + const char** target_oper_names, + int ntargets, const char** handle, + TF_Status*); + +// See TF_SessionPRun above. +TF_CAPI_EXPORT extern void TF_PRun(TF_DeprecatedSession*, const char* handle, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Status*); + +typedef struct TF_DeviceList TF_DeviceList; + +// Lists all devices in a TF_Session. +// +// Caller takes ownership of the returned TF_DeviceList* which must eventually +// be freed with a call to TF_DeleteDeviceList. +TF_CAPI_EXPORT extern TF_DeviceList* TF_SessionListDevices(TF_Session* session, + TF_Status* status); + +// Lists all devices in a TF_Session. +// +// Caller takes ownership of the returned TF_DeviceList* which must eventually +// be freed with a call to TF_DeleteDeviceList. +TF_CAPI_EXPORT extern TF_DeviceList* TF_DeprecatedSessionListDevices( + TF_DeprecatedSession* session, TF_Status* status); + +// Deallocates the device list. +TF_CAPI_EXPORT extern void TF_DeleteDeviceList(TF_DeviceList* list); + +// Counts the number of elements in the device list. +TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); + +// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) +// The return value will be a pointer to a null terminated string. The caller +// must not modify or delete the string. It will be deallocated upon a call to +// TF_DeleteDeviceList. +// +// If index is out of bounds, an error code will be set in the status object, +// and a null pointer will be returned. +TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, + int index, + TF_Status* status); + +// Retrieves the type of the device at the given index. +// +// The caller must not modify or delete the string. It will be deallocated upon +// a call to TF_DeleteDeviceList. +// +// If index is out of bounds, an error code will be set in the status object, +// and a null pointer will be returned. +TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, + int index, + TF_Status* status); + +// Retrieve the amount of memory associated with a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and -1 will be returned. +TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( + const TF_DeviceList* list, int index, TF_Status* status); + +// Retrieve the incarnation number of a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and 0 will be returned. +TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation( + const TF_DeviceList* list, int index, TF_Status* status); + +// -------------------------------------------------------------------------- +// Load plugins containing custom ops and kernels + +// TF_Library holds information about dynamically loaded TensorFlow plugins. +typedef struct TF_Library TF_Library; + +// Load the library specified by library_filename and register the ops and +// kernels present in that library. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, place OK in status and return the newly created library handle. +// The caller owns the library handle. +// +// On failure, place an error status in status and return NULL. +TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename, + TF_Status* status); + +// Get the OpList of OpDefs defined in the library pointed by lib_handle. +// +// Returns a TF_Buffer. The memory pointed to by the result is owned by +// lib_handle. The data in the buffer will be the serialized OpList proto for +// ops defined in the library. +TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); + +// Frees the memory associated with the library handle. +// Does NOT unload the library. +TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); + +// Get the OpList of all OpDefs defined in this address space. +// Returns a TF_Buffer, ownership of which is transferred to the caller +// (and can be freed using TF_DeleteBuffer). +// +// The data in the buffer will be the serialized OpList proto for ops registered +// in this address space. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); + +// TF_ApiDefMap encapsulates a collection of API definitions for an operation. +// +// This object maps the name of a TensorFlow operation to a description of the +// API to generate for it, as defined by the ApiDef protocol buffer ( +// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) +// +// The ApiDef messages are typically used to generate convenience wrapper +// functions for TensorFlow operations in various language bindings. +typedef struct TF_ApiDefMap TF_ApiDefMap; + +// Creates a new TF_ApiDefMap instance. +// +// Params: +// op_list_buffer - TF_Buffer instance containing serialized OpList +// protocol buffer. (See +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto +// for the OpList proto definition). +// status - Set to OK on success and an appropriate error on failure. +TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, + TF_Status* status); + +// Deallocates a TF_ApiDefMap. +TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap); + +// Add ApiDefs to the map. +// +// `text` corresponds to a text representation of an ApiDefs protocol message. +// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). +// +// The provided ApiDefs will be merged with existing ones in the map, with +// precedence given to the newly added version in case of conflicts with +// previous calls to TF_ApiDefMapPut. +TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, + const char* text, size_t text_len, + TF_Status* status); + +// Returns a serialized ApiDef protocol buffer for the TensorFlow operation +// named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, + const char* name, + size_t name_len, + TF_Status* status); + +// -------------------------------------------------------------------------- +// Kernel definition information. + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// registered kernels. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// kernels registered for the operation named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( + const char* name, TF_Status* status); + +// -------------------------------------------------------------------------- +// In-process TensorFlow server functionality, for use in distributed training. +// A Server instance encapsulates a set of devices and a Session target that +// can participate in distributed training. A server belongs to a cluster +// (specified by a ClusterSpec), and corresponds to a particular task in a +// named job. The server can communicate with any other server in the same +// cluster. + +// In-process TensorFlow server. +typedef struct TF_Server TF_Server; + +// Creates a new in-process TensorFlow server configured using a serialized +// ServerDef protocol buffer provided via `proto` and `proto_len`. +// +// The server will not serve any requests until TF_ServerStart is invoked. +// The server will stop serving requests once TF_ServerStop or +// TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, + size_t proto_len, + TF_Status* status); + +// Starts an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); + +// Stops an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); + +// Blocks until the server has been successfully stopped (via TF_ServerStop or +// TF_ServerClose). +TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); + +// Returns the target string that can be provided to TF_SetTarget() to connect +// a TF_Session to `server`. +// +// The returned string is valid only until TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); + +// Destroy an in-process TensorFlow server, frees memory. If server is running +// it will be stopped and joined. +TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); + +// Register a listener method that processes printed messages. +// +// If any listeners are registered, the print operator will call all listeners +// with the printed messages and immediately return without writing to the +// logs. +TF_CAPI_EXPORT extern void TF_RegisterLogListener( + void (*listener)(const char*)); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_C_CORE_API_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c25cb264ce7..2ec1f442780 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -42,7 +42,7 @@ tf_cuda_library( "//conditions:default": [ "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:fixed_array", - "//tensorflow/c:c_api", + "//tensorflow/c:c_core_api", "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_tensor_internal", "//tensorflow/core:core_cpu", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 96dc288f213..67324a441f9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -28,7 +28,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" -#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_core_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/tf_tensor_internal.h" diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 070b3a9bb60..b951f45d0e1 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -20,7 +20,7 @@ limitations under the License. // WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be // stable and can change without notice. -#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_core_api.h" // Macro to control visibility of exported symbols in the shared library (.so, // .dylib, .dll). diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c38d7b84a74..f69f79eed7a 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -36,6 +36,7 @@ transitive_hdrs( "//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:reader", "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/c:c_core_api_no_xla", # WARNING: None of the C/C++ code under python/ has any API guarantees, and TF team # reserves the right to change APIs and other header-level interfaces. If your custom # op uses these headers, it may break when users upgrade their version of tensorflow. diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 4dfe616263b..64a4469e0da 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -246,6 +246,7 @@ headers = ( list(find_files('*.proto', 'tensorflow/compiler')) + list(find_files('*.proto', 'tensorflow/core')) + list(find_files('*.proto', 'tensorflow/python')) + + list(find_files('*.h', 'tensorflow/c')) + list(find_files('*.h', 'tensorflow/cc')) + list(find_files('*.h', 'tensorflow/compiler')) + list(find_files('*.h', 'tensorflow/core')) + From 1a101d91633e43940afa0bfd53d9a35d68aa85bc Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 17 Mar 2020 12:48:15 -0700 Subject: [PATCH 092/492] Update `Tensor.shape`, `Tensor.get_shape` and `tf.ensure_shape` docstrings. Tensor.shape is broken in nightly. PiperOrigin-RevId: 301432954 Change-Id: I00024bca82e12985273c8208bc1bfce4d7ba98d2 --- tensorflow/python/framework/ops.py | 249 ++++++++++++++++++++++------- tensorflow/python/ops/check_ops.py | 90 ++++++++--- 2 files changed, 259 insertions(+), 80 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 266413cc96e..7796322d6fc 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -450,50 +450,22 @@ class Tensor(_TensorLike): @property def shape(self): - """Returns the `TensorShape` that represents the shape of this tensor. + """Returns a `tf.TensorShape` that represents the shape of this tensor. - The shape is computed using shape inference functions that are - registered in the Op for each `Operation`. See - `tf.TensorShape` - for more details of what a shape represents. + >>> t = tf.constant([1,2,3,4,5]) + >>> t.shape + TensorShape([5]) - The inferred shape of a tensor is used to provide shape - information without having to execute the underlying kernel. This - can be used for debugging and providing early error messages. For - example: + `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`. - ```python - >>> c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - >>> print(c.shape) # will be TensorShape([2, 3]) - (2, 3) + In a `tf.function` or when building a model using + `tf.keras.Input`, they return the build-time shape of the + tensor, which may be partially unknown. - >>> d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]]) - >>> print(d.shape) - (4, 2) - - # Raises a ValueError, because `c` and `d` do not have compatible - # inner dimensions. - >>> e = tf.matmul(c, d) - Traceback (most recent call last): - ... - tensorflow.python.framework.errors_impl.InvalidArgumentError: Matrix - size-incompatible: In[0]: [2,3], In[1]: [4,2] [Op:MatMul] name: MatMul/ - - # This works because we have compatible shapes. - >>> f = tf.matmul(c, d, transpose_a=True, transpose_b=True) - >>> print(f.shape) - (3, 4) - - ``` - - In some cases, the inferred shape may have unknown dimensions. If - the caller has additional information about the values of these - dimensions, `Tensor.set_shape()` can be used to augment the - inferred shape. - - Returns: - A `tf.TensorShape` representing the shape of this tensor. + A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor + containing the shape, calculated at runtime. + See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples. """ if self._shape_val is None: self._shape_val = self._c_api_shape() @@ -591,37 +563,192 @@ class Tensor(_TensorLike): return self.shape.ndims def get_shape(self): - """Alias of `tf.Tensor.shape`.""" + """Returns a `tf.TensorShape` that represents the shape of this tensor. + + In eager execution the shape is always fully-known. + + >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> print(a.shape) + (2, 3) + + `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`. + + + When executing in a `tf.function` or building a model using + `tf.keras.Input`, `Tensor.shape` may return a partial shape (including + `None` for unknown dimensions). See `tf.TensorShape` for more details. + + >>> inputs = tf.keras.Input(shape = [10]) + >>> # Unknown batch size + >>> print(inputs.shape) + (None, 10) + + The shape is computed using shape inference functions that are + registered for each `tf.Operation`. + + The returned `tf.TensorShape` is determined at *build* time, without + executing the underlying kernel. It is not a `tf.Tensor`. If you need a + shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or + use the `tf.shape(tensor)` function, which returns the tensor's shape at + *execution* time. + + This is useful for debugging and providing early errors. For + example, when tracing a `tf.function`, no ops are being executed, shapes + may be unknown (See the [Concrete Functions + Guide](https://www.tensorflow.org/guide/concrete_function) for details). + + >>> @tf.function + ... def my_matmul(a, b): + ... result = a@b + ... # the `print` executes during tracing. + ... print("Result shape: ", result.shape) + ... return result + + The shape inference functions propagate shapes to the extent possible: + + >>> _ = my_matmul.get_concrete_function( + ... tf.TensorSpec([None,3]), + ... tf.TensorSpec([3,5])) + Result shape: (None, 5) + + Tracing may fail if a shape missmatch can be detected: + + >>> _ = my_matmul.get_concrete_function( + ... tf.TensorSpec([None,3]), + ... tf.TensorSpec([4,5])) + Traceback (most recent call last): + ... + ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op: + 'MatMul') with input shapes: [?,3], [4,5]. + + In some cases, the inferred shape may have unknown dimensions. If + the caller has additional information about the values of these + dimensions, `Tensor.set_shape()` can be used to augment the + inferred shape. + + >>> @tf.function + ... def my_fun(a): + ... a.set_shape([5, 5]) + ... # the `print` executes during tracing. + ... print("Result shape: ", a.shape) + ... return a + + >>> _ = my_fun.get_concrete_function( + ... tf.TensorSpec([None, None])) + Result shape: (5, 5) + + Returns: + A `tf.TensorShape` representing the shape of this tensor. + + """ return self.shape def set_shape(self, shape): """Updates the shape of this tensor. - This method can be called multiple times, and will merge the given - `shape` with the current shape of this tensor. It can be used to - provide additional information about the shape of this tensor that - cannot be inferred from the graph alone. For example, this can be used - to provide additional information about the shapes of images: + With eager execution this operates as a shape assertion. + Here the shapes match: - ```python - _, image_data = tf.compat.v1.TFRecordReader(...).read(...) - image = tf.io.decode_png(image_data, channels=3) + >>> t = tf.constant([[1,2,3]]) + >>> t.set_shape([1, 3]) - # The height and width dimensions of `image` are data dependent, and - # cannot be computed without executing the op. - print(image.shape) - ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)]) + Passing a `None` in the new shape allows any value for that axis: - # We know that each image in this dataset is 28 x 28 pixels. - image.set_shape([28, 28, 3]) - print(image.shape) - ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)]) - ``` + >>> t.set_shape([1,None]) - NOTE: This shape is not enforced at runtime. Setting incorrect shapes can - result in inconsistencies between the statically-known graph and the runtime - value of tensors. For runtime validation of the shape, use `tf.ensure_shape` - instead. + An error is raised if an incompatible shape is passed. + + >>> t.set_shape([1,5]) + Traceback (most recent call last): + ... + ValueError: Tensor's shape (1, 3) is not compatible with supplied + shape [1, 5] + + When executing in a `tf.function`, or building a model using + `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with + the current shape of this tensor, and set the tensor's shape to the + merged value (see `tf.TensorShape.merge_with` for details): + + >>> t = tf.keras.Input(shape=[None, None, 3]) + >>> print(t.shape) + (None, None, None, 3) + + Dimensions set to `None` are not updated: + + >>> t.set_shape([None, 224, 224, None]) + >>> print(t.shape) + (None, 224, 224, 3) + + The main use case for this is to provide additional shape information + that cannot be inferred from the graph alone. + + For example if you know all the images in a dataset have shape [28,28,3] you + can set it with `tf.set_shape`: + + >>> @tf.function + ... def load_image(filename): + ... raw = tf.io.read_file(filename) + ... image = tf.image.decode_png(raw, channels=3) + ... # the `print` executes during tracing. + ... print("Initial shape: ", image.shape) + ... image.set_shape([28, 28, 3]) + ... print("Final shape: ", image.shape) + ... return image + + Trace the function, see the [Concrete Functions + Guide](https://www.tensorflow.org/guide/concrete_function) for details. + + >>> _ = load_image.get_concrete_function( + ... tf.TensorSpec([], dtype=tf.string)) + Initial shape: (None, None, 3) + Final shape: (28, 28, 3) + + Similarly the `tf.io.parse_tensor` function could return a tensor with + any shape, even the `tf.rank` is unknown. If you know that all your + serialized tensors will be 2d, set it with `set_shape`: + + >>> @tf.function + ... def my_parse(string_tensor): + ... result = tf.io.parse_tensor(string_tensor, out_type=tf.float32) + ... # the `print` executes during tracing. + ... print("Initial shape: ", result.shape) + ... result.set_shape([None, None]) + ... print("Final shape: ", result.shape) + ... return result + + Trace the function + + >>> concrete_parse = my_parse.get_concrete_function( + ... tf.TensorSpec([], dtype=tf.string)) + Initial shape: + Final shape: (None, None) + + Make sure it works: + + >>> t = tf.ones([5,3], dtype=tf.float32) + >>> serialized = tf.io.serialize_tensor(t) + >>> print(serialized.dtype) + + >>> print(serialized.shape) + () + >>> t2 = concrete_parse(serialized) + >>> print(t2.shape) + (5, 3) + + Caution: `set_shape` ensures that the applied shape is compatible with + the existing shape, but it does not check at runtime. Setting + incorrect shapes can result in inconsistencies between the + statically-known graph and the runtime value of tensors. For runtime + validation of the shape, use `tf.ensure_shape` instead. It also modifies + the `shape` of the tensor. + + >>> # Serialize a rank-3 tensor + >>> t = tf.ones([5,5,5], dtype=tf.float32) + >>> serialized = tf.io.serialize_tensor(t) + >>> # The function still runs, even though it `set_shape([None,None])` + >>> t2 = concrete_parse(serialized) + >>> print(t2.shape) + (5, 5, 5) Args: shape: A `TensorShape` representing the shape of this tensor, a diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 242c41b2927..e8945f95ca1 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -2134,33 +2134,85 @@ def assert_scalar(tensor, name=None, message=None): def ensure_shape(x, shape, name=None): """Updates the shape of a tensor and checks at runtime that the shape holds. - For example: - ```python - x = tf.compat.v1.placeholder(tf.int32) - print(x.shape) - ==> TensorShape(None) - y = x * 2 - print(y.shape) - ==> TensorShape(None) + With eager execution this is a shape assertion, that returns the input: - y = tf.ensure_shape(y, (None, 3, 3)) - print(y.shape) - ==> TensorShape([Dimension(None), Dimension(3), Dimension(3)]) + >>> x = tf.constant([1,2,3]) + >>> print(x.shape) + (3,) + >>> x = tf.ensure_shape(x, [3]) + >>> x = tf.ensure_shape(x, [5]) + Traceback (most recent call last): + ... + tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not + compatible with expected shape [5]. [Op:EnsureShape] - with tf.compat.v1.Session() as sess: - # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not - # compatible with the shape (None, 3, 3) - sess.run(y, feed_dict={x: [1, 2, 3]}) + Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and + runtime shapes. This is stricter than `tf.Tensor.set_shape` which only + checks the buildtime shape. - ``` - - NOTE: This differs from `Tensor.set_shape` in that it sets the static shape + Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape of the resulting tensor and enforces it at runtime, raising an error if the tensor's runtime shape is incompatible with the specified shape. - `Tensor.set_shape` sets the static shape of the tensor without enforcing it + `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it at runtime, which may result in inconsistencies between the statically-known shape of tensors and the runtime value of tensors. + For example, of loading images of a known size: + + >>> @tf.function + ... def decode_image(png): + ... image = tf.image.decode_png(png, channels=3) + ... # the `print` executes during tracing. + ... print("Initial shape: ", image.shape) + ... image = tf.ensure_shape(image,[28, 28, 3]) + ... print("Final shape: ", image.shape) + ... return image + + When tracing a function, no ops are being executed, shapes may be unknown. + See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function) + for details. + + >>> concrete_decode = decode_image.get_concrete_function( + ... tf.TensorSpec([], dtype=tf.string)) + Initial shape: (None, None, 3) + Final shape: (28, 28, 3) + + >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) + >>> image = tf.cast(image,tf.uint8) + >>> png = tf.image.encode_png(image) + >>> image2 = concrete_decode(png) + >>> print(image2.shape) + (28, 28, 3) + + >>> image = tf.concat([image,image], axis=0) + >>> print(image.shape) + (56, 28, 3) + >>> png = tf.image.encode_png(image) + >>> image2 = concrete_decode(png) + Traceback (most recent call last): + ... + tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not + compatible with expected shape [28,28,3]. + + Caution: if you don't use the result of `tf.ensure_shape` the check may not + run. + + >>> @tf.function + ... def bad_decode_image(png): + ... image = tf.image.decode_png(png, channels=3) + ... # the `print` executes during tracing. + ... print("Initial shape: ", image.shape) + ... # BAD: forgot to use the returned tensor. + ... tf.ensure_shape(image,[28, 28, 3]) + ... print("Final shape: ", image.shape) + ... return image + + >>> image = bad_decode_image(png) + Initial shape: (None, None, 3) + Final shape: (None, None, 3) + >>> print(image.shape) + (56, 28, 3) + Args: x: A `Tensor`. shape: A `TensorShape` representing the shape of this tensor, a From ac3165622406999e2a2f47b6f9be3cf742396123 Mon Sep 17 00:00:00 2001 From: Anna R Date: Tue, 17 Mar 2020 13:12:19 -0700 Subject: [PATCH 093/492] Temporarily disable multiple symbol check to fix windows build. PiperOrigin-RevId: 301437819 Change-Id: Ifbf9541a2f77bb3addb187f9e10edfc98232d087 --- tensorflow/python/util/tf_export.py | 6 +++--- tensorflow/python/util/tf_export_test.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index 04c96d03617..e4d6bebc3db 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -341,9 +341,9 @@ class api_export(object): # pylint: disable=invalid-name # their own _tf_api_names as opposed to just inheriting it. if api_names_attr in func.__dict__: if not self._allow_multiple_exports: - raise SymbolAlreadyExposedError( - 'Symbol %s is already exposed as %s.' % - (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access + # TODO(annarev): temporarily removing check to fix builds. + # Need to investigate why symbols get reported multiple times. + return setattr(func, api_names_attr, names) def export_constant(self, module_name, name): diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py index 20625792e9b..51d9901fdb4 100644 --- a/tensorflow/python/util/tf_export_test.py +++ b/tensorflow/python/util/tf_export_test.py @@ -152,7 +152,8 @@ class ValidateExportTest(test.TestCase): (('NAME_E', 'NAME_F'), 0.5)], module2._tf_api_constants) - def testRaisesExceptionIfAlreadyHasAPINames(self): + # TODO(b/151745456): re-enable + def DISABLED_testRaisesExceptionIfAlreadyHasAPINames(self): _test_function._tf_api_names = ['abc'] export_decorator = tf_export.tf_export('nameA', 'nameB') with self.assertRaises(tf_export.SymbolAlreadyExposedError): From 6ae6b65d2f87177a9fa65f7cd48ed3127a154aaf Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Tue, 17 Mar 2020 13:27:04 -0700 Subject: [PATCH 094/492] Update docs to refer to dynamic-range quantization universally. PiperOrigin-RevId: 301440536 Change-Id: Ie54aba7649aed76d9c6e61e4e4a37cd07ffab82b --- .../g3doc/performance/model_optimization.md | 2 +- .../performance/post_training_quantization.md | 38 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tensorflow/lite/g3doc/performance/model_optimization.md b/tensorflow/lite/g3doc/performance/model_optimization.md index befc40aa738..5a5772b4a1f 100644 --- a/tensorflow/lite/g3doc/performance/model_optimization.md +++ b/tensorflow/lite/g3doc/performance/model_optimization.md @@ -89,7 +89,7 @@ The following types of quantization are available in TensorFlow Lite: Technique | Data requirements | Size reduction | Accuracy | Supported hardware -------------------------------------------------------------------------------------------------------------- | -------------------------------- | -------------- | --------------------------- | ------------------ [Post-training float16 quantization](post_training_float16_quant.ipynb) | No data | Up to 50% | Insignificant accuracy loss | CPU, GPU -[Post-training weight quantization](post_training_quant.ipynb) | No data | Up to 75% | Accuracy loss | CPU +[Post-training dynamic range quantization](post_training_quant.ipynb) | No data | Up to 75% | Accuracy loss | CPU [Post-training integer quantization](post_training_integer_quant.ipynb) | Unlabelled representative sample | Up to 75% | Smaller accuracy loss | CPU, EdgeTPU, Hexagon DSP [Quantization-aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize) | Labelled training data | Up to 75% | Smallest accuracy loss | CPU, EdgeTPU, Hexagon DSP diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md index 7712cffc83f..194d102d43d 100644 --- a/tensorflow/lite/g3doc/performance/post_training_quantization.md +++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md @@ -14,13 +14,13 @@ Note: The procedures on this page require TensorFlow 1.15 or higher. There are several post-training quantization options to choose from. Here is a summary table of the choices and the benefits they provide: -| Technique | Benefits | Hardware | -| -------------------------- | ------------------------- | ------------------- | -| Weight quantization | 4x smaller, 2-3x speedup, | CPU | -: : accuracy : : -| Full integer quantization | 4x smaller, 3x+ speedup | CPU, Edge TPU, etc. | -| Float16 quantization | 2x smaller, potential GPU | CPU/GPU | -: : acceleration : : +| Technique | Benefits | Hardware | +| ------------------------- | ------------------------- | ------------------- | +| Dynamic range | 4x smaller, 2-3x speedup, | CPU | +: quantization : accuracy : : +| Full integer quantization | 4x smaller, 3x+ speedup | CPU, Edge TPU, etc. | +| Float16 quantization | 2x smaller, potential GPU | CPU/GPU | +: : acceleration : : This decision tree can help determine which post-training quantization method is best for your use case: @@ -34,29 +34,29 @@ However, doing so requires some model modifications to add fake quantization nodes, whereas the post-training quantization techniques on this page use an existing pre-trained model. +### Dynamic range quantization -### Weight quantization - -The simplest form of post-training quantization quantizes only the weights from -floating point to 8-bits of precision (also called "hybrid" quantization). This -technique is enabled as an option in the [TensorFlow Lite -converter](../convert/): +The simplest form of post-training quantization statically quantizes only the +weights from floating point to 8-bits of precision. This technique is enabled as +an option in the [TensorFlow Lite converter](../convert/): ``` import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) -converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] +converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_quant_model = converter.convert() ``` At inference, weights are converted from 8-bits of precision to floating point and computed using floating-point kernels. This conversion is done once and cached to reduce latency. -To further improve latency, hybrid operators dynamically quantize activations to 8-bits and -perform computations with 8-bit weights and activations. This optimization provides latencies -close to fully fixed-point inference. However, the outputs are still stored using -floating point, so that the speedup with hybrid ops is less than a full fixed-point computation. -Hybrid ops are available for the most compute-intensive operators in a network: +To further improve latency, "dynamic-range" operators dynamically quantize +activations based on their range to 8-bits and perform computations with 8-bit +weights and activations. This optimization provides latencies close to fully +fixed-point inference. However, the outputs are still stored using floating +point, so that the speedup with dynamic-range ops is less than a full +fixed-point computation. Dynamic-range ops are available for the most +compute-intensive operators in a network: * [tf.contrib.layers.fully_connected](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/fully_connected) * [tf.nn.conv2d](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d) From 3a15248de39e08d4618960c24e98804f93cf59d1 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Tue, 17 Mar 2020 13:41:56 -0700 Subject: [PATCH 095/492] [tf.data] Adding support for overriding external state policy for checkpointing. PiperOrigin-RevId: 301443563 Change-Id: I852269b86039a71466ddeadfe3ce03d75dc45fda --- tensorflow/core/framework/dataset.h | 43 +++++++++++++++++-- .../core/kernels/data/batch_dataset_op.cc | 5 ++- .../core/kernels/data/cache_dataset_ops.cc | 23 ++++++---- .../kernels/data/concatenate_dataset_op.cc | 5 ++- .../assert_cardinality_dataset_op.cc | 5 ++- .../experimental/assert_next_dataset_op.cc | 5 ++- .../choose_fastest_branch_dataset_op.cc | 10 +++-- .../experimental/choose_fastest_dataset_op.cc | 7 +-- .../data/experimental/csv_dataset_op.cc | 3 +- .../dense_to_sparse_batch_dataset_op.cc | 5 ++- .../directed_interleave_dataset_op.cc | 7 +-- .../group_by_reducer_dataset_op.cc | 20 +++++---- .../group_by_window_dataset_op.cc | 18 ++++---- .../experimental/ignore_errors_dataset_op.cc | 5 ++- .../data/experimental/lmdb_dataset_op.cc | 3 +- .../experimental/map_and_batch_dataset_op.cc | 8 ++-- .../non_serializable_dataset_op.cc | 5 ++- .../parallel_interleave_dataset_op.cc | 15 ++++--- .../data/experimental/random_dataset_op.cc | 3 +- .../data/experimental/rebatch_dataset_op.cc | 5 ++- .../data/experimental/sampling_dataset_op.cc | 5 ++- .../data/experimental/scan_dataset_op.cc | 8 ++-- .../set_stats_aggregator_dataset_op.cc | 5 ++- .../data/experimental/sleep_dataset_op.cc | 5 ++- .../experimental/sliding_window_dataset_op.cc | 5 ++- .../data/experimental/snapshot_dataset_op.cc | 18 +++++--- .../data/experimental/sql_dataset_op.cc | 3 +- .../data/experimental/stats_dataset_ops.cc | 10 +++-- .../experimental/take_while_dataset_op.cc | 8 ++-- .../experimental/threadpool_dataset_op.cc | 15 ++++--- .../data/experimental/unbatch_dataset_op.cc | 5 ++- .../data/experimental/unique_dataset_op.cc | 5 ++- .../core/kernels/data/filter_dataset_op.cc | 8 ++-- .../data/fixed_length_record_dataset_op.cc | 6 ++- .../core/kernels/data/flat_map_dataset_op.cc | 10 +++-- .../core/kernels/data/generator_dataset_op.cc | 3 +- .../kernels/data/interleave_dataset_op.cc | 15 ++++--- .../core/kernels/data/map_dataset_op.cc | 8 ++-- .../core/kernels/data/model_dataset_op.cc | 5 ++- .../kernels/data/padded_batch_dataset_op.cc | 5 ++- .../data/parallel_interleave_dataset_op.cc | 27 +++++++----- .../kernels/data/parallel_map_dataset_op.cc | 8 ++-- .../core/kernels/data/prefetch_dataset_op.cc | 5 ++- .../core/kernels/data/range_dataset_op.cc | 3 +- .../core/kernels/data/repeat_dataset_op.cc | 13 +++--- .../core/kernels/data/shard_dataset_op.cc | 5 ++- .../core/kernels/data/shuffle_dataset_op.cc | 15 ++++--- .../core/kernels/data/skip_dataset_op.cc | 8 ++-- .../data/sparse_tensor_slice_dataset_op.cc | 3 +- .../core/kernels/data/take_dataset_op.cc | 8 ++-- .../core/kernels/data/tensor_dataset_op.cc | 3 +- .../kernels/data/tensor_slice_dataset_op.cc | 3 +- .../core/kernels/data/text_line_dataset_op.cc | 3 +- .../core/kernels/data/tf_record_dataset_op.cc | 3 +- .../core/kernels/data/window_dataset.cc | 3 +- .../core/kernels/data/window_dataset_op.cc | 5 ++- .../core/kernels/data/zip_dataset_op.cc | 5 ++- .../kernel_tests/auto_shard_dataset_test.py | 3 +- .../data/experimental/ops/distribute.py | 7 ++- .../data/kernel_tests/checkpoint_test.py | 32 ++++++++++++++ tensorflow/python/data/ops/dataset_ops.py | 25 +++++------ tensorflow/python/data/ops/iterator_ops.py | 12 +++++- 62 files changed, 359 insertions(+), 192 deletions(-) diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index a3e5f87f66b..92f7a52b632 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -481,6 +481,25 @@ class SerializationContext { kFail = 2, }; + // Handles the CheckExternalState status according to the external state + // policy. + Status HandleCheckExternalStateStatus(Status s) { + if (s.ok()) { + return s; + } + switch (params_.external_state_policy) { + case ExternalStatePolicy::kWarn: + LOG(WARNING) << s.ToString(); + return Status::OK(); + case ExternalStatePolicy::kIgnore: + VLOG(2) << "Ignoring error status: " << s.ToString(); + return Status::OK(); + case ExternalStatePolicy::kFail: + return s; + } + LOG(FATAL) << "Control should never reach here"; + } + struct Params { std::vector>* input_list = nullptr; // Not owned. @@ -589,7 +608,7 @@ class IteratorBase { // Saves the state of this iterator. virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) { - return SaveInternal(writer); + return SaveInternal(ctx, writer); } protected: @@ -604,9 +623,17 @@ class IteratorBase { // This is needed so that sub-classes of IteratorBase can call // `SaveInternal` on their input iterators. + Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer, + const std::unique_ptr& input) { + return input->SaveInternal(ctx, writer); + } + + // TODO(jsimsa): Remove this override when all callers are migrated to the + // override that uses SerializationContext. Status SaveInput(IteratorStateWriter* writer, const std::unique_ptr& input) { - return input->SaveInternal(writer); + SerializationContext ctx(/*params=*/{}); + return input->SaveInternal(&ctx, writer); } // This is needed so that sub-classes of IteratorBase can call @@ -620,7 +647,17 @@ class IteratorBase { // // This method is used to store the state of the iterator in a checkpoint. // implementations have an override. - virtual Status SaveInternal(IteratorStateWriter* writer) = 0; + virtual Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) { + return SaveInternal(writer); + } + + // TODO(jsimsa): Remove this override when all subclasses are migrated to the + // override that accepts SerializationContext and make that override pure + // virtual. + virtual Status SaveInternal(IteratorStateWriter* writer) { + return errors::Unimplemented("checkpointing is not supported"); + } // Restores the state of this iterator. // diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 0d454a0abf2..c915f80c2c6 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -257,12 +257,13 @@ class BatchDatasetOp::Dataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), "")); } else { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } return Status::OK(); } diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index e773efc6c2e..f99ac114dc2 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -164,10 +164,11 @@ class CacheDatasetOp::FileDataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_)); - return SaveInput(writer, iterator_); + return SaveInput(ctx, writer, iterator_); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { @@ -303,7 +304,8 @@ class CacheDatasetOp::FileDataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kCurIndex), cur_index_)); @@ -333,7 +335,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase { lockfile_ = strings::StrCat(filename_, kLockFileSuffix); lockfile_created_ = false; } - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kShardId), shard_id_)); return Status::OK(); } @@ -532,7 +534,8 @@ class CacheDatasetOp::FileDataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kCurIndex), cur_index_)); @@ -785,14 +788,15 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (cache_->IsCompleted()) { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), "")); TF_RETURN_IF_ERROR(SaveCache( writer, cache_, [this](const string& s) { return full_name(s); })); } - return SaveInput(writer, iterator_); + return SaveInput(ctx, writer, iterator_); } Status RestoreInternal(IteratorContext* ctx, @@ -867,14 +871,15 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!cache_->IsCompleted()) { TF_RETURN_IF_ERROR( SaveCache(writer, &temp_cache_, [this](const string& s) { return full_name(s); })); } - return SaveInput(writer, input_impl_); + return SaveInput(ctx, writer, input_impl_); } Status RestoreInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index 4cf1228f0fd..34faafeb178 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -146,11 +146,12 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), i_)); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kInputImplUninitialized), "")); diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc index c4dfaf78d7c..1dd38dcaa04 100644 --- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc @@ -124,10 +124,11 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("num_elements"), num_elements_)); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc index dea2ac8efe8..0fe35ed4b15 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -119,8 +119,9 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc index b5f89996b5d..e911d82b7d4 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc @@ -105,7 +105,8 @@ class WrapperDataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1.0); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return Status::OK(); } @@ -393,9 +394,10 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { // TODO(rachelim): Save and restore histogram state as well. Currently, // if an iterator is saved and restored, the histograms start recording // from scratch. - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"), experiment_counter_)); TF_RETURN_IF_ERROR( @@ -403,7 +405,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("fastest_index"), fastest_index_)); if (current_iterator_) { - TF_RETURN_IF_ERROR(SaveInput(writer, current_iterator_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_iterator_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc index 6531e766183..f346dcc70c3 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc @@ -238,7 +238,8 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { // TODO(rachelim): Save and restore histogram state as well. Currently, // if an iterator is saved and restored, the histograms start recording // from scratch. - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"), experiment_counter_)); @@ -246,13 +247,13 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("fastest_index"), fastest_index_)); if (fastest_index_ != -1) { - TF_RETURN_IF_ERROR(SaveInput(writer, fastest_input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, fastest_input_impl_)); } else if (input_impls_.empty()) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impls_empty"), "")); } else { for (auto& input_impl : input_impls_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl)); } } return Status::OK(); diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc index 8b2d3343c78..8d1bd7acfd9 100644 --- a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc @@ -269,7 +269,8 @@ class CSVDatasetOp : public DatasetOpKernel { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), current_file_index_)); diff --git a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc index b0aee7a2af2..d09922988b9 100644 --- a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc @@ -289,9 +289,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { DatasetIterator>::dataset()->batch_size_); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(Iterator::SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(Iterator::SaveInput(ctx, writer, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index 630efbebb6f..48a446be42c 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -212,10 +212,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { return model::MakeInterleaveManyNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (selector_input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("selector_input_impl_empty"), "")); @@ -223,7 +224,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { for (size_t i = 0; i < data_input_impls_.size(); ++i) { const auto& data_input_impl = data_input_impls_[i]; if (data_input_impl) { - TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl)); } else { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat("data_input_impl_empty[", i, "]")), diff --git a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc index cde4fe2a591..7418fb2c9a3 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc @@ -285,16 +285,18 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_key_func_->CheckExternalState()); - TF_RETURN_IF_ERROR( - dataset()->captured_init_func_->CheckExternalState()); - TF_RETURN_IF_ERROR( - dataset()->captured_reduce_func_->CheckExternalState()); - TF_RETURN_IF_ERROR( - dataset()->captured_finalize_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_key_func_->CheckExternalState())); + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_init_func_->CheckExternalState())); + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_reduce_func_->CheckExternalState())); + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_finalize_func_->CheckExternalState())); mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); if (end_of_input_) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index 51b54019c0b..462f8ce6ef7 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -294,14 +294,16 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_key_func_->CheckExternalState()); - TF_RETURN_IF_ERROR( - dataset()->captured_reduce_func_->CheckExternalState()); - TF_RETURN_IF_ERROR( - dataset()->captured_window_size_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_key_func_->CheckExternalState())); + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_reduce_func_->CheckExternalState())); + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_window_size_func_->CheckExternalState())); mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); if (end_of_input_) { TF_RETURN_IF_ERROR( @@ -342,7 +344,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } if (current_group_iterator_) { - TF_RETURN_IF_ERROR(SaveInput(writer, current_group_iterator_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_group_iterator_)); // Saving current_key_ TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc index e7a1675b664..e177fe27d18 100644 --- a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc @@ -115,10 +115,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); else TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impls_empty"), "")); diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc index d661623dd00..1852bc51407 100644 --- a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc @@ -134,7 +134,8 @@ class LMDBDatasetOp : public DatasetOpKernel { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return errors::Unimplemented( "Checkpointing is currently not supported for LMDBDataset."); } diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index a21a97b762a..c016711bedc 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -251,15 +251,17 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { /*max=*/ctx->runner_threadpool_size())}); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { cond_var_->wait(l); } DCHECK_EQ(num_calls_, 0); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kCallCounter), call_counter_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kBatchResultsSize), diff --git a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc index 3341827abdc..1e752931157 100644 --- a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc @@ -106,8 +106,9 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel { return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 74482d8f3e0..c4b34b5163c 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -396,13 +396,15 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { /*parameters=*/{}); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); // The order of locking is important here to avoid deadlock. mutex_lock l(mu_); mutex_lock ckpt_l(ckpt_mu_); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } else { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInputExhausted, "")); } @@ -416,7 +418,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i)); } for (int i = 0; i < worker_thread_states_.size(); ++i) { - TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i)); + TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(ctx, writer, i)); } TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInterleaveSize, interleave_indices_.size())); @@ -932,13 +934,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return Status::OK(); } - Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer, int index) + Status WriteWorkerThreadStateLocked(SerializationContext* ctx, + IteratorStateWriter* writer, int index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { string iterator_name = strings::StrCat(prefix(), "::", kWorkerThread, "_", index); if (worker_thread_states_[index].iterator != nullptr) { TF_RETURN_IF_ERROR( - SaveInput(writer, worker_thread_states_[index].iterator)); + SaveInput(ctx, writer, worker_thread_states_[index].iterator)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(iterator_name, kIteratorExhausted, "")); diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc index e1d2b5afafe..460c18ce7a3 100644 --- a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc @@ -104,7 +104,8 @@ class RandomDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"), num_random_samples_)); diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc index 5f224b8a5f4..dd7084d0b26 100644 --- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc @@ -179,13 +179,14 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); } else { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("slice_number"), slice_number_)); diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc index 36f195d1d1e..00869eea85c 100644 --- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc @@ -145,7 +145,8 @@ class SamplingDatasetOp::Dataset : public DatasetBase { generator_.Skip(num_random_samples_); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); // Save state needed to restore the random number generators. TF_RETURN_IF_ERROR(writer->WriteScalar( @@ -156,7 +157,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase { writer->WriteScalar(this->full_name("seed2"), seeds_.second)); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); diff --git a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc index 9c2f4c4b403..2b7ece1661b 100644 --- a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc @@ -248,10 +248,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); if (!state_.empty()) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("state_size"), state_.size())); diff --git a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc index a3a43085316..e96de29d759 100644 --- a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc @@ -199,9 +199,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); - return SaveInput(writer, input_impl_); + return SaveInput(ctx, writer, input_impl_); } Status RestoreInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc index ec13bb8108c..f2195804cfd 100644 --- a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc @@ -142,8 +142,9 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { - return SaveInput(writer, input_impl_); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + return SaveInput(ctx, writer, input_impl_); } Status RestoreInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc index f6be3c08d6d..04ebd5bfd34 100644 --- a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc @@ -239,13 +239,14 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { dataset()->window_shift_); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); } else { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } // Save buffer. TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"), diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index d7eff8df710..e9873fd226e 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -389,9 +389,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, iterator_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kState), static_cast(state_))); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kHashDir), hash_dir_)); @@ -623,7 +624,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kHashDir), hash_dir_)); @@ -1037,9 +1039,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); if (end_of_sequence_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kEndOfSequence), "")); @@ -1482,8 +1485,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(IteratorStateWriter* writer) override { - return SaveInput(writer, input_impl_); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + return SaveInput(ctx, writer, input_impl_); } Status RestoreInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc index 976cb1e87ba..f6720aa1c88 100644 --- a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc @@ -158,7 +158,8 @@ class SqlDatasetOp : public DatasetOpKernel { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (query_connection_initialized_) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc index bdf7a29ca26..08d208fc340 100644 --- a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc @@ -128,9 +128,10 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } @@ -246,9 +247,10 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc index 3868c65af5e..fd4b4fccb7e 100644 --- a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc @@ -166,11 +166,13 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); mutex_lock l(mu_); if (input_impl_) - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); else TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impls_empty"), "")); diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index 26a58d10593..65252e3dbcf 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -216,10 +216,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); DCHECK(input_impl_ != nullptr); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } @@ -348,10 +349,11 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); DCHECK(input_impl_ != nullptr); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } @@ -470,10 +472,11 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); DCHECK(input_impl_ != nullptr); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc index c19ad6ca7ae..111afa218df 100644 --- a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc @@ -164,10 +164,11 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc index dd35f44f9ac..a4319234082 100644 --- a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc @@ -103,10 +103,11 @@ class UniqueDatasetOp::Dataset : public DatasetBase { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index f7828e60fe4..1301aed3cb4 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -194,11 +194,13 @@ class FilterDatasetOp::Dataset : public DatasetBase { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); mutex_lock l(mu_); if (input_impl_) - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); else TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kInputImplsEmpty), "")); diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index 2aa2d5e4ce2..15bfeb01a65 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc @@ -190,7 +190,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { } protected: - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex), current_file_index_)); @@ -374,7 +375,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex), current_file_index_)); diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index dca4973f0ce..eba5097a1bb 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -162,11 +162,13 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { return model::MakeInterleaveManyNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); mutex_lock l(mu_); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kElementIndex), element_index_)); if (current_element_iterator_) { @@ -178,7 +180,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { full_name(strings::StrCat(kCapturedFuncInputs, "[", i, "]")), captured_func_inputs_[i])); } - TF_RETURN_IF_ERROR(SaveInput(writer, current_element_iterator_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_element_iterator_)); } else { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(kCurrentElementIteratorUninitialized), "")); diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 8eca7057bec..fcdbe4ab9a5 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -158,7 +158,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return errors::Unimplemented( "GeneratorDataset does not support checkpointing."); } diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 312aaa5219c..0a795c1cf82 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -195,10 +195,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { return model::MakeInterleaveManyNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kCycleIndex), cycle_index_)); TF_RETURN_IF_ERROR( @@ -207,7 +209,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kEndOfInput), "")); } TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumOpen), num_open_)); - TF_RETURN_IF_ERROR(SaveCurrentElements(writer)); + TF_RETURN_IF_ERROR(SaveCurrentElements(ctx, writer)); return Status::OK(); } @@ -234,11 +236,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { } private: - Status SaveCurrentElements(IteratorStateWriter* writer) + Status SaveCurrentElements(SerializationContext* ctx, + IteratorStateWriter* writer) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { for (int idx = 0; idx < current_elements_.size(); idx++) { if (current_elements_[idx]) { - TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx])); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_elements_[idx])); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat(kArgsSize, "[", idx, "]")), args_list_[idx].size())); diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 1bbf277ee58..cbd0aa093e5 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -174,9 +174,11 @@ class MapDatasetOp::Dataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 4a227078a78..87e61e1d37c 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -151,9 +151,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index 12800c27eff..a35fb2c3952 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -343,10 +343,11 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); else TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), "")); return Status::OK(); diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index c744977c8bd..0a58a082948 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -377,8 +377,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // TODO(aaudibert): Refactor the implementations to avoid the need for // `IteratorContext` when saving the state of the iterator. - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + dataset()->captured_func_->CheckExternalState())); mutex_lock l(*mu_); wait_for_checkpoint_ = true; // Wait for all in-flight calls to complete. @@ -400,7 +402,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { wait_for_checkpoint_ = false; DCHECK_EQ(num_active_workers_, 0); VLOG(4) << "State before save:\n" << DebugString(); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kBlockIndex, block_index_)); TF_RETURN_IF_ERROR( @@ -410,8 +412,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kElementIdCounter, element_id_counter_)); - TF_RETURN_IF_ERROR(WriteCurrentElements(writer)); - TF_RETURN_IF_ERROR(WriteFutureElements(writer)); + TF_RETURN_IF_ERROR(WriteCurrentElements(ctx, writer)); + TF_RETURN_IF_ERROR(WriteFutureElements(ctx, writer)); // Wake workers back up. current_workers_cond_var_.notify_all(); future_workers_cond_var_.notify_all(); @@ -1124,13 +1126,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::StrCat(kResultsSuffix, "[", idx, "]", kErrorMessageSuffix); } - Status WriteElement(std::shared_ptr element, int idx, + Status WriteElement(SerializationContext* ctx, + std::shared_ptr element, int idx, const string& key_prefix, IteratorStateWriter* writer) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { const auto& iterator_name = absl::StrCat(prefix(), "::", key_prefix, "::", idx); if (element->iterator) { - TF_RETURN_IF_ERROR(SaveInput(writer, element->iterator)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, element->iterator)); TF_RETURN_IF_ERROR( writer->WriteScalar(iterator_name, kIdSuffix, element->id)); TF_RETURN_IF_ERROR(writer->WriteScalar( @@ -1165,26 +1168,28 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return Status::OK(); } - Status WriteCurrentElements(IteratorStateWriter* writer) + Status WriteCurrentElements(SerializationContext* ctx, + IteratorStateWriter* writer) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurrentElementsSize, current_elements_.size())); for (int idx = 0; idx < current_elements_.size(); idx++) { if (current_elements_[idx]) { - TF_RETURN_IF_ERROR(WriteElement(current_elements_[idx], idx, + TF_RETURN_IF_ERROR(WriteElement(ctx, current_elements_[idx], idx, kCurrentElements, writer)); } } return Status::OK(); } - Status WriteFutureElements(IteratorStateWriter* writer) + Status WriteFutureElements(SerializationContext* ctx, + IteratorStateWriter* writer) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kFutureElementsSize, future_elements_.size())); for (int idx = 0; idx < future_elements_.size(); idx++) { if (future_elements_[idx]) { - TF_RETURN_IF_ERROR(WriteElement(future_elements_[idx], idx, + TF_RETURN_IF_ERROR(WriteElement(ctx, future_elements_[idx], idx, kFutureElements, writer)); } } diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 55b29ed2a08..ca547fb6339 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -377,8 +377,10 @@ class ParallelMapIterator : public DatasetBaseIterator { /*max=*/ctx->runner_threadpool_size())}); } - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(parallel_map_functor_->CheckExternalState()); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( + parallel_map_functor_->CheckExternalState())); mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { @@ -388,7 +390,7 @@ class ParallelMapIterator : public DatasetBaseIterator { return errors::FailedPrecondition( "Unexpected outstanding calls encountered."); } - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), invocation_results_.size())); diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 27c2ca57854..0c3fe43d6c3 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -206,12 +206,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { /*max=*/std::numeric_limits::max())}); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { // Acquire both locks to ensure that the prefetch thread and // all GetNext threads are blocked. mutex_lock input_l(input_mu_); mutex_lock l(*mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kBufferSize, buffer_.size())); for (size_t i = 0; i < buffer_.size(); i++) { diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index a03fe82d815..f6993ab2797 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -132,7 +132,8 @@ class RangeDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNext), next_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index d968661b1c6..dd6a0e9d03e 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -124,7 +124,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase { /*ratio=*/kKnownRatio); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, @@ -172,13 +173,14 @@ class RepeatDatasetOp::Dataset : public DatasetBase { /*ratio=*/kKnownRatio); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIteration), i_)); if (!input_impl_) { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), "")); } else { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } return Status::OK(); } @@ -249,10 +251,11 @@ class RepeatDatasetOp::Dataset : public DatasetBase { /*ratio=*/kKnownRatio); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!first_call_) - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); else TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kUninitialized), "")); return Status::OK(); diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc index 9d6f81ced96..03c9525a7ab 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op.cc @@ -169,12 +169,13 @@ class ShardDatasetOp::Dataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), dataset()->num_shards_); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), "")); } else { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kNextIndex), next_index_)); } diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 337c82c1c61..26f68431203 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -278,7 +278,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { generator_.Skip(num_random_samples_); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); // Save state needed to restore the random number generators. TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kNumRandomSamples), @@ -292,7 +293,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(this->full_name(kEndOfInputSequence), "")); } else { - TF_RETURN_IF_ERROR(this->SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(this->SaveInput(ctx, writer, input_impl_)); } // Save the epoch counter, buffer, and buffer slices. @@ -526,14 +527,15 @@ class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { // Save RNG state of Dataset. TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kDSNumRandomSamples), seed_generator_->num_random_samples())); // Save the Iterator. - return ShuffleDatasetBase::Iterator::SaveInternal(writer); + return ShuffleDatasetBase::Iterator::SaveInternal(ctx, writer); } Status RestoreInternal(IteratorContext* ctx, @@ -634,14 +636,15 @@ class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase { return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { // Save state of the seed generator. TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kDSNumRandomSamples), seed_generator_->num_random_samples())); // Save the tterator state. - return ShuffleDatasetBase::Iterator::SaveInternal(writer); + return ShuffleDatasetBase::Iterator::SaveInternal(ctx, writer); } Status RestoreInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 9a3e8a9c980..952d5cae97b 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -110,7 +110,8 @@ class SkipDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return Status::OK(); } @@ -172,11 +173,12 @@ class SkipDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_)); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } else { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), "")); } diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index adcdc954ffc..1e3ed53d6c6 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -161,7 +161,8 @@ class Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(Iterator::full_name("i"), i_)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 05ed06f459e..627467f291b 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -95,7 +95,8 @@ class TakeDataset::EmptyIterator : public DatasetIterator { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return Status::OK(); } @@ -142,11 +143,12 @@ class TakeDataset::FiniteIterator : public DatasetIterator { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_)); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } else { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), "")); } diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 8abdbe7b757..20540cf9a57 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -117,7 +117,8 @@ class TensorDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (produced_) TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kProduced), "")); diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index a2e4222033e..8831f8d548d 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -136,7 +136,8 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc index 7f6c3b20f0e..c2c3190bd7f 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc @@ -139,7 +139,8 @@ class TextLineDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex), current_file_index_)); diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc index c63d211d926..a72d05c5155 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc @@ -156,7 +156,8 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex), current_file_index_)); diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index 2207577f8db..ad6273f8941 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -97,7 +97,8 @@ class WindowDataset : public DatasetBase { return Status::OK(); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 7db3f0a6a5b..35437a9231c 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -255,12 +255,13 @@ class WindowDatasetOp::Dataset : public DatasetBase { dataset()->window_shift_); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), "")); } else { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } // Save buffer. TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index 8b6f9a88d00..b59dc2c3a22 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -174,14 +174,15 @@ class ZipDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impls_.empty()) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kInputImplsEmpty), "")); } else { for (auto& input_impl : input_impls_) - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl)); } return Status::OK(); } diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py index 5f13bdae849..b2b348a436e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py @@ -258,8 +258,7 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, graph_def = dataset._as_serialized_graph( strip_device_assignment=True, - external_state_policy= - dataset.options().experimental_external_state_policy) + external_state_policy=distribute_options.ExternalStatePolicy.WARN) options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = sharding_policy diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index 1d7e642e5e7..206f1b1df35 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.compat import compat from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy +from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import ops @@ -162,10 +163,12 @@ def replicate(dataset, devices): with ops.colocate_with(dataset._variant_tensor): dataset = dataset._apply_options() - external_state_policy = dataset.options().experimental_external_state_policy + policy = dataset.options().experimental_external_state_policy + if policy is None: + policy = ExternalStatePolicy.WARN graph_def = dataset._as_serialized_graph( strip_device_assignment=True, - external_state_policy=external_state_policy) + external_state_policy=policy) for device in devices: ds = _RemoteDataset(graph_def, device, dataset.element_spec) datasets[device] = ds diff --git a/tensorflow/python/data/kernel_tests/checkpoint_test.py b/tensorflow/python/data/kernel_tests/checkpoint_test.py index b0c9a77dd1e..1341744d221 100644 --- a/tensorflow/python/data/kernel_tests/checkpoint_test.py +++ b/tensorflow/python/data/kernel_tests/checkpoint_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os from absl.testing import parameterized +from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import grouping from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.experimental.ops import scan_ops @@ -35,6 +36,7 @@ from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import script_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -531,6 +533,36 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.apply(take_while_ops.take_while(self._statefulBoolFunc)) self._assertNotCheckpointable(dataset) + @combinations.generate(test_base.eager_only_combinations()) + def testStatefulExternalPolicy(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.range(4) + + def fn(x): + return x * x + + dataset = dataset.map( + lambda x: script_ops.eager_py_func(fn, [x], dtypes.int64)) + + options = dataset_ops.Options() + options.experimental_external_state_policy = ( + distribute_options.ExternalStatePolicy.WARN) + dataset = dataset.with_options(options) + + iterator = iter(dataset) + get_next = iterator.get_next + checkpoint = trackable_utils.Checkpoint(iterator=iterator) + self.assertEqual(0, get_next().numpy()) + self.assertEqual(1, get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertEqual(4, get_next().numpy()) + self.assertEqual(9, get_next().numpy()) + checkpoint.restore(save_path).run_restore_ops() + self.assertEqual(4, get_next().numpy()) + self.assertEqual(9, get_next().numpy()) + with self.assertRaises(errors.OutOfRangeError): + get_next() if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index b81de54f16e..32ab469363e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -225,9 +225,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): serialized graph. """ if external_state_policy: - policy = None - if external_state_policy: - policy = external_state_policy.value + policy = external_state_policy.value return gen_dataset_ops.dataset_to_graph_v2( self._variant_tensor, external_state_policy=policy, @@ -1031,14 +1029,14 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): Example: If we had the following files on our filesystem: - + - /path/to/dir/a.txt - /path/to/dir/b.py - /path/to/dir/c.py - + If we pass "/path/to/dir/*.py" as the directory, the dataset would produce: - + - /path/to/dir/b.py - /path/to/dir/c.py @@ -2731,15 +2729,12 @@ class Options(options_lib.OptionsBase): experimental_external_state_policy = options_lib.create_option( name="experimental_external_state_policy", ty=distribute_options.ExternalStatePolicy, - docstring="By default, tf.data will refuse to serialize a dataset or " - "checkpoint its iterator if the dataset contains a stateful op as the " - "serialization / checkpointing won't be able to capture its state. " - "Users can -- at their own risk -- override this restriction by " - "explicitly specifying that they are fine throwing away the state " - "in these ops. There are three settings available - IGNORE: in which we" - "completely ignore any state; WARN: We warn the user that some state " - "might be thrown away; FAIL: We fail if any state is being captured.", - default_factory=lambda: distribute_options.ExternalStatePolicy.WARN) + docstring="This option can be used to override the default policy for " + "how to handle external state when serializing a dataset or " + "checkpointing its iterator. There are three settings available - " + "IGNORE: in which we completely ignore any state; WARN: We warn the " + "user that some state might be thrown away; FAIL: We fail if any state " + "is being captured.") def _graph_rewrites(self): """Produces the list of enabled static graph rewrites.""" diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 668af74acf6..09187705c16 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -743,7 +743,17 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor): def _gather_saveables_for_checkpoint(self): def _saveable_factory(name): - return _IteratorSaveable(self._iterator_resource, name) + """Returns a SaveableObject for serialization/deserialization.""" + policy = None + if self._dataset: + policy = self._dataset.options().experimental_external_state_policy + if policy: + return _IteratorSaveable( + self._iterator_resource, + name, + external_state_policy=policy) + else: + return _IteratorSaveable(self._iterator_resource, name) return {"ITERATOR": _saveable_factory} From bbe150901badaf042ae66bd0dd6857aca2ac42fd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 13:46:46 -0700 Subject: [PATCH 096/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301444456 Change-Id: I3d17afd682011faf048b2fdcb39f188da770a271 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 6456f104ad3..52a9bf9551b 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From 523f1d04204c15279935e534751860c85eb0686f Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 17 Mar 2020 14:00:19 -0700 Subject: [PATCH 097/492] Change HLO op names to match XLA op code names DivOp op code in XLA is "divide" instead of "div". Op names are still using the shorter name as that is what XLA builders are using. Similarly for MulOp, MaxOp, MinOp and SubOp. PiperOrigin-RevId: 301447200 Change-Id: I50f19562ba6ae2add404c2cca0acac7d97cfc955 --- .../quantization/xla/tests/materialize.mlir | 16 +- .../quantization/xla/tests/weight-only.mlir | 4 +- .../mlir/tensorflow/tests/legalize_hlo.mlir | 22 +-- .../compiler/mlir/xla/ir/hlo_client_ops.td | 10 +- tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 10 +- tensorflow/compiler/mlir/xla/ir/lhlo_ops.td | 10 +- .../mlir/xla/tests/hlo-legalize-to-lhlo.mlir | 20 +-- .../xla/tests/hlo-legalize-to-linalg.mlir | 12 +- .../compiler/mlir/xla/tests/legalize-tf.mlir | 144 +++++++++--------- .../mlir/xla/tests/legalize-to-std.mlir | 24 +-- .../xla/tests/lhlo-legalize-to-affine.mlir | 22 +-- .../xla/tests/lhlo-legalize-to-linalg.mlir | 4 +- .../compiler/mlir/xla/tests/lhlo_ops.mlir | 12 +- .../mlir/xla/tests/lower-complex.mlir | 132 ++++++++-------- .../xla/tests/materialize-broadcasts.mlir | 20 +-- .../mlir/xla/tests/translate/export.mlir | 10 +- .../fully_connected_reference_model.hlotxt | 4 +- .../mlir/xla/tests/translate/import.hlotxt | 12 +- .../mlir/xla/tests/unfuse_batch_norm.mlir | 12 +- .../xla/transforms/hlo_legalize_to_lhlo.cc | 8 +- .../mlir/xla/transforms/legalize_tf.cc | 26 ++-- .../service/mlir_gpu/tests/add_multiply.hlo | 2 +- 22 files changed, 268 insertions(+), 268 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/materialize.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/materialize.mlir index ab93d9154ec..c731d72f752 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/materialize.mlir +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/materialize.mlir @@ -6,49 +6,49 @@ func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> { // CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32, // CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16> // CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32> -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[cast]] : tensor<2x4xf32> +// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[cast]] : tensor<2x4xf32> // CHECK-NEXT: return %[[mul]] : tensor<2x4xf32> %w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32> %q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform> %dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform>) -> tensor<2x4xf32> - %mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32> + %mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32> return %mul: tensor<2x4xf32> } // CHECK-LABEL: func @quantize_small func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> { // CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32> -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<1x4xf32> +// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<1x4xf32> // CHECK-NEXT: return %[[mul]] : tensor<1x4xf32> %w = constant dense<1.0> : tensor<1x4xf32> %q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> %dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> - %mul = xla_hlo.mul %arg0, %dq : tensor<1x4xf32> + %mul = xla_hlo.multiply %arg0, %dq : tensor<1x4xf32> return %mul: tensor<1x4xf32> } // CHECK-LABEL: func @quantize_non_cst func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> { -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %arg0 : tensor<2x4xf32> +// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %arg0 : tensor<2x4xf32> // CHECK-NEXT: return %[[mul]] : tensor<2x4xf32> %q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform> %dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform>) -> tensor<2x4xf32> - %mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32> + %mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32> return %mul: tensor<2x4xf32> } // CHECK-LABEL: func @quantize_non_4x func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> { // CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32> -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<2x5xf32> +// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<2x5xf32> // CHECK-NEXT: return %[[mul]] : tensor<2x5xf32> %w = constant dense<1.0> : tensor<2x5xf32> %q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform> %dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform>) -> tensor<2x5xf32> - %mul = xla_hlo.mul %arg0, %dq : tensor<2x5xf32> + %mul = xla_hlo.multiply %arg0, %dq : tensor<2x5xf32> return %mul: tensor<2x5xf32> } diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/weight-only.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/weight-only.mlir index ba384f10698..8f0936c41af 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/weight-only.mlir +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/weight-only.mlir @@ -5,10 +5,10 @@ func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32> // CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[dq]] : tensor<2x2xf32> +// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[dq]] : tensor<2x2xf32> // CHECK-NEXT: return %[[mul]] : tensor<2x2xf32> %w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32> - %mul = xla_hlo.mul %arg0, %w : tensor<2x2xf32> + %mul = xla_hlo.multiply %arg0, %w : tensor<2x2xf32> return %mul: tensor<2x2xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index c1b53debd7c..4f9e12736e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -34,7 +34,7 @@ return %0 : tensor<4x4x4x4xi32> // CHECK: return [[VAL_8]] : tensor<4x4x4x4xi32> func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> +%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } // CHECK-LABEL: func @div( @@ -43,7 +43,7 @@ return %0 : tensor<2xi32> // CHECK: return [[VAL_10]] : tensor<2xi32> func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } // CHECK-LABEL: func @broadcast_div( @@ -61,7 +61,7 @@ return %0 : tensor<4xi32> // CHECK: return [[VAL_16]] : tensor<4xi32> func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { -%0 = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor +%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @div_dynamic( @@ -79,7 +79,7 @@ return %0 : tensor // CHECK: return [[VAL_22]] : tensor func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -%0 = xla_hlo.max %arg0, %arg1 : tensor<4xf32> +%0 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } // CHECK-LABEL: func @maximum( @@ -88,7 +88,7 @@ return %0 : tensor<4xf32> // CHECK: return [[VAL_25]] : tensor<4xf32> func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -%0 = xla_hlo.min %arg0, %arg1 : tensor<4xf32> +%0 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } // CHECK-LABEL: func @minimum( @@ -97,7 +97,7 @@ return %0 : tensor<4xf32> // CHECK: return [[VAL_28]] : tensor<4xf32> func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.mul %arg0, %arg0 : tensor<2xi32> +%0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } // CHECK-LABEL: func @mul( @@ -106,7 +106,7 @@ return %0 : tensor<2xi32> // CHECK: return [[VAL_30]] : tensor<2xi32> func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } // CHECK-LABEL: func @broadcast_mul( @@ -115,7 +115,7 @@ return %0 : tensor<1x2xi32> // CHECK: return [[VAL_33]] : tensor<1x2xi32> func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> +%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } // CHECK-LABEL: func @real_div( @@ -124,7 +124,7 @@ return %0 : tensor<2xi32> // CHECK: return [[VAL_35]] : tensor<2xi32> func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } // CHECK-LABEL: func @broadcast_real_div( @@ -133,7 +133,7 @@ return %0 : tensor<1x2xi32> // CHECK: return [[VAL_38]] : tensor<1x2xi32> func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.sub %arg0, %arg0 : tensor<2xi32> +%0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } // CHECK-LABEL: func @sub( @@ -142,7 +142,7 @@ return %0 : tensor<2xi32> // CHECK: return [[VAL_40]] : tensor<2xi32> func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +%0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } // CHECK-LABEL: func @broadcast_sub( diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td index 2048604915d..6a60a42861a 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td @@ -83,16 +83,16 @@ def HLOClient_AddOp : HLOClient_BinaryElementwiseOp<"add", def HLOClient_Atan2Op : HLOClient_BinaryElementwiseOp<"atan2", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; -def HLOClient_DivOp : HLOClient_BinaryElementwiseOp<"div", +def HLOClient_DivOp : HLOClient_BinaryElementwiseOp<"divide", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp; -def HLOClient_MaxOp : HLOClient_BinaryElementwiseOp<"max", +def HLOClient_MaxOp : HLOClient_BinaryElementwiseOp<"maximum", [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp; -def HLOClient_MinOp : HLOClient_BinaryElementwiseOp<"min", +def HLOClient_MinOp : HLOClient_BinaryElementwiseOp<"minimum", [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp; -def HLOClient_MulOp : HLOClient_BinaryElementwiseOp<"mul", +def HLOClient_MulOp : HLOClient_BinaryElementwiseOp<"multiply", [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp; def HLOClient_PowOp : HLOClient_BinaryElementwiseOp<"pow", @@ -110,7 +110,7 @@ def HLOClient_ShiftRightArithmeticOp : HLOClient_BinaryElementwiseOp<"shift_righ def HLOClient_ShiftRightLogicalOp : HLOClient_BinaryElementwiseOp<"shift_right_logical", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; -def HLOClient_SubOp : HLOClient_BinaryElementwiseOp<"sub", +def HLOClient_SubOp : HLOClient_BinaryElementwiseOp<"subtract", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index d85a44eca10..8f8f6ac62e3 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -288,16 +288,16 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add", def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; -def HLO_DivOp : HLO_BinaryElementwiseOp<"div", +def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp; -def HLO_MaxOp : HLO_BinaryElementwiseOp<"max", +def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp; -def HLO_MinOp : HLO_BinaryElementwiseOp<"min", +def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp; -def HLO_MulOp : HLO_BinaryElementwiseOp<"mul", +def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp; def HLO_PowOp : HLO_BinaryElementwiseOp<"pow", @@ -315,7 +315,7 @@ def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; -def HLO_SubOp : HLO_BinaryElementwiseOp<"sub", +def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index a37c530532d..92e084bf6a2 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -120,18 +120,18 @@ class LHLO_BinaryElementwiseOp traits> : def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp; -def LHLO_DivOp : LHLO_BinaryElementwiseOp<"div", []>, BASE_HLO_DivOp; +def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp; -def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"max", []>, BASE_HLO_MaxOp; +def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp; -def LHLO_MinOp : LHLO_BinaryElementwiseOp<"min", []>, BASE_HLO_MinOp; +def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum", []>, BASE_HLO_MinOp; -def LHLO_MulOp : LHLO_BinaryElementwiseOp<"mul", []>, BASE_HLO_MulOp; +def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply", []>, BASE_HLO_MulOp; def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", []>, BASE_HLO_RemOp; -def LHLO_SubOp : LHLO_BinaryElementwiseOp<"sub", []>, BASE_HLO_SubOp; +def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract", []>, BASE_HLO_SubOp; def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", []>, BASE_HLO_AndOp; diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 2aeb5f1041d..2858c6d9978 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -21,16 +21,16 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - %1 = xla_hlo.max %arg0, %arg1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) + %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + // CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) - %3 = xla_hlo.min %arg0, %arg1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.min"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) - %4 = xla_hlo.sub %arg1, %3 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.sub"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) - %5 = xla_hlo.mul %2, %4 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) + %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + // CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) + %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32> + // CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) + %5 = xla_hlo.multiply %2, %4 : tensor<4xf32> + // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> @@ -55,9 +55,9 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> - %tensor_result = "xla_hlo.mul"(%sum, %tensor_multiplier) + %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index c2fb840ad10..0f7b7369035 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: mulf - %0 = "xla_hlo.mul"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: muli - %0 = "xla_hlo.mul"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: subf - %0 = "xla_hlo.sub"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: subi - %0 = "xla_hlo.sub"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -313,7 +313,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { // CHECK-LABEL: func @minf func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "xla_hlo.min"(%lhs, %rhs) + %0 = "xla_hlo.minimum"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -327,7 +327,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @maxi func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla_hlo.max"(%lhs, %rhs) + %0 = "xla_hlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index b759fe593c2..f30bd961fca 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -43,7 +43,7 @@ func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32> // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: xla_hlo.constant - // CHECK: "xla_hlo.mul"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK: "xla_hlo.multiply"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor) -> tensor<8xf32> return %0#0 : tensor<8x8x8x8xf32> } @@ -92,8 +92,8 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -104,10 +104,10 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -150,8 +150,8 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -162,10 +162,10 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -235,8 +235,8 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -247,10 +247,10 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -320,8 +320,8 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -332,10 +332,10 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -416,7 +416,7 @@ func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi3 // CHECK-LABEL: func @div func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -424,7 +424,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @broadcast_div func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } @@ -438,7 +438,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-LABEL: func @div_dynamic func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.Div"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor } @@ -452,21 +452,21 @@ func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.max %arg0, %arg1 : tensor<4xf32> + // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // CHECK-LABEL: func @minimum func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.min %arg0, %arg1 : tensor<4xf32> + // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // CHECK-LABEL: func @mul func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -474,28 +474,28 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @broadcast_mul func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } // CHECK-LABEL: func @real_div func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } // CHECK-LABEL: func @broadcast_real_div func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } // CHECK-LABEL: func @sub func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.sub %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -503,7 +503,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @broadcast_sub func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } @@ -662,15 +662,15 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.sub [[ABS2]], [[ZEROS3]] + // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.neg"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = "xla_hlo.div"([[NEG]], [[ABS3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV2:%.+]] = "xla_hlo.divide"([[NEG]], [[ABS3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -684,15 +684,15 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.sub [[ABS2]], [[ZEROS3]] + // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.neg"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_hlo.div [[NEG]], [[ABS3]] + // CHECK-DAG: [[DIV2:%.+]] = xla_hlo.divide [[NEG]], [[ABS3]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -701,7 +701,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-LABEL: func @floordiv_f32 func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = xla_hlo.div %arg0, %arg0 + // CHECK-NEXT: %[[DIV:.*]] = xla_hlo.divide %arg0, %arg0 // CHECK-NEXT: %[[FLOOR:.*]] = "xla_hlo.floor"(%[[DIV]]) // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> @@ -712,7 +712,7 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: xla_hlo.convert - // CHECK-NEXT: xla_hlo.div + // CHECK-NEXT: xla_hlo.divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: return @@ -722,7 +722,7 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-LABEL: func @floordiv_f16_broadcast func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: xla_hlo.div + // CHECK-NEXT: xla_hlo.divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: return %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> @@ -1250,7 +1250,7 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> - // CHECK: %[[OFFSET:.*]] = xla_hlo.sub %[[X]], %[[Y]] : tensor<64x64xbf16> + // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<*xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor @@ -1270,7 +1270,7 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<12x24x48xbf16> { // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xbf16> // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> - // CHECK: %[[OFFSET:.*]] = xla_hlo.sub %[[X]], %[[Y]] : tensor<24x48xbf16> + // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<*xi1> @@ -1311,7 +1311,7 @@ func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor, %arg2: t func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { // CHECK: %[[INIT:.*]] = xla_hlo.constant dense<-2147483648> : tensor // CHECK: "xla_hlo.reduce_window"(%[[ARG]], %[[INIT]]) - // CHECK: xla_hlo.max + // CHECK: xla_hlo.maximum // CHECK: xla_hlo.return // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} @@ -1502,7 +1502,7 @@ func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (te // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: "xla_hlo.max"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -1510,7 +1510,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked func @relu_unranked(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: "xla_hlo.max"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1644,12 +1644,12 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK-DAG: %[[NEG_INF:.*]] = xla_hlo.constant dense<0xFF800000> : tensor // CHECK-DAG: %[[CASTED_INP:.*]] = "xla_hlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) - // CHECK: xla_hlo.max + // CHECK: xla_hlo.maximum // CHECK: "xla_hlo.return" // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. @@ -1661,7 +1661,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = "xla_hlo.divide"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // return %[[RESULT]] %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> @@ -1671,7 +1671,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Verify intermediate and final shape are correct with dynamic shapes. // CHECK-LABEL: func @dynamic_softmax func @dynamic_softmax(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.div"({{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor) -> tensor + // CHECK: "xla_hlo.divide"({{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor) -> tensor %0 = "tf.Softmax"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1697,7 +1697,7 @@ func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { // CHECK: "xla_hlo.reduce" // CHECK: dimensions = dense<3> - // CHECK: "xla_hlo.div"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: "xla_hlo.divide"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> return %0: tensor<2x3x4x5xf16> } @@ -1714,12 +1714,12 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK-DAG: %[[CASTED_INP:.*]] = "xla_hlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-DAG: %[[NEG_INF:.*]] = xla_hlo.constant dense<0xFF800000> : tensor // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) - // CHECK: xla_hlo.max + // CHECK: xla_hlo.maximum // CHECK: "xla_hlo.return" // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. @@ -1732,7 +1732,7 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.sub"(%[[SHIFTED_INP]], %[[LOG]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = "xla_hlo.subtract"(%[[SHIFTED_INP]], %[[LOG]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // return %[[RESULT]] %0 = "tf.LogSoftmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> @@ -2036,9 +2036,9 @@ func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-DAG: [[R0:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor // CHECK-DAG: [[R1:%.+]] = "xla_hlo.broadcast"([[R0]]) {broadcast_sizes = dense<2> : tensor<1xi64>} : (tensor) -> tensor<2xf32> - // CHECK-DAG: [[R2:%.+]] = xla_hlo.mul %arg0, [[R1]] : tensor<2xf32> + // CHECK-DAG: [[R2:%.+]] = xla_hlo.multiply %arg0, [[R1]] : tensor<2xf32> // CHECK-DAG: [[R3:%.+]] = "xla_hlo.tanh"([[R2]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK-DAG: [[R4:%.+]] = xla_hlo.mul [[R3]], [[R1]] : tensor<2xf32> + // CHECK-DAG: [[R4:%.+]] = xla_hlo.multiply [[R3]], [[R1]] : tensor<2xf32> // CHECK-DAG: [[R5:%.+]] = xla_hlo.add [[R4]], [[R1]] : tensor<2xf32> %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> @@ -2523,7 +2523,7 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> // CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<8.000000e+00> : tensor - // CHECK: %[[MEAN:.*]] = "xla_hlo.div"(%[[REDUCED]], %[[DIVISOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = "xla_hlo.divide"(%[[REDUCED]], %[[DIVISOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> @@ -2590,7 +2590,7 @@ func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0xFC00> : tensor // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.max %[[ARGA]], %[[ARGB]] : tensor + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.maximum %[[ARGA]], %[[ARGB]] : tensor // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor) -> tensor<4xf16> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> @@ -2607,7 +2607,7 @@ func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0xFC00> : tensor // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.max %[[ARGA]], %[[ARGB]] : tensor + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.maximum %[[ARGA]], %[[ARGB]] : tensor // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf16>, tensor) -> tensor<4xf16> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> @@ -2624,7 +2624,7 @@ func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0x7C00> : tensor // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.min %[[ARGA]], %[[ARGB]] : tensor + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.minimum %[[ARGA]], %[[ARGB]] : tensor // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor) -> tensor<4xf16> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> @@ -2641,7 +2641,7 @@ func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.mul %[[ARGA]], %[[ARGB]] : tensor + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.multiply %[[ARGA]], %[[ARGB]] : tensor // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> @@ -2827,7 +2827,7 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota" - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.mul"([[IOTA]], [[DELTA]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[DELTA]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} // CHECK: "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> return %3 : tensor<5xf32> @@ -2838,10 +2838,10 @@ func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { // CHECK-DAG: [[NUM:%.*]] = xla_hlo.constant dense<4> // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = "xla_hlo.convert"([[NUM]]) - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.sub [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = xla_hlo.div [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] + // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]] + // CHECK-DAG: [[STEP:%.*]] = xla_hlo.divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.mul"([[IOTA]], [[STEP]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[STEP]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} // CHECK-DAG: [[LINSPACE:%.*]] = "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor @@ -3022,13 +3022,13 @@ func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor) { // CHECK: %[[CONST:.*]] = xla_hlo.constant dense<1> // CHECK: %[[DIM_0:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 0 - // CHECK: %[[MUL_0:.*]] = xla_hlo.mul %[[CONST]], %[[DIM_0]] + // CHECK: %[[MUL_0:.*]] = xla_hlo.multiply %[[CONST]], %[[DIM_0]] // CHECK: %[[DIM_1:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 1 - // CHECK: %[[MUL_1:.*]] = xla_hlo.mul %[[MUL_0]], %[[DIM_1]] + // CHECK: %[[MUL_1:.*]] = xla_hlo.multiply %[[MUL_0]], %[[DIM_1]] // CHECK: %[[DIM_2:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 2 - // CHECK: %[[MUL_2:.*]] = xla_hlo.mul %[[MUL_1]], %[[DIM_2]] + // CHECK: %[[MUL_2:.*]] = xla_hlo.multiply %[[MUL_1]], %[[DIM_2]] %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor // CHECK: return %[[MUL_2]] return %size : tensor @@ -3240,7 +3240,7 @@ func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor : tensor<2xi64>} : (tensor) -> tensor<4x64xf32> // CHECK: [[SCATTER:%.*]] = "xla_hlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[MUL:%.*]] = xla_hlo.mul [[LHS]], [[RHS]] : tensor + // CHECK: [[MUL:%.*]] = xla_hlo.multiply [[LHS]], [[RHS]] : tensor // CHECK: "xla_hlo.return"([[MUL]]) // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor, tensor<8x?x64xf32>) -> tensor<4x?xf32> // CHECK: return [[SCATTER]] @@ -3253,7 +3253,7 @@ func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor : tensor} : () -> tensor // CHECK: xla_hlo.constant dense<0x7F800000> : tensor // CHECK: xla_hlo.scatter - // CHECK: xla_hlo.min + // CHECK: xla_hlo.minimum %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) return %0: tensor<4x?xf32> } @@ -3263,7 +3263,7 @@ func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor : tensor} : () -> tensor // CHECK: xla_hlo.constant dense<0xFF800000> : tensor // CHECK: xla_hlo.scatter - // CHECK: xla_hlo.max + // CHECK: xla_hlo.maximum %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) return %0: tensor<4x?xf32> } @@ -3611,7 +3611,7 @@ func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> // CHECK: "xla_hlo.return"([[ADD]]) // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = xla_hlo.constant dense<4.000000e+00> : tensor - // CHECK: [[DIV:%.+]] = "xla_hlo.div"([[REDUCE]], [[COUNT]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> + // CHECK: [[DIV:%.+]] = "xla_hlo.divide"([[REDUCE]], [[COUNT]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[CONV16:%.+]] = "xla_hlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> // CHECK: return [[CONV16]] %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index da6adf8cbe1..5bb965fa320 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -6,13 +6,13 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32> - %1 = "xla_hlo.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> - %2 = "xla_hlo.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32> - %3 = "xla_hlo.div"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32> %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> @@ -27,13 +27,13 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32> - %1 = "xla_hlo.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32> - %2 = "xla_hlo.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32> - %3 = "xla_hlo.div"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32> %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> @@ -52,18 +52,18 @@ func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tens name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %1 = "xla_hlo.mul"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %1 = "xla_hlo.mul"(%0, %arg1) { + // CHECK-NEXT: %1 = "xla_hlo.multiply"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> + %1 = "xla_hlo.multiply"(%0, %arg1) { name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %2 = "xla_hlo.sub"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %2 = "xla_hlo.sub"(%1, %arg1) { + // CHECK-NEXT: %2 = "xla_hlo.subtract"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> + %2 = "xla_hlo.subtract"(%1, %arg1) { name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %3 = "xla_hlo.div"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %3 = "xla_hlo.div"(%2, %arg1) { + // CHECK-NEXT: %3 = "xla_hlo.divide"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> + %3 = "xla_hlo.divide"(%2, %arg1) { name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir index 74fea0cc687..0aa7834b4fb 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir @@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> // CHECK: return - "xla_lhlo.min"(%lhs, %rhs, %result) {name = "min.1"} : + "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () return } @@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: divf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.div"(%lhs, %rhs, %result) {name = "div.1"} + "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: divi_signed %{{.*}}, %{{.*}} : i32 - "xla_lhlo.div"(%lhs, %rhs, %result) {name = "div.1"} + "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 - "xla_lhlo.max"(%lhs, %rhs, %result) {name = "max.1"} + "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 - "xla_lhlo.max"(%lhs, %rhs, %result) {name = "max.1"} + "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 - "xla_lhlo.min"(%lhs, %rhs, %result) {name = "min.1"} + "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 - "xla_lhlo.min"(%lhs, %rhs, %result) {name = "min.1"} + "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: mulf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.mul"(%lhs, %rhs, %result) {name = "mul.1"} + "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: muli %{{.*}}, %{{.*}} : i32 - "xla_lhlo.mul"(%lhs, %rhs, %result) {name = "mul.1"} + "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: subf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.sub"(%lhs, %rhs, %result) {name = "sub.1"} + "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: subi %{{.*}}, %{{.*}} : i32 - "xla_lhlo.sub"(%lhs, %rhs, %result) {name = "sub.1"} + "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index 5d0c767a716..5e4a7fd719f 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -47,7 +47,7 @@ func @element_wise_scalar(%lhs: memref, %rhs: memref, // CHECK-LABEL: func @minf func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.min"(%lhs, %rhs, %result) + "xla_lhlo.minimum"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -62,7 +62,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @maxi func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.max"(%lhs, %rhs, %result) + "xla_lhlo.maximum"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 2953fc84d71..04d9d23fe8b 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -90,7 +90,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @div_memref func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.div"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -98,7 +98,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @max_memref func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.max"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -106,7 +106,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @min_memref func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.min"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -114,7 +114,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @mul_memref func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.mul"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -122,7 +122,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @sub_memref func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.sub"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -172,7 +172,7 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m %1 = tensor_load %input2 : memref<10xf32> %2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %3 = tensor_load %input3 : memref<10xf32> - %4 = "xla_hlo.mul"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> tensor_store %4, %out : memref<10xf32> "xla_lhlo.terminator"() : () -> () } ) : () -> () diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir index 0c0ac91beb0..915771923d0 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir @@ -50,9 +50,9 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.sub %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.sub %arg1, %arg3 - %4 = "xla_hlo.sub"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3 + %4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) @@ -65,9 +65,9 @@ func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : te %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.sub"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.sub"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.sub"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.subtract"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.subtract"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} + %4 = "xla_hlo.subtract"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) @@ -80,9 +80,9 @@ func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.sub %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.sub %arg1, %arg3 - %4 = "xla_hlo.sub"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3 + %4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) @@ -95,13 +95,13 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg3 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul %arg0, %arg3 - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg1, %arg2 + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2 // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.mul"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) @@ -114,13 +114,13 @@ func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : te %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.mul"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.mul"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.mul"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.mul"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.multiply"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.mul"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) + %4 = "xla_hlo.multiply"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) @@ -133,13 +133,13 @@ func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg3 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul %arg0, %arg3 - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg1, %arg2 + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2 // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.mul"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) @@ -156,26 +156,26 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // Compute the numerator's real component: // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg0, %arg2 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.mul %arg1, [[VAL0]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]] + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.mul %arg1, %arg2 - // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.mul %arg0, [[VAL0]] + // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]] // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.div [[VAL3]], [[VAL6]] - // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.div [[VAL9]], [[VAL6]] - %4 = "xla_hlo.div"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]] + %4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) @@ -195,26 +195,26 @@ func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : te // Compute the numerator's real component: // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.mul"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.mul"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]] + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.multiply"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.mul"(%arg1, %arg2) - // CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.mul"(%arg0, [[VAL0]]) + // CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) + // CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.multiply"(%arg0, [[VAL0]]) // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.div"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.div"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.div"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) + // CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.divide"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.divide"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + %4 = "xla_hlo.divide"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) @@ -234,26 +234,26 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // Compute the numerator's real component: // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg0, %arg2 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.mul %arg1, [[VAL0]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]] + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.mul %arg1, %arg2 - // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.mul %arg0, [[VAL0]] + // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]] // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.div [[VAL3]], [[VAL6]] - // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.div [[VAL9]], [[VAL6]] - %4 = "xla_hlo.div"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]] + %4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) @@ -266,8 +266,8 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg0 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg1 + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1 // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]] // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]]) %1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) @@ -284,8 +284,8 @@ func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tenso // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0) // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1) // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul [[VAL0]], [[VAL2]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] %1 = "xla_hlo.exp"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) %3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) @@ -301,8 +301,8 @@ func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf3 // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0) // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1) // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul [[VAL0]], [[VAL2]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] %1 = "xla_hlo.exp"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) %2 = "xla_hlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) %3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index 682b153d474..91eb7493648 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -88,8 +88,8 @@ func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor< func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.div %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.divide %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> + %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -99,8 +99,8 @@ func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x func @maxBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.max %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.max"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.maximum %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> + %0 = "xla_hlo.maximum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -110,8 +110,8 @@ func @maxBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x func @minBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.min %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.min"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.minimum %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> + %0 = "xla_hlo.minimum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -121,8 +121,8 @@ func @minBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x func @mulBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.mul %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.multiply %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> + %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -187,8 +187,8 @@ func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32> func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.sub %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.subtract %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> + %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 8af27bb586a..2436ef32c07 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -18,7 +18,7 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { %0 = "xla_hlo.all_reduce"(%arg0) ({ // Perform max reduction inside the region ^bb0(%lhs: tensor, %rhs: tensor): - %max = xla_hlo.max %lhs, %rhs : tensor + %max = xla_hlo.maximum %lhs, %rhs : tensor "xla_hlo.return"(%max) : (tensor) -> () }) { @@ -210,7 +210,7 @@ func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - %1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "xla_hlo.multiply"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0, %1 : tensor<4xi32>, tensor<4xi32> } @@ -605,8 +605,8 @@ func @main(%token: !xla_hlo.token) -> tuple, !xla_hlo.token> { func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor, %arg3 : tensor) -> (tensor<1xf32>, tensor<1xi32>) { %result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( { ^bb0(%fa: tensor, %ia : tensor, %fb: tensor, %ib: tensor): // no predecessors - %fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor, tensor) -> tensor - %imax = "xla_hlo.max"(%ia, %ib) {} : (tensor, tensor) -> tensor + %fmax = "xla_hlo.maximum"(%fa, %fb) {} : (tensor, tensor) -> tensor + %imax = "xla_hlo.maximum"(%ia, %ib) {} : (tensor, tensor) -> tensor "xla_hlo.return"(%fmax, %imax) : (tensor, tensor) -> () }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>) return %result0, %result1 : tensor<1xf32>, tensor<1xi32> @@ -632,7 +632,7 @@ func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> { %0 = xla_hlo.constant dense<-2147483648> : tensor %1 = "xla_hlo.reduce_window"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = xla_hlo.max %arg1, %arg2 : tensor + %2 = xla_hlo.maximum %arg1, %arg2 : tensor "xla_hlo.return"(%2) : (tensor) -> () }) { window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt index 01a24c06d2c..05d6a2a9af2 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -30,7 +30,7 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %5 = "xla_hlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} - // CHECK-NEXT: %6 = xla_hlo.mul %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32> + // CHECK-NEXT: %6 = xla_hlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32> %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9) // CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor @@ -85,7 +85,7 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %18 = xla_hlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32> %add.39 = f32[300,5] add(%dot.36, %broadcast.38) - // CHECK-NEXT: %19 = xla_hlo.max %10, %18 {name = "maximum.42"} : tensor<300x5xf32> + // CHECK-NEXT: %19 = xla_hlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32> %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39) // CHECK-NEXT: %20 = "xla_hlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 23307427e97..54ba9704ac5 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -302,7 +302,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.div %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: xla_hlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -512,7 +512,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.max %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -521,7 +521,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.min %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -530,7 +530,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -692,7 +692,7 @@ add { // CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor %reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3 - // CHECK: %4 = xla_hlo.sub [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor + // CHECK: %4 = xla_hlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor %sub.5 = f32[] subtract(%reduce.3, %reduce.4) ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5) @@ -858,7 +858,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.sub %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: xla_hlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } diff --git a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir index b5e1eaf104a..12d542ce5f6 100644 --- a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir @@ -18,9 +18,9 @@ func @batchNormInference_2D_inner_features( // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.sub %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.mul %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.div %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : @@ -125,9 +125,9 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor - // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.sub %[[X]], %[[MEAN_BCAST]] : tensor - // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.mul %[[X_CENTER]], %[[SCALE_BCAST]] : tensor - // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.div %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor + // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor + // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor + // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 0.001 : f32, feature_index = 1 : i64} : diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index eb6e7e1cd3d..b2b17a8dd75 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -276,7 +276,7 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // %2 = "xla_hlo.add"(%0, %1) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // %3 = tensor_load %arg0 : memref<2x2xf32> -// %4 = "xla_hlo.mul"(%2, %3) : +// %4 = "xla_hlo.multiply"(%2, %3) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () @@ -293,7 +293,7 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // %0 = alloc() : memref<2x2xf32> // "xla_lhlo.add"(%arg1, %arg2, %0) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// "xla_lhlo.mul"(%0, %arg0, %arg3) : +// "xla_lhlo.multiply"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // dealloc %0 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () @@ -305,7 +305,7 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // FuncOp signature conversion example: // // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -// %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> +// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> // tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>, // tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> // } @@ -318,7 +318,7 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // %arg2: memref<4xf32>) { // %0 = alloc() : memref<4xf32> // %1 = alloc() : memref<4xf32> -// "xla_lhlo.max"(%arg0, %arg1, %0) : +// "xla_lhlo.maximum"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // "xla_lhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 2fc98ebd676..f6358d6cde7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -1299,7 +1299,7 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Sample result for VALID padding mode: // // %init = constant dense<...> : tensor -// %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"] +// %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.maximum"] // {window_dimensions = ..., window_strides = ... } // class ConvertMaxPoolOp : public OpRewritePattern { @@ -1421,11 +1421,11 @@ class ConvertSelectV2Op : public OpRewritePattern { // : (tensor) -> tensor<2xf32> // // // Compute Tanh of half the logits of the values. -// %halved_logits = xla_hlo.mul %logits, %half_array : tensor<2xf32> +// %halved_logits = xla_hlo.multiply %logits, %half_array : tensor<2xf32> // %tanh = "xla_hlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32> // // // Have the result of Tanh and add 0.5. -// %halved_tanh = xla_hlo.mul %tanh, %half : tensor<2xf32> +// %halved_tanh = xla_hlo.multiply %tanh, %half : tensor<2xf32> // %sigmoid = xla_hlo.add %halved_tanh, %half : tensor<2xf32> // class ConvertSigmoidOp : public OpRewritePattern { @@ -1479,7 +1479,7 @@ class ConvertSigmoidOp : public OpRewritePattern { // // stability. // %max = "tf.Max"(%input, %reduce_dim) // : (tensor, tensor<1xi64>) -> tensor -// %sub = "xla_hlo.sub"(%inp, %max) {broadcast_dimensions = 0} +// %sub = "xla_hlo.subtract"(%inp, %max) {broadcast_dimensions = 0} // : (tensor, tensor) -> tensor // // %exp = "xla_hlo.exp"(%sub) : (tensor) -> tensor @@ -1487,7 +1487,7 @@ class ConvertSigmoidOp : public OpRewritePattern { // : (tensor, tensor<1xi64>) -> tensor // // // Softmax computation: -// %softmax = "xla_hlo.div"(%exp, %sum_f16) {broadcast_dimensions = 0} +// %softmax = "xla_hlo.divide"(%exp, %sum_f16) {broadcast_dimensions = 0} // : (tensor, tensor) -> tensor template class ConvertSoftmaxOp : public OpRewritePattern { @@ -1559,13 +1559,13 @@ class ConvertSoftmaxOp : public OpRewritePattern { // %const = xla_hlo.constant dense<1> : tensor // %dim_0 = "xla_hlo.get_dimension_size"(%input) {dimension = 0 : i32} : // (tensor<2x?x8xf32>) -> tensor -// %prod_0 = xla_hlo.mul %const, %dim_0 : tensor +// %prod_0 = xla_hlo.multiply %const, %dim_0 : tensor // %dim_1 = "xla_hlo.get_dimension_size"(%input) {dimension = 1 : i32} : // (tensor<2x?x8xf32>) -> tensor -// %prod_1 = xla_hlo.mul %prod_0, %dim_1 : tensor +// %prod_1 = xla_hlo.multiply %prod_0, %dim_1 : tensor // %dim_2 = "xla_hlo.get_dimension_size"(%input) {dimension = 2 : i32} : // (tensor<2x?x8xf32>) -> tensor -// %size = xla_hlo.mul %prod_1, %dim_2 : tensor +// %size = xla_hlo.multiply %prod_1, %dim_2 : tensor class ConvertSizeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2064,7 +2064,7 @@ class ConvertStridedSliceGradOp /// /// Output would be: /// %iota = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32> -/// %scaled = "xla_hlo.mul"(%iota, %delta) +/// %scaled = "xla_hlo.multiply"(%iota, %delta) /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : /// (tensor<5xf32>, tensor) -> tensor<5xf32> /// %result = "xla_hlo.add"(%scaled, %offset) @@ -2233,7 +2233,7 @@ class GenericConvertReductionOp : public OpRewritePattern { // %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"] // {dimensions = ...} // %divisor = constant dense<...> : tensor -// %mean = "xla_hlo.div"(%sum, %divisor) +// %mean = "xla_hlo.divide"(%sum, %divisor) class ConvertMeanOp : public GenericConvertReductionOp { public: @@ -2263,7 +2263,7 @@ class ConvertSumOp // Converts Max op to HLO Reduce op. // // %init = constant dense<...> : tensor -// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"] +// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.maximum"] // {dimensions = ...} class ConvertMaxOp : public GenericConvertReductionOp : tensor -// %min = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.min"] +// %min = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.minimum"] // {dimensions = ...} class ConvertMinOp : public GenericConvertReductionOp : tensor -// %prod = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.mul"] +// %prod = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.multiply"] // {dimensions = ...} class ConvertProdOp : public GenericConvertReductionOp { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo index f4f2e4d2c91..58cba9711f3 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo @@ -14,7 +14,7 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { // CHECK: %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]] // CHECK: %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]] // CHECK: %[[ADD:.*]] = xla_hlo.add %[[REF1]], %[[REF2]] -// CHECK: %[[MUL:.*]] = xla_hlo.mul %[[ADD]], %[[REF0]] +// CHECK: %[[MUL:.*]] = xla_hlo.multiply %[[ADD]], %[[REF0]] // CHECK: tensor_store %[[MUL]], %[[RESULT]] // CHECK: "xla_lhlo.terminator"() // CHECK-NEXT: } From b7dedbffc2e97eb3a3296679b2db93be0ea49bc0 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Tue, 17 Mar 2020 14:11:00 -0700 Subject: [PATCH 098/492] [XLA] Create explicit phi graph optimization during dataflow analysis. Previously hlo dataflow analysis tries to create phi and remove phi at the same time during propagation, which leads to several cases of deadlocks, some are hard to fix. This cl changes this process to two phases: 1. During value propagation, Dataflow analysis always create phi values once it see multiple inputs merging at the same point. It then records those phi values as well as their inputs in a phi graph. 2. Post value propagation, Dataflow analysis can then do certain optimization on the phi graph to prune uncessary phi nodes. Both of the functions are guaranteed to exit thus we can avoid deadlocks. PiperOrigin-RevId: 301449515 Change-Id: I85f545ed9935ad5aee85b3f5bc05c2ba19da074a --- tensorflow/compiler/xla/service/BUILD | 52 ++++ .../xla/service/hlo_dataflow_analysis.cc | 111 ++++++-- .../xla/service/hlo_dataflow_analysis.h | 29 +- .../xla/service/hlo_dataflow_analysis_test.cc | 253 +++++++++++------- .../compiler/xla/service/hlo_phi_graph.cc | 233 ++++++++++++++++ .../compiler/xla/service/hlo_phi_graph.h | 100 +++++++ .../xla/service/hlo_phi_graph_test.cc | 86 ++++++ 7 files changed, 733 insertions(+), 131 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_phi_graph.cc create mode 100644 tensorflow/compiler/xla/service/hlo_phi_graph.h create mode 100644 tensorflow/compiler/xla/service/hlo_phi_graph_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 98851fddd2d..925afd689f7 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2763,6 +2763,7 @@ cc_library( ":call_graph", ":hlo", ":hlo_casting_utils", + ":hlo_phi_graph", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -2771,10 +2772,13 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -2783,6 +2787,7 @@ tf_cc_test( name = "hlo_dataflow_analysis_test", srcs = ["hlo_dataflow_analysis_test.cc"], deps = [ + ":flatten_call_graph", ":hlo", ":hlo_creation_utils", ":hlo_dataflow_analysis", @@ -2803,6 +2808,53 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_phi_graph", + srcs = ["hlo_phi_graph.cc"], + hdrs = ["hlo_phi_graph.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_casting_utils", + ":hlo_value", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_phi_graph_test", + srcs = ["hlo_phi_graph_test.cc"], + deps = [ + ":hlo", + ":hlo_dataflow_analysis", + ":hlo_graph_dumper", + ":hlo_matchers", + ":hlo_ordering", + ":hlo_phi_graph", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_replication_analysis", srcs = ["hlo_replication_analysis.cc"], diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 36da176b62f..6a0b9e5dfb8 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -17,18 +17,23 @@ limitations under the License. #include #include +#include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -118,10 +123,11 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { } void HloDataflowAnalysis::DeleteMarkedValues() { -#ifndef NDEBUG - // Verify that no marked-for-deletion values are in any of the value sets. + // Use a set to prevent deleting an id twice. absl::flat_hash_set id_set(value_ids_to_delete_.begin(), value_ids_to_delete_.end()); +#ifndef NDEBUG + // Verify that no marked-for-deletion values are in any of the value sets. for (const auto& pair : value_sets_) { const HloInstruction* instruction = pair.first; const InstructionValueSet& instruction_value_set = pair.second; @@ -138,7 +144,7 @@ void HloDataflowAnalysis::DeleteMarkedValues() { } #endif - for (HloValue::Id value_id : value_ids_to_delete_) { + for (HloValue::Id value_id : id_set) { values_.erase(value_id); } value_ids_to_delete_.clear(); @@ -216,22 +222,13 @@ bool HloDataflowAnalysis::Phi( const HloValue* current_value = value_set.values().size() == 1 ? value_set.values()[0] : nullptr; - // Construct a vector of unique value IDs of the inputs. - // Don't add value ids where the input is equal to the definition. + // Construct a vector of value IDs of the inputs. std::vector input_value_ids; for (const InstructionValueSet* input : inputs) { for (const HloValue* value : input->element(index).values()) { - if (value->defining_instruction() == instruction && - value->defining_index() == index) { - continue; - } input_value_ids.push_back(value->id()); } } - absl::c_sort(input_value_ids); - input_value_ids.erase( - std::unique(input_value_ids.begin(), input_value_ids.end()), - input_value_ids.end()); // Remove the existing phi value (if it exists). The phi can be its own // input, for example, in while body parameters where the body passes @@ -240,14 +237,7 @@ bool HloDataflowAnalysis::Phi( (current_value != nullptr && current_value->defining_instruction() == instruction && current_value->defining_index() == index); - if (current_value_defined_here) { - VLOG(5) << "current_value_defined_here: " << current_value->ToString(); - CHECK(current_value->is_phi()); - auto it = absl::c_find(input_value_ids, current_value->id()); - if (it != input_value_ids.end()) { - input_value_ids.erase(it); - } - } + VLOG(5) << "after input_value_ids.size = " << input_value_ids.size(); if (input_value_ids.empty()) { // A value set which has at least one element should never have its value @@ -277,11 +267,33 @@ bool HloDataflowAnalysis::Phi( // Multiple distinct values reach this point. A phi value is // necessary. CHECK_GT(input_value_ids.size(), 1); - if (current_value == nullptr || - !(current_value->is_phi() && current_value_defined_here)) { + bool phi_defined_here = + current_value_defined_here && current_value->is_phi(); + if (current_value == nullptr || !phi_defined_here) { value_set.Clear(); value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true)); + + std::vector inputs; + inputs.reserve(input_value_ids.size()); + for (HloValue::Id id : input_value_ids) { + inputs.push_back(&GetValue(id)); + } + // Register the phi into phi graph. + phi_graph_.RegisterPhi(*value_set.values()[0], inputs); changed = true; + } else if (phi_defined_here) { + std::vector new_inputs; + new_inputs.reserve(input_value_ids.size()); + for (HloValue::Id id : input_value_ids) { + new_inputs.push_back(&GetValue(id)); + } + + if (!phi_graph_.InputsEqualTo(*current_value, new_inputs)) { + VLOG(1) << current_value->ToShortString() << " has new phi inputs: "; + // Update phi inputs. + phi_graph_.RegisterPhi(*current_value, new_inputs); + changed = true; + } } } } @@ -564,9 +576,9 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { CHECK_EQ(parameter->parameter_number(), 0); inputs.push_back( &GetInstructionValueSet(callsite.instruction()->operand(0))); - // If the parameter *is* the root, then don't consider it's current state - // (InstructionValueSet) as we are recomputing its current - // state. Otherwise, the parameter state would never be updated. + // If the parameter *is not* the root, parameter state would be + // updated by the root, otherwise don't consider it's current state + // (InstructionValueSet) as we are recomputing its current state. if (parameter != callsite.instruction()->while_body()->root_instruction()) { inputs.push_back(&GetInstructionValueSet( @@ -599,7 +611,6 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { "called from call, while, or conditional instructions"; } } - if (ssa_form_ && need_phi) { return Phi(parameter, inputs); } else { @@ -722,10 +733,18 @@ void HloDataflowAnalysis::Propagate() { add_to_worklist(instruction); } } + VLOG(1) << "SSA_FORM_: " << ssa_form_; while (!worklist.empty()) { HloInstruction* instruction = worklist.front(); + auto add_to_worklist = [&](HloInstruction* todo) { + if (workset.insert(todo).second) { + VLOG(1) << " Adding todo : " << todo->name(); + worklist.push(todo); + } + }; worklist.pop(); + workset.erase(workset.find(instruction)); VLOG(3) << "Worklist top: " << instruction->name(); @@ -913,6 +932,43 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { return Status::OK(); } +void HloDataflowAnalysis::OptimizePhiValues() { + // Only applicable to SSA form where phis are defined. + if (!ssa_form_) { + return; + } + + VLOG(1) << "Before phi graph optimization"; + XLA_VLOG_LINES(1, phi_graph_.ToString()); + phi_graph_.Optimize(); + VLOG(1) << "After phi graph optimization"; + XLA_VLOG_LINES(1, phi_graph_.ToString()); + + for (const HloComputation* computation : module_.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + InstructionValueSet& instruction_value_set = + GetInstructionValueSet(instruction); + VLOG(1) << "inst: " << instruction->name(); + VLOG(1) << instruction_value_set.ToString(); + instruction_value_set.ForEachMutableElement( + [&](const xla::ShapeIndex& index, HloValueSet* value_set) { + auto values = value_set->values(); + if (!(values.size() == 1 && values[0]->is_phi())) { + return; + } + HloValue::Id phi_id = values[0]->id(); + HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id); + if (new_id != phi_id) { + value_set->Clear(); + const HloValue& new_value = GetValue(new_id); + value_set->AddValue(&new_value); + MarkValueForDeletion(phi_id); + } + }); + } + } +} + /* static */ StatusOr> HloDataflowAnalysis::Run( const HloModule& module, bool ssa_form, bool bitcast_defines_value, @@ -925,6 +981,7 @@ StatusOr> HloDataflowAnalysis::Run( TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); dataflow_analysis->Propagate(); + dataflow_analysis->OptimizePhiValues(); // Delete all values marked for deletion. dataflow_analysis->DeleteMarkedValues(); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 294ffea6792..75bcf7ea318 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -20,15 +20,19 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ +#include #include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_phi_graph.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -60,7 +64,8 @@ class HloDataflowAnalysis { // SSA form is minimal in that a new phi value is defined only if the // merge point is reachable by multiple different values. The SSA form is // also in loop-closed form in that no values defined inside of a loop - // (while body) is used outside of the loop. + // (while body) is used outside of the loop. Example use of this ssa_form + // mode is to reason about live range interference of buffers. // // If ssa_form is false, then merge points do not define new // values. Rather, the HloValueSet for the merge point contains the union @@ -138,8 +143,8 @@ class HloDataflowAnalysis { // Returns true if 'user' cannot possibly use the buffer at 'index' in // 'operand'. Returns false otherwise. // - // 'operand' does not have to be an operand of 'user'. This can be the case - // with indirect uses. + // 'operand' does not have to be an operand of 'user'. This can be the + // case with indirect uses. bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const; @@ -160,9 +165,22 @@ class HloDataflowAnalysis { bool bitcast_defines_value = false, const CanShareBuffer& can_share_buffer = nullptr); + // 1. During value propagation (Propagate function), always create phi + // values once it see multiple inputs merging at the same point. It then + // records those phi values as well as their inputs in a phi graph. + // + // 2. Post value propagation, Dataflow analysis can then do certain + // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi + // nodes. + // + // Note that this applies in SSA form, and Both of the functions are + // guaranteed to exit. + // + void OptimizePhiValues(); + // Returns a new HloValue defined at the given instruction and shape index. HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, - bool is_phi = false); + bool is_phi); // Marks the HloValue with the given ID for deletion. void MarkValueForDeletion(HloValue::Id value_id); @@ -248,6 +266,9 @@ class HloDataflowAnalysis { // The Id to use for the next HloValue. HloValue::Id next_value_id_ = 0; + // An explicit graph holding phi values and edges. + PhiGraph phi_graph_; + // Backend specific function that decides whether an instruction can share // a buffer with its operand. CanShareBuffer can_share_buffer_ = nullptr; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 074d14fd810..1bbbb248bbc 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" @@ -50,6 +51,8 @@ class HloDataflowAnalysisTest : public HloTestBase, // reference to the generated analysis stored in analysis_. const HloDataflowAnalysis& RunAnalysis(bool ssa_form, bool bitcast_defines_value = false) { + FlattenCallGraph flatten; + EXPECT_TRUE(flatten.Run(module_.get()).ok()); analysis_ = HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); @@ -299,102 +302,6 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } -TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { - // Test a subcomputation which is called twice with identical values. - auto subbuilder = HloComputation::Builder("Subcomputation"); - auto subparam0 = subbuilder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param0")); - auto subparam1 = subbuilder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape_, "param1")); - auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( - scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); - HloComputation* called_computation = - module_->AddEmbeddedComputation(subbuilder.Build()); - - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); - auto call1 = builder.AddInstruction(HloInstruction::CreateCall( - scalar_shape_, {constant1, constant2}, called_computation)); - auto call2 = builder.AddInstruction(HloInstruction::CreateCall( - scalar_shape_, {constant1, constant2}, called_computation)); - auto sub = builder.AddInstruction(HloInstruction::CreateBinary( - scalar_shape_, HloOpcode::kSubtract, call1, call2)); - module_->AddEntryComputation(builder.Build()); - SCOPED_TRACE(module_->ToString()); - - bool ssa_form = GetParam(); - const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - - EXPECT_EQ(analysis.values().size(), 4); - - // Definitions should be identical to the single callsite case. - EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); - EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1)); - EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(call1)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(call2)); - EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); - - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}}, - HloUse{add, 0, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}}, - HloUse{add, 1, {}})); - // The Add from the subcomputation is used as both operands of the Subtract. - EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), - UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); - - EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); -} - -TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { - // Test a subcomputation which is called twice with different argument values. - auto subbuilder = HloComputation::Builder("Subcomputation"); - auto subparam0 = subbuilder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param0")); - auto subparam1 = subbuilder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape_, "param1")); - auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( - scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); - HloComputation* called_computation = - module_->AddEmbeddedComputation(subbuilder.Build()); - - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); - auto call1 = builder.AddInstruction(HloInstruction::CreateCall( - scalar_shape_, {constant1, constant2}, called_computation)); - auto call2 = builder.AddInstruction(HloInstruction::CreateCall( - scalar_shape_, {call1, constant2}, called_computation)); - module_->AddEntryComputation(builder.Build()); - SCOPED_TRACE(module_->ToString()); - - bool ssa_form = GetParam(); - const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - - EXPECT_FALSE(analysis.ValueIsDefinedAt(call1)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(call2)); - - EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0)); - - EXPECT_THAT(HloValuesAt(subparam0), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), - analysis.GetValueDefinedAt(add))); - EXPECT_THAT(HloValuesAt(subparam1), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant2))); - - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); -} - TEST_P(HloDataflowAnalysisTest, NestedCalls) { // Test a module with nested computations. HLO is: // @@ -637,6 +544,100 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); } +TEST_P(HloDataflowAnalysisTest, MultiLevelNestedWhile) { + // Test nested while instructions. The level0 body (most inner while) and + // level1 body pass through the parameter, while level2 (most outer while) + // modifies it. + // + // level0_body((F32[]) %tuple_param): + // return Tuple(%tuple_param{0}) + // + // level1_body((F32[]) %tuple_param): + // return While(%tuple_param{0}), body=level0 + // + // level2_body((F32[]) %tuple_param): + // while = While(%tuple_param{0}), body=level1 + //. return negate(%while{0}) + // + // entry: + // %constant = Constant(1.0) + // %tuple = Tuple(%constant) + // return While(%tuple), body=level2 + // + const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_}); + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + // level 0 passes transparently through the body. + auto level0_builder = HloComputation::Builder("level0_body"); + auto level0_param = level0_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto level0_element_0 = level0_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, level0_param, 0)); + auto level0_root = level0_builder.AddInstruction( + HloInstruction::CreateTuple({level0_element_0})); + HloComputation* level0_body = + module_->AddEmbeddedComputation(level0_builder.Build()); + + // Element 1 passes transparently through the body. + auto level1_builder = HloComputation::Builder("level1_body"); + auto level1_param = level1_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto level1_root = level1_builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, condition, level0_body, level1_param)); + HloComputation* level1_body = + module_->AddEmbeddedComputation(level1_builder.Build()); + + // Element 1 passes transparently through the body. + auto level2_builder = HloComputation::Builder("level2_body"); + auto level2_param = level2_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto level2_while = level2_builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, condition, level1_body, level2_param)); + auto level2_element_0 = level2_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, level2_while, 0)); + auto negate = level2_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, level2_element_0)); + level2_builder.AddInstruction(HloInstruction::CreateTuple({negate})); + HloComputation* level2_body = + module_->AddEmbeddedComputation(level2_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); + builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, level2_body, tuple)); + module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + + bool ssa_form = GetParam(); + if (!ssa_form) { + return; + } + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + // Phi node on inner parameters and roots should have been eliminated. + EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_param, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_param, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_root, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_root, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(level2_param, /*index=*/{0})); + EXPECT_EQ(HloValuesAt(level1_param, /*index=*/{0}), + HloValuesAt(level2_param, /*index=*/{0})); + EXPECT_EQ(HloValuesAt(level0_param, /*index=*/{0}), + HloValuesAt(level2_param, /*index=*/{0})); + EXPECT_EQ(HloValuesAt(level1_root, /*index=*/{0}), + HloValuesAt(level2_param, /*index=*/{0})); + EXPECT_EQ(HloValuesAt(level0_root, /*index=*/{0}), + HloValuesAt(level2_param, /*index=*/{0})); +} + TEST_P(HloDataflowAnalysisTest, NestedWhiles) { // Test nested while instructions. The inner body passes through element 0 of // its parameter, and the outer body passes through element 1. HLO: @@ -757,6 +758,58 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { } } +TEST_P(HloDataflowAnalysisTest, SwizzlingWhileSharedInput) { + // Test a while instruction with a body which permutes it's tuple parameter + // elements. HLO: + // + // body((F32[], F32[]) %tuple_param): + // return Tuple(%tuple_param{1}, %tuple_param{0}) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %tuple = Tuple(%constant1, %constant1) + // return While(%tuple, body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_1, body_element_0})); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant1})); + builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); + module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); +} + TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { // Test a while instruction with a body which permutes it's tuple parameter // elements. HLO: @@ -1621,8 +1674,8 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { DependencyHloOrdering ordering(module_.get()); - // Exp only use is the call so it should not interfere with values inside the - // embedded computation. + // Exp only use is the call so it should not interfere with values inside + // the embedded computation. EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log)); // Negate is live across the call and should interfere with values in the @@ -2134,8 +2187,8 @@ TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { // The fusion instruction never uses tuple element 0, but does use element 1. EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); - // The same holds for the parameter tuple, except that the tuple elements are - // swapped in 'tuple'. + // The same holds for the parameter tuple, except that the tuple elements + // are swapped in 'tuple'. EXPECT_TRUE( dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion)); EXPECT_FALSE( diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph.cc b/tensorflow/compiler/xla/service/hlo_phi_graph.cc new file mode 100644 index 00000000000..9b69771dab2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_phi_graph.cc @@ -0,0 +1,233 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_phi_graph.h" + +#include + +namespace xla { +HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) { + Node* node = value_id_to_node_[value.id()]; + return node->value_id; +} + +// Returns true if the input to a hlo value is the same as `inputs`. +bool PhiGraph::InputsEqualTo(const HloValue& value, + absl::Span inputs) { + auto iter = value_id_to_node_.find(value.id()); + CHECK(iter != value_id_to_node_.end()); + absl::flat_hash_set existing_set; + for (Node* operand : iter->second->operands) { + existing_set.insert(operand->value_id); + } + absl::flat_hash_set new_set; + for (const HloValue* input : inputs) { + new_set.insert(input->id()); + } + return existing_set == new_set; +} + +HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) { + auto iter = value_id_to_node_.find(id); + CHECK(iter != value_id_to_node_.end()); + return iter->second->value_id; +} + +PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) { + auto iter = value_id_to_node_.find(value.id()); + if (iter == value_id_to_node_.end()) { + node_storage_.emplace_back(absl::make_unique()); + Node* node = node_storage_.back().get(); + node->value_id = value.id(); + value_id_to_node_[value.id()] = node; + node_to_value_id_[node].push_back(value.id()); + return node; + } else { + // A node is already registered with this value, check the value_id + // is the same as previously registrated. + CHECK_NE(iter->second, nullptr); + CHECK_EQ(iter->second->value_id, value.id()); + return iter->second; + } +} + +void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) { + // Update users. + CHECK(node->is_phi); + for (Node* user : node->users) { + absl::c_replace(user->operands, node, replace); + } + + // Update operand's users + for (Node* operand : node->operands) { + absl::c_replace(operand->users, node, replace); + } + for (HloValue::Id value_id : node_to_value_id_[node]) { + CHECK(value_id_to_node_.contains(value_id)); + value_id_to_node_[value_id] = replace; + } + // Update mappings to HloValue::Id. + absl::c_copy(node_to_value_id_[node], + std::back_inserter(node_to_value_id_[replace])); + node_to_value_id_[node].clear(); + node->mark_as_dead = true; +} + +void PhiGraph::RegisterPhi(const HloValue& value, + absl::Span inputs) { + Node* node = CreateOrReuseNode(value); + CHECK(value.is_phi()); + node->is_phi = true; + node->operands.clear(); + for (auto input : inputs) { + CHECK(input != nullptr); + Node* input_node = CreateOrReuseNode(*input); + node->operands.push_back(input_node); + } +} + +std::string PhiGraph::ToString() { + std::string out = "PhiGraph: \n"; + for (auto& node : node_storage_) { + std::string is_phi = node->is_phi ? ", phi" : ""; + std::string is_optimized = node->mark_as_dead ? ", dead" : ""; + absl::StrAppend(&out, node->value_id); + absl::StrAppend(&out, is_phi); + absl::StrAppend(&out, is_optimized, ":\n"); + for (Node* input : node->operands) { + absl::StrAppend(&out, " ", input->value_id); + absl::StrAppend(&out, "\n"); + } + } + return out; +} + +void PhiGraph::Optimize() { + // Set up users for each node. + for (auto& node : node_storage_) { + for (Node* input : node->operands) { + input->users.push_back(node.get()); + } + } + + // input_node->users.push_back(node); + bool changed = true; + + // Run the optimization to a fixed point. + while (changed) { + changed = false; + absl::flat_hash_set checked_for_closure; + for (auto& node : node_storage_) { + // Only optimize phi node. + if (!node->is_phi) { + continue; + } + // Skip dead nodes + if (node->mark_as_dead) { + continue; + } + + Node* node_ptr = node.get(); + + CHECK_GE(node_ptr->operands.size(), 1); + + // Remove self-referencing ids from users and operands. + auto it = absl::c_find(node_ptr->operands, node_ptr); + while (it != node_ptr->operands.end()) { + node_ptr->operands.erase(it); + it = absl::c_find(node_ptr->operands, node_ptr); + } + + it = absl::c_find(node_ptr->users, node_ptr); + while (it != node_ptr->users.end()) { + node_ptr->users.erase(it); + it = absl::c_find(node_ptr->users, node_ptr); + } + + // If all inputs to phi (after self referencing ids are removed) are the + // same value, replace the phi with that value. + // + // phi(A, A, ... A) => A + // phi(A, self) = phi(A) => A + CHECK_GE(node_ptr->operands.size(), 1); + bool all_inputs_are_same = absl::c_all_of( + node_ptr->operands, + [&](Node* elem) { return elem == node_ptr->operands[0]; }); + + if (all_inputs_are_same) { + ReplaceNodeWith(node_ptr, node_ptr->operands[0]); + changed = true; + continue; + } + + // Find a closure of inter-connected phis and one non-phi node. Replace + // all phis with that non-phi node. + // + // def A = phi(B, C) + // def B = phi(C, D) + // def C = phi(A, B) + // def D = non-phi + // Replace A, B, and C with D: + // A = phi(B, C) => D + // B = phi(C, D) => D + // C = phi(A, B) => D + if (checked_for_closure.contains(node_ptr)) { + continue; + } + // Keeps track of nodes in the current closure being tested. + absl::flat_hash_set workset; + std::queue worklist; + Node* non_phi = nullptr; + worklist.push(node_ptr); + while (!worklist.empty()) { + Node* todo = worklist.front(); + worklist.pop(); + if (workset.contains(todo)) { + continue; + } + checked_for_closure.insert(todo); + workset.insert(todo); + for (Node* operand : todo->operands) { + worklist.push(operand); + } + if (!todo->is_phi) { + if (non_phi != nullptr && non_phi != todo) { + // We see distinct non-phi nodes in the closure, can't apply the + // optimization. + non_phi = nullptr; + // Break the while loop non_phi setting to nullptr, signaling that + // the optimization can't be applied. + break; + } else { + // This is the non_phi node we are seeing so far. + non_phi = todo; + } + } + } + if (non_phi != nullptr) { + // Replace all phi nodes in the closure/workset with the non_phi node. + for (Node* node : workset) { + if (!node->is_phi) { + CHECK_EQ(node, non_phi); + continue; + } + ReplaceNodeWith(node, non_phi); + changed = true; + } + } + } + } +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph.h b/tensorflow/compiler/xla/service/hlo_phi_graph.h new file mode 100644 index 00000000000..a0eb994438e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_phi_graph.h @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PHI_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PHI_GRAPH_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" + +namespace xla { +// Phi graph is a graph that contains and connects phi nodes build on top of +// HloValues with explicit edges, as well as non-phi nodes that are direct +// inputs to the phi nodes. The graph can be viewed as an 'overlay' on top of +// HloValues, with some edges that connect them together. After optimization, +// some phis nodes will be simplified away and the user can then ask two useful +// questions: +// +// 1. Which HloValue should a phi node being replaced with? +// 2. TODO(yunxing): What are the set of aliased HloValues that are connecting +// to the same phi (Must-aliasing). +class PhiGraph { + public: + // Register an hlo value into the phi node. + void RegisterPhi(const HloValue& value, + absl::Span inputs); + + HloValue::Id GetOptimizedId(const HloValue& value); + + // Returns true if the input to a hlo value is the same as `inputs`. + bool InputsEqualTo(const HloValue& value, + absl::Span inputs); + + // Given `id`, returns the new id that `id` should be replaced with. If the + // node is not optimized, returns the same value. + HloValue::Id FindOptimizedValue(const HloValue::Id id); + + // Optimize the entire graph. + void Optimize(); + + std::string ToString(); + + private: + struct Node { + bool is_phi; + // Users of this node. Non-phi node has no operands. + std::vector users; + // Operands of this node. + std::vector operands; + + // The value that the node is originally registered with. + HloValue::Id value_id; + + // mark_as_dead is set to true when a phi node is simplified away + // + // Precondition: node is a phi. + bool mark_as_dead = false; + }; + + Node* CreateOrReuseNode(const HloValue& value); + + // Relace `node` with `replace`. Redirect all users to the `replace` and + // all HloValues pointing to the `node` to `replace`. Also mark `node` as + // dead. + // + // Precondition: node is a phi -- It's only possible to simplify away a + // phi node. + void ReplaceNodeWith(Node* node, Node* replace); + + // A reverse mapping of a node in the phi graph and all HloValues pointing + // to that phi. + absl::flat_hash_map> node_to_value_id_; + + // A mapping between a HloValue and node in the phi graph. + absl::flat_hash_map value_id_to_node_; + std::vector> node_storage_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PHI_GRAPH_H_ diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc b/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc new file mode 100644 index 00000000000..41f0454fe55 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_phi_graph.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { +class PhiGraphTest : public ::testing::Test { + protected: + HloValue NewHloValue(bool is_phi) { + static int64 id = 0; + return HloValue(id++, dummy_inst_.get(), {}, is_phi); + } + + void SetUp() override { + dummy_inst_ = HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)); + } + + // Dummy instruction used to fill unrelated argument when creating a + // HloValue. + std::unique_ptr dummy_inst_; +}; + +TEST_F(PhiGraphTest, SelfReferencingPhi) { + // Def A = non-phi + // Def B = phi(B, A) + // + // Optimize B into A. + PhiGraph phi_graph; + HloValue A = NewHloValue(false); + HloValue B = NewHloValue(true); + phi_graph.RegisterPhi(B, {&A, &B}); + phi_graph.Optimize(); + EXPECT_EQ(A.id(), phi_graph.FindOptimizedValue(B.id())); +} + +TEST_F(PhiGraphTest, PhiWithSameInputs) { + // Def A = non-phi + // Def B = phi(A, A) + // + // Optimize B into A. + PhiGraph phi_graph; + HloValue A = NewHloValue(false); + HloValue B = NewHloValue(true); + phi_graph.RegisterPhi(B, {&A, &A}); + phi_graph.Optimize(); + EXPECT_EQ(A.id(), phi_graph.FindOptimizedValue(B.id())); +} + +TEST_F(PhiGraphTest, CircularPhi) { + // def A = phi(B, C) + // def B = phi(C, D) + // def C = phi(A, B) + // def D = non-phi + // Replace A, B, and C with D: + PhiGraph phi_graph; + HloValue A = NewHloValue(true); + HloValue B = NewHloValue(true); + HloValue C = NewHloValue(true); + HloValue D = NewHloValue(false); + phi_graph.RegisterPhi(A, {&B, &C}); + phi_graph.RegisterPhi(B, {&D, &C}); + phi_graph.RegisterPhi(C, {&A, &B}); + phi_graph.Optimize(); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(A.id())); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(B.id())); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id())); +} + +} // namespace +} // namespace xla From 5ba1dd891725f5c071874b472beca06b9e562537 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 17 Mar 2020 14:19:46 -0700 Subject: [PATCH 099/492] Add auto generated TensorFlow op Inv PiperOrigin-RevId: 301451430 Change-Id: If8ea3a5d76429e82e5b5d7b497868f047cac2d8b --- .../mlir/tensorflow/ir/tf_generated_ops.td | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index d2bbbd32b7c..39b24ad353f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -2607,6 +2607,24 @@ def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> { TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>; } +def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the reciprocal of x element-wise."; + + let description = [{ +I.e., \\(y = 1 / x\\). + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_InvertOp : TF_Op<"Invert", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Invert (flip) each bit of supported types; for example, type `uint8` value 01010101 becomes 10101010. From adddc552048d0ba9f638444ec3d9403b07966a9e Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 17 Mar 2020 14:31:14 -0700 Subject: [PATCH 100/492] Reuse CreateDenseElementsAttrFromLiteral util function in HLO importer PiperOrigin-RevId: 301453826 Change-Id: Ic926347fbd8d8cf73d2cde8f91b25bfe073a8cfa --- .../mlir/xla/hlo_function_importer.cc | 34 ++----------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index fa029bd50d0..95421d95504 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -66,36 +66,6 @@ string SanitizeFunctionName(llvm::StringRef name) { return output; } -StatusOr CreateDenseAttrFromLiteral(ShapedType type, - const Literal& literal) { -#define DENSE_ELEMENT_ATTR_BUILDER(xla_type, cpp_type) \ - case xla_type: { \ - auto data_span = literal.data(); \ - return DenseElementsAttr::get( \ - type, llvm::makeArrayRef(data_span.data(), data_span.size())); \ - } - - switch (literal.shape().element_type()) { - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::PRED, bool) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::F32, float) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::F64, double) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S8, int8) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S16, int16) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S32, int32) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S64, int64) - // TODO(b/130356985): Update once MLIR supports unsigned integers. - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::U8, uint8) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::U16, uint16) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::U32, uint32) - DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::U64, uint64) - default: - return tensorflow::errors::Internal( - absl::StrCat("Unsupported type: ", - PrimitiveType_Name(literal.shape().element_type()))); - } -#undef DENSE_ELEMENT_ATTR_BUILDER -} - // Returns whether the instruction is a default dot operation. bool DotIsDefault(const HloInstruction* instruction) { auto dnums = instruction->dot_dimension_numbers(); @@ -209,8 +179,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return nullptr; } case HloOpcode::kConstant: { - auto attr = CreateDenseAttrFromLiteral( - result_type.cast(), instruction->literal()); + const Literal& literal = instruction->literal(); + auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_); if (!attr.ok()) return attr.status(); mlir::Operation* new_operation = func_builder->create(loc, attr.ValueOrDie()); From 1cea2490cb1fb1e930694caa04c36a3049491535 Mon Sep 17 00:00:00 2001 From: Anna R Date: Tue, 17 Mar 2020 14:50:10 -0700 Subject: [PATCH 101/492] Temporarily disable multiple symbol check to fix windows build. PiperOrigin-RevId: 301458183 Change-Id: I79b1f0f2bc0f986c480a54c9fa36671eded06095 --- tensorflow/python/util/tf_export.py | 6 +++--- tensorflow/python/util/tf_export_test.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index e4d6bebc3db..04c96d03617 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -341,9 +341,9 @@ class api_export(object): # pylint: disable=invalid-name # their own _tf_api_names as opposed to just inheriting it. if api_names_attr in func.__dict__: if not self._allow_multiple_exports: - # TODO(annarev): temporarily removing check to fix builds. - # Need to investigate why symbols get reported multiple times. - return + raise SymbolAlreadyExposedError( + 'Symbol %s is already exposed as %s.' % + (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access setattr(func, api_names_attr, names) def export_constant(self, module_name, name): diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py index 51d9901fdb4..20625792e9b 100644 --- a/tensorflow/python/util/tf_export_test.py +++ b/tensorflow/python/util/tf_export_test.py @@ -152,8 +152,7 @@ class ValidateExportTest(test.TestCase): (('NAME_E', 'NAME_F'), 0.5)], module2._tf_api_constants) - # TODO(b/151745456): re-enable - def DISABLED_testRaisesExceptionIfAlreadyHasAPINames(self): + def testRaisesExceptionIfAlreadyHasAPINames(self): _test_function._tf_api_names = ['abc'] export_decorator = tf_export.tf_export('nameA', 'nameB') with self.assertRaises(tf_export.SymbolAlreadyExposedError): From 69565ec4003902794bc94e10ba5fe9469a0b3ae4 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Tue, 17 Mar 2020 14:54:56 -0700 Subject: [PATCH 102/492] Create Variables to track mini-batches seen in Model.fit / evaluate / predict. Use these counters in the TensorBoard Callback. PiperOrigin-RevId: 301459298 Change-Id: I8e92e119ef4cef37c41532d11caefd601d4395a7 --- tensorflow/python/keras/callbacks.py | 448 +++++++----------- tensorflow/python/keras/callbacks_test.py | 6 +- tensorflow/python/keras/callbacks_v1.py | 29 +- tensorflow/python/keras/engine/training.py | 63 ++- tensorflow/python/keras/engine/training_v1.py | 3 + .../keras/tests/model_subclassing_test.py | 15 + .../python/keras/utils/version_utils.py | 22 + ...orflow.keras.callbacks.-tensor-board.pbtxt | 2 + ...orflow.keras.callbacks.-tensor-board.pbtxt | 1 + 9 files changed, 305 insertions(+), 284 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index bb9e61d01a2..9177d89c67b 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -35,21 +35,19 @@ import six from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import distributed_file_utils from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.distribute import multi_worker_training_state as training_state from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 -from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.training import checkpoint_management @@ -1614,7 +1612,7 @@ class LearningRateScheduler(Callback): @keras_export('keras.callbacks.TensorBoard', v1=[]) -class TensorBoard(Callback): +class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -1676,11 +1674,10 @@ class TensorBoard(Callback): batches. Note that writing too frequently to TensorBoard can slow down your training. profile_batch: Profile the batch(es) to sample compute characteristics. - profile_batch must be a non-negative integer or a comma separated string - of pair of positive integers. A pair of positive integers signify a - range of batches to profile. By default, it will profile the second - batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow - eager mode. + profile_batch must be a non-negative integer or a tuple of integers. + A pair of positive integers signify a range of batches to profile. + By default, it will profile the second batch. Set profile_batch=0 + to disable profiling. Must run in TensorFlow eager mode. embeddings_freq: frequency (in epochs) at which embedding layers will be visualized. If set to 0, embeddings won't be visualized. embeddings_metadata: a dictionary which maps layer name to a file name in @@ -1713,30 +1710,18 @@ class TensorBoard(Callback): self.histogram_freq = histogram_freq self.write_graph = write_graph self.write_images = write_images - if update_freq == 'batch': - self.update_freq = 1 - else: - self.update_freq = update_freq + self.update_freq = 1 if update_freq == 'batch' else update_freq self.embeddings_freq = embeddings_freq self.embeddings_metadata = embeddings_metadata + self._init_profile_batch(profile_batch) + self._epoch = 0 - self._samples_seen = 0 - self._samples_seen_at_last_write = 0 - self._current_batch = 0 - - # A collection of file writers currently in use, to be closed when - # training ends for this callback. Writers are keyed by the - # directory name under the root logdir: e.g., "train" or - # "validation". - self._train_run_name = 'train' - self._validation_run_name = 'validation' + # Lazily initialized in order to avoid creating event files when + # not needed. self._writers = {} - self._start_batch, self._stop_batch = self._init_profile_batch( - profile_batch) - if self._start_batch > 0: - profiler.warmup() # Improve the profiling accuracy. - # True when a trace is running. - self._is_tracing = False + + # Used to restore any existing `SummaryWriter` after training ends. + self._prev_summary_state = [] def _validate_kwargs(self, kwargs): """Handle arguments were supported in V1.""" @@ -1768,37 +1753,56 @@ class TensorBoard(Callback): def set_model(self, model): """Sets Keras model and writes graph if specified.""" self.model = model + self._log_write_dir = self._get_log_write_dir() - # In case this callback is used via native Keras, _get_distribution_strategy does not exist. - if hasattr(self.model, '_get_distribution_strategy'): - # TensorBoard callback involves writing a summary file in a - # possibly distributed settings. - self._log_write_dir = distributed_file_utils.write_dirpath( - self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access - else: - self._log_write_dir = self.log_dir + self._train_dir = os.path.join(self._log_write_dir, 'train') + self._train_step = self.model._train_counter # pylint: disable=protected-access - with context.eager_mode(): - self._close_writers() - if self.write_graph: - with self._get_writer(self._train_run_name).as_default(): - with summary_ops_v2.always_record_summaries(): - if not model.run_eagerly: - summary_ops_v2.graph(K.get_graph(), step=0) + self._val_dir = os.path.join(self._log_write_dir, 'validation') + self._val_step = self.model._test_counter # pylint: disable=protected-access - summary_writable = ( - self.model._is_graph_network or # pylint: disable=protected-access - self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access - if summary_writable: - summary_ops_v2.keras_model('keras', self.model, step=0) + self._writers = {} # Resets writers. + if self.write_graph: + self._write_keras_model_graph() if self.embeddings_freq: self._configure_embeddings() - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - self._prev_summary_recording = summary_state.is_recording - self._prev_summary_writer = summary_state.writer - self._prev_summary_step = summary_state.step + @property + def _train_writer(self): + if 'train' not in self._writers: + self._writers['train'] = summary_ops_v2.create_file_writer_v2( + self._train_dir) + return self._writers['train'] + + @property + def _val_writer(self): + if 'val' not in self._writers: + self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir) + return self._writers['val'] + + def _get_log_write_dir(self): + """For multi-worker, only chief should write, others write to '/tmp'.""" + return distributed_file_utils.write_dirpath(self.log_dir, + self.model.distribute_strategy) + + def _delete_tmp_write_dir(self): + """Deletes tmp write directories for multi-worker.""" + distributed_file_utils.remove_temp_dirpath(self.log_dir, + self.model.distribute_strategy) + + def _write_keras_model_graph(self): + """Writes Keras graph networks to TensorBoard.""" + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): + if not self.model.run_eagerly: + summary_ops_v2.graph(K.get_graph(), step=0) + + summary_writable = ( + self.model._is_graph_network or # pylint: disable=protected-access + self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access + if summary_writable: + summary_ops_v2.keras_model('keras', self.model, step=0) def _configure_embeddings(self): """Configure the Projector for embeddings.""" @@ -1839,74 +1843,44 @@ class TensorBoard(Callback): writer = DummyWriter(self._log_write_dir) projector.visualize_embeddings(writer, config) - def _close_writers(self): - """Close all remaining open file writers owned by this callback. - - If there are no such file writers, this is a no-op. - """ - with context.eager_mode(): - for writer in six.itervalues(self._writers): - writer.close() - self._writers.clear() - - def _get_writer(self, writer_name): - """Get a summary writer for the given subdirectory under the logdir. - - A writer will be created if it does not yet exist. - - Arguments: - writer_name: The name of the directory for which to create or - retrieve a writer. Should be either `self._train_run_name` or - `self._validation_run_name`. - - Returns: - A `SummaryWriter` object. - """ - if writer_name not in self._writers: - path = os.path.join(self._log_write_dir, writer_name) - writer = summary_ops_v2.create_file_writer_v2(path) - self._writers[writer_name] = writer - return self._writers[writer_name] - - def _set_default_writer(self, writer_name): + def _push_writer(self, writer, step): """Sets the default writer for custom batch-level summaries.""" if self.update_freq == 'epoch': - # Writer is only used for custom summaries, which are written - # batch-by-batch. return - step = self._total_batches_seen[writer_name] + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + self._prev_summary_state.append({ + 'is_recording': summary_state.is_recording, + 'writer': summary_state.writer, + 'step': summary_state.step + }) - def _should_record(): - return math_ops.equal(step % self.update_freq, 0) + if self.update_freq == 'epoch': + should_record = False + writer = None + else: + should_record = lambda: math_ops.equal(step % self.update_freq, 0) + + summary_state.is_recording = should_record + summary_state.writer = writer + # TODO(b/151339474): Fix deadlock when not using .value() here. + summary_ops_v2.set_step(step.value()) + + def _pop_writer(self): + """Pops the current writer.""" + if self.update_freq == 'epoch': + return + + prev_state = self._prev_summary_state.pop() summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.is_recording = _should_record - summary_state.writer = self._get_writer(writer_name) - summary_ops_v2.set_step(step) + summary_state.is_recording = prev_state['is_recording'] + summary_state.writer = prev_state['writer'] + summary_ops_v2.set_step(prev_state['step']) - def _init_batch_steps(self): - """Create the total batch counters.""" - if ops.executing_eagerly_outside_functions(): - # Variables are needed for the `step` value of custom tf.summaries - # to be updated inside a tf.function. - self._total_batches_seen = { - self._train_run_name: variables.Variable(0, dtype='int64'), - self._validation_run_name: variables.Variable(0, dtype='int64') - } - else: - # Custom tf.summaries are not supported in legacy graph mode. - self._total_batches_seen = { - self._train_run_name: 0, - self._validation_run_name: 0 - } - - def _increment_step(self, writer_name): - step = self._total_batches_seen[writer_name] - if isinstance(step, variables.Variable): - step.assign_add(1) - else: - self._total_batches_seen[writer_name] += 1 + def _close_writers(self): + for writer in self._writers.values(): + writer.close() def _init_profile_batch(self, profile_batch): """Validate profile_batch value and set the range of batches to profile. @@ -1926,75 +1900,79 @@ class TensorBoard(Callback): """ profile_batch_error_message = ( - 'profile_batch must be a non-negative integer or a comma separated ' - 'string of pair of positive integers. A pair of positive integers ' - 'signify a range of batches to profile.') - try: - profile_range = [int(i) for i in str(profile_batch).split(',')] - except ValueError: - raise ValueError(profile_batch_error_message) - if len(profile_range) == 1: # single batch - start_batch, stop_batch = profile_range[0], profile_range[0] - if start_batch < 0: - raise ValueError(profile_batch_error_message) - elif len(profile_range) == 2: # (start_batch, stop_batch) - start_batch, stop_batch = profile_range - # [0, 0], [-1, 100], [6, 5] are illegal. - if start_batch <= 0 or start_batch > stop_batch: - raise ValueError(profile_batch_error_message) + 'profile_batch must be a non-negative integer or 2-tuple of positive ' + 'integers. A pair of positive integers signifies a range of batches ' + 'to profile. Found: {}'.format(profile_batch)) + + # Support legacy way of specifying "start,stop" or "start" as str. + if isinstance(profile_batch, six.string_types): + profile_batch = str(profile_batch).split(',') + profile_batch = nest.map_structure(int, profile_batch) + + if isinstance(profile_batch, int): + self._start_batch = profile_batch + self._stop_batch = profile_batch + elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2: + self._start_batch, self._stop_batch = profile_batch else: raise ValueError(profile_batch_error_message) - return start_batch, stop_batch + + if self._start_batch < 0 or self._stop_batch < self._start_batch: + raise ValueError(profile_batch_error_message) + + if self._start_batch > 0: + profiler.warmup() # Improve the profiling accuracy. + # True when a trace is running. + self._is_tracing = False + + # Setting `profile_batch=0` disables profiling. + self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0) def on_train_begin(self, logs=None): - self._init_batch_steps() - if self._start_batch == 1: - self._enable_trace() + self._push_writer(self._train_writer, self._train_step) + + def on_train_end(self, logs=None): + self._pop_writer() + + if self._is_tracing: + self._stop_trace() + + self._close_writers() + self._delete_tmp_write_dir() def on_test_begin(self, logs=None): - self._set_default_writer(self._validation_run_name) + self._push_writer(self._val_writer, self._val_step) + + def on_test_end(self, logs=None): + self._pop_writer() + + def on_train_batch_begin(self, batch, logs=None): + if not self._should_trace: + return + + if self._epoch == 0 and batch == self._start_batch: + self._start_trace() def on_train_batch_end(self, batch, logs=None): - """Writes scalar summaries for metrics on every training batch. - - Performs profiling if current batch is in profiler_batches. + """Performs profiling if current batch is in profiler_batches. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Metric results for this batch. """ - # TODO(b/150629188): Make TensorBoard callback not use batch hooks - # by default. - if self.update_freq == 'epoch' and self._start_batch is None: + if not self._should_trace: return - # Don't output batch_size and batch number as TensorBoard summaries - logs = logs or {} - train_batches = self._total_batches_seen[self._train_run_name] - if self.update_freq != 'epoch' and batch % self.update_freq == 0: - self._log_metrics(logs, prefix='batch_', step=train_batches) - - self._increment_step(self._train_run_name) - if self._is_tracing: - control_flow_ops.cond( - math_ops.greater_equal(train_batches, self._stop_batch), - lambda: self._log_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda - else: - control_flow_ops.cond( - math_ops.equal(train_batches, self._start_batch - 1), - lambda: self._enable_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda - - def on_test_batch_end(self, batch, logs=None): - if self.update_freq == 'epoch': - return - self._increment_step(self._validation_run_name) + if self._is_tracing and batch >= self._stop_batch: + self._stop_trace() def on_epoch_begin(self, epoch, logs=None): - self._set_default_writer(self._train_run_name) + # Keeps track of epoch for profiling. + self._epoch = epoch def on_epoch_end(self, epoch, logs=None): """Runs metrics and histogram summaries at epoch end.""" - self._log_metrics(logs, prefix='epoch_', step=epoch) + self._log_epoch_metrics(epoch, logs) if self.histogram_freq and epoch % self.histogram_freq == 0: self._log_weights(epoch) @@ -2002,124 +1980,57 @@ class TensorBoard(Callback): if self.embeddings_freq and epoch % self.embeddings_freq == 0: self._log_embeddings(epoch) - def on_train_end(self, logs=None): - if self._is_tracing: - self._log_trace() - self._close_writers() - - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.is_recording = self._prev_summary_recording - summary_state.writer = self._prev_summary_writer - summary_state.step = self._prev_summary_step - - # In case this callback is used via native Keras, _get_distribution_strategy does not exist. - if hasattr(self.model, '_get_distribution_strategy'): - # Safely remove the unneeded temp files. - distributed_file_utils.remove_temp_dirpath( - self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access - - def _enable_trace(self): - """Starts to collect trace graph to TensorBoard. - - Collects both trace and graph in eager mode, and trace only in graph mode. - """ - if context.executing_eagerly(): - # Graph must be traced in eager mode. - summary_ops_v2.trace_on(graph=True, profiler=False) - profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) + def _start_trace(self): + summary_ops_v2.trace_on(graph=True, profiler=False) + profiler.start(logdir=self._train_dir) self._is_tracing = True - def _enable_trace_return_true(self): - """Starts to collect trace graph to TensorBoard and returns True. - - Returns: - True. - """ - self._enable_trace() - return True - - def _log_trace(self): - """Logs the trace graph to TensorBoard. - - Logs both trace and graph in eager mode, and trace only in graph mode. - """ - profiler.stop() - if context.executing_eagerly(): - # Graph must be traced in eager mode. - with self._get_writer(self._train_run_name).as_default(), \ - summary_ops_v2.always_record_summaries(): + def _stop_trace(self, batch=None): + """Logs the trace graph to TensorBoard.""" + if batch is None: + batch = self._stop_batch + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): # TODO(b/126388999): Remove step info in the summary name. - step = K.get_value(self._total_batches_seen[self._train_run_name]) - summary_ops_v2.trace_export(name='batch_%d' % step, step=step) + summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch) + profiler.stop() self._is_tracing = False - def _log_trace_return_true(self): - """Logs the trace graph to TensorBoard and returns True. - - Returns: - True. - """ - self._log_trace() - return True - - def _log_metrics(self, logs, prefix, step): - """Writes metrics out as custom scalar summaries. + def _log_epoch_metrics(self, epoch, logs): + """Writes epoch metrics out as scalar summaries. Arguments: - logs: Dict. Keys are scalar summary names, values are NumPy scalars. - prefix: String. The prefix to apply to the scalar summary names. - step: Int. The global step to use for TensorBoard. + epoch: Int. The global step to use for TensorBoard. + logs: Dict. Keys are scalar summary names, values are scalars. """ - if logs is None: - logs = {} + if not logs: + return - # Group metrics by the name of their associated file writer. Values - # are lists of metrics, as (name, scalar_value) pairs. - logs_by_writer = { - self._train_run_name: [], - self._validation_run_name: [], - } - validation_prefix = 'val_' - for (name, value) in logs.items(): - if name in ('batch', 'size', 'num_steps'): - # Scrub non-metric items. - continue - if name.startswith(validation_prefix): - name = name[len(validation_prefix):] - writer_name = self._validation_run_name - else: - writer_name = self._train_run_name - name = prefix + name # assign batch or epoch prefix - logs_by_writer[writer_name].append((name, value)) + train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')} + val_logs = {k: v for k, v in logs.items() if k.startswith('val_')} - with context.eager_mode(): - with summary_ops_v2.always_record_summaries(): - for writer_name in logs_by_writer: - these_logs = logs_by_writer[writer_name] - if not these_logs: - # Don't create a "validation" events file if we don't - # actually have any validation data. - continue - writer = self._get_writer(writer_name) - with writer.as_default(): - for (name, value) in these_logs: - summary_ops_v2.scalar(name, value, step=step) + with summary_ops_v2.always_record_summaries(): + if train_logs: + with self._train_writer.as_default(): + for name, value in train_logs.items(): + summary_ops_v2.scalar('epoch_' + name, value, step=epoch) + if val_logs: + with self._val_writer.as_default(): + for name, value in val_logs.items(): + name = name[4:] # Remove 'val_' prefix. + summary_ops_v2.scalar('epoch_' + name, value, step=epoch) def _log_weights(self, epoch): """Logs the weights of the Model to TensorBoard.""" - writer = self._get_writer(self._train_run_name) - with context.eager_mode(), \ - writer.as_default(), \ - summary_ops_v2.always_record_summaries(): - for layer in self.model.layers: - for weight in layer.weights: - weight_name = weight.name.replace(':', '_') - with ops.init_scope(): - weight = K.get_value(weight) - summary_ops_v2.histogram(weight_name, weight, step=epoch) - if self.write_images: - self._log_weight_as_image(weight, weight_name, epoch) - writer.flush() + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): + for layer in self.model.layers: + for weight in layer.weights: + weight_name = weight.name.replace(':', '_') + summary_ops_v2.histogram(weight_name, weight, step=epoch) + if self.write_images: + self._log_weight_as_image(weight, weight_name, epoch) + self._train_writer.flush() def _log_weight_as_image(self, weight, weight_name, epoch): """Logs a weight as a TensorBoard image.""" @@ -2150,6 +2061,9 @@ class TensorBoard(Callback): 'keras_embedding.ckpt-{}'.format(epoch)) self.model.save_weights(embeddings_ckpt) + def _implements_train_batch_hooks(self): + return not (self._start_batch == 0 and self._stop_batch == 0) + @keras_export('keras.callbacks.ReduceLROnPlateau') class ReduceLROnPlateau(Callback): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index eb62d0b29ee..54f71402177 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -2079,17 +2079,19 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): model.fit( np.zeros((64, 1)), np.zeros((64, 1)), + batch_size=32, callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)], ) # Verifies trace exists in the first train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) model.fit( np.zeros((64, 1)), np.zeros((64, 1)), + batch_size=32, callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)], ) # Verifies trace exists in the second train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) def test_TensorBoard_autoTrace_profileBatchRange(self): model = self._get_seq_model() diff --git a/tensorflow/python/keras/callbacks_v1.py b/tensorflow/python/keras/callbacks_v1.py index 524e039f597..09af890b76c 100644 --- a/tensorflow/python/keras/callbacks_v1.py +++ b/tensorflow/python/keras/callbacks_v1.py @@ -39,7 +39,7 @@ from tensorflow.python.util.tf_export import keras_export @keras_export(v1=['keras.callbacks.TensorBoard']) -class TensorBoard(callbacks.Callback): +class TensorBoard(callbacks.TensorBoard): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -127,7 +127,8 @@ class TensorBoard(callbacks.Callback): embeddings_data=None, update_freq='epoch', profile_batch=2): - super(TensorBoard, self).__init__() + # Don't call super's init since it is an eager-only version. + callbacks.Callback.__init__(self) self.log_dir = log_dir self.histogram_freq = histogram_freq if self.histogram_freq and context.executing_eagerly(): @@ -342,6 +343,21 @@ class TensorBoard(callbacks.Callback): self.writer.add_summary(summary, step) self.writer.flush() + def on_train_batch_begin(self, batch, logs=None): + if (not self._is_profiling and + self._total_batches_seen == self._profile_batch - 1): + profiler.start(self.log_dir) + self._is_profiling = True + + def on_train_batch_end(self, batch, logs=None): + return self.on_batch_end(batch, logs) + + def on_test_begin(self, logs=None): + pass + + def on_test_end(self, logs=None): + pass + def on_batch_end(self, batch, logs=None): """Writes scalar summaries for metrics on every training batch. @@ -358,18 +374,13 @@ class TensorBoard(callbacks.Callback): self._write_custom_summaries(self._total_batches_seen, batch_logs) self._samples_seen_at_last_write = self._samples_seen self._total_batches_seen += 1 + if self._is_profiling: profiler.stop() self._is_profiling = False - elif (not self._is_profiling and - self._total_batches_seen == self._profile_batch - 1): - profiler.start(self.log_dir) - self._is_profiling = True def on_train_begin(self, logs=None): - if self._profile_batch == 1: - profiler.start(self.log_dir) - self._is_profiling = True + pass def on_epoch_begin(self, epoch, logs=None): """Add histogram op to Model eval_function callbacks, reset batch count.""" diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 7dcf10a506c..21361f680da 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import copy +import itertools from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribute_coordinator_context as dc_context @@ -28,6 +29,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import monitoring +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import callbacks as callbacks_module from tensorflow.python.keras import optimizers @@ -43,6 +45,8 @@ from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import summary_ops_v2 +from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.profiler import trace @@ -161,6 +165,9 @@ class Model(network.Network, version_utils.ModelVersionSelector): Checkout [guide](https://www.tensorflow.org/guide/keras/overview) for additional details. """ + _TF_MODULE_IGNORED_PROPERTIES = frozenset( + itertools.chain(('_train_counter', '_test_counter', '_predict_counter'), + network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access def __init__(self, *args, **kwargs): super(Model, self).__init__(*args, **kwargs) @@ -186,6 +193,18 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.compiled_loss = None self.compiled_metrics = None + self._init_batch_counters() + + @trackable.no_automatic_dependency_tracking + def _init_batch_counters(self): + # Untracked Variables, used to keep track of mini-batches seen in `fit`, + # `evaluate`, and `predict`. + agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA + self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg) + self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg) + self._predict_counter = variables.Variable( + 0, dtype='int64', aggregation=agg) + def get_weights(self): """Retrieves the weights of the model. @@ -499,11 +518,18 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.train_function def train_function(iterator): + """Runs one call to `self.train_function`.""" + + def run_step(data): + outputs = self.train_step(data) + self._train_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.train_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') + write_scalar_summaries(outputs, step=self._train_counter) return outputs if not self.run_eagerly: @@ -762,6 +788,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.stop_training = False train_function = self.make_train_function() + self._train_counter.assign(0) callbacks.on_train_begin() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to @@ -872,9 +899,15 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.test_function def test_function(iterator): + """Runs one call to `self.test_function`.""" + + def run_step(data): + outputs = self.test_step(data) + self._test_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.test_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') return outputs @@ -1003,6 +1036,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): steps=data_handler.inferred_steps) test_function = self.make_test_function() + self._test_counter.assign(0) callbacks.on_test_begin() for _, iterator in data_handler.enumerate_epochs(): # Single epoch. self.reset_metrics() @@ -1075,9 +1109,15 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.predict_function def predict_function(iterator): + """Runs one call to `self.predict_function`.""" + + def run_step(data): + outputs = self.predict_step(data) + self._predict_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.predict_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='concat') return outputs @@ -1192,6 +1232,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): steps=data_handler.inferred_steps) predict_function = self.make_predict_function() + self._predict_counter.assign(0) callbacks.on_predict_begin() for _, iterator in data_handler.enumerate_epochs(): # Single epoch. with data_handler.catch_stop_iteration(): @@ -1734,3 +1775,13 @@ def _minimize(tape, optimizer, loss, trainable_variables): all_reduce_sum_gradients=False) else: optimizer.apply_gradients(zip(gradients, trainable_variables)) + + +def _is_scalar(x): + return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 + + +def write_scalar_summaries(logs, step): + for name, value in logs.items(): + if _is_scalar(value): + summary_ops_v2.scalar('batch_' + name, value, step=step) diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 1c0fea91337..710f9bf3497 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -162,6 +162,9 @@ class Model(training_lib.Model): self._v1_compile_was_called = False + def _init_batch_counters(self): + pass # Batch counters should not be created in legacy graph mode. + @trackable.no_automatic_dependency_tracking def _set_strategy(self, strategy): self._compile_time_distribution_strategy = strategy diff --git a/tensorflow/python/keras/tests/model_subclassing_test.py b/tensorflow/python/keras/tests/model_subclassing_test.py index 761f720cea5..5af1148f4f0 100644 --- a/tensorflow/python/keras/tests/model_subclassing_test.py +++ b/tensorflow/python/keras/tests/model_subclassing_test.py @@ -737,6 +737,21 @@ class CustomCallSignatureTests(test.TestCase, parameterized.TestCase): self.assertLen(new_model.variables, 1) self.assertLen(new_model.layers, 1) + def test_batch_counters_not_in_variables(self): + + class MyModel(keras.Model): + + def __init__(self): + super(MyModel, self).__init__() + self.layer = keras.layers.Dense(4) + + def call(self, obs): + return self.layer(obs) + + model = MyModel() + model(np.ones((10, 10))) + self.assertLen(model.variables, 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/utils/version_utils.py b/tensorflow/python/keras/utils/version_utils.py index cf485e1080d..377f370430c 100644 --- a/tensorflow/python/keras/utils/version_utils.py +++ b/tensorflow/python/keras/utils/version_utils.py @@ -36,6 +36,13 @@ base_layer = lazy_loader.LazyLoader( base_layer_v1 = lazy_loader.LazyLoader( "base_layer_v1", globals(), "tensorflow.python.keras.engine.base_layer_v1") +callbacks = lazy_loader.LazyLoader( + "callbacks", globals(), + "tensorflow.python.keras.callbacks") +callbacks_v1 = lazy_loader.LazyLoader( + "callbacks_v1", globals(), + "tensorflow.python.keras.callbacks_v1") + # pylint: enable=g-inconsistent-quotes @@ -58,6 +65,21 @@ class LayerVersionSelector(object): return super(LayerVersionSelector, cls).__new__(cls) +class TensorBoardVersionSelector(object): + """Chooses between Keras v1 and v2 TensorBoard callback class.""" + + def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument + eager_enabled = ops.executing_eagerly_outside_functions() + start_cls = cls + cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, + eager_enabled) + if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard: + # Since the v2 class is not a subclass of the v1 class, __init__ has to + # be called manually. + return cls(*args, **kwargs) + return super(TensorBoardVersionSelector, cls).__new__(cls) + + def swap_class(cls, v2_cls, v1_cls, eager_enabled): """Swaps in v2_cls or v1_cls depending on graph mode.""" if cls == object: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt index 4504633d4a1..2e0c6c97826 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt index 24385e2722a..51d6901e936 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" From 6dffcddea9a5d5d1aa45213c09c6d6424a93938f Mon Sep 17 00:00:00 2001 From: Jose Baiocchi Date: Tue, 17 Mar 2020 15:19:28 -0700 Subject: [PATCH 103/492] Fix flaky host_tracer_test PiperOrigin-RevId: 301464254 Change-Id: I7f36d5a0c7b21c69e94a40cc6c62c11ed9cf0e5b --- tensorflow/core/profiler/internal/cpu/BUILD | 3 --- tensorflow/core/profiler/internal/cpu/host_tracer_test.cc | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD index 1c229f78c43..fe028d85cf7 100644 --- a/tensorflow/core/profiler/internal/cpu/BUILD +++ b/tensorflow/core/profiler/internal/cpu/BUILD @@ -43,9 +43,6 @@ cc_library( tf_cc_test( name = "host_tracer_test", srcs = ["host_tracer_test.cc"], - tags = [ - "no_oss", - ], deps = [ ":host_tracer", "//tensorflow/core:core_cpu_lib", diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc index 412038df9b1..9944a1062e1 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc @@ -36,7 +36,7 @@ namespace { using ::testing::UnorderedElementsAre; -NodeExecStats MakeNodeStats(const string& name, int32 thread_id, +NodeExecStats MakeNodeStats(const string& name, uint32 thread_id, const string& label = "") { NodeExecStats ns; ns.set_node_name(name); @@ -74,7 +74,7 @@ inline ::testing::PolymorphicMatcher EqualsNodeStats( } TEST(HostTracerTest, CollectsTraceMeEventsAsRunMetadata) { - int32 thread_id = Env::Default()->GetCurrentThreadId(); + uint32 thread_id = Env::Default()->GetCurrentThreadId(); auto tracer = CreateHostTracer(ProfilerOptions()); @@ -106,7 +106,7 @@ TEST(HostTracerTest, CollectsTraceMeEventsAsRunMetadata) { } TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) { - int32 thread_id; + uint32 thread_id; string thread_name = "MyThreadName"; XSpace space; From 7b43b3ce08035b6c502b1aa4caa23ba59e4710f2 Mon Sep 17 00:00:00 2001 From: Jiho Choi Date: Tue, 17 Mar 2020 15:22:35 -0700 Subject: [PATCH 104/492] Bring back the global batch number for the profiler. PiperOrigin-RevId: 301464845 Change-Id: I40f6fefaf0ec70be0edc8751827eb1afe0fcb5d9 --- tensorflow/python/keras/callbacks.py | 14 +++++--------- tensorflow/python/keras/callbacks_test.py | 14 ++++++++------ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 9177d89c67b..734b833fd62 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1677,7 +1677,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): profile_batch must be a non-negative integer or a tuple of integers. A pair of positive integers signify a range of batches to profile. By default, it will profile the second batch. Set profile_batch=0 - to disable profiling. Must run in TensorFlow eager mode. + to disable profiling. embeddings_freq: frequency (in epochs) at which embedding layers will be visualized. If set to 0, embeddings won't be visualized. embeddings_metadata: a dictionary which maps layer name to a file name in @@ -1715,6 +1715,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): self.embeddings_metadata = embeddings_metadata self._init_profile_batch(profile_batch) self._epoch = 0 + self._global_train_batch = 0 # Lazily initialized in order to avoid creating event files when # not needed. @@ -1947,23 +1948,18 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): self._pop_writer() def on_train_batch_begin(self, batch, logs=None): + self._global_train_batch += 1 if not self._should_trace: return - if self._epoch == 0 and batch == self._start_batch: + if self._global_train_batch == self._start_batch: self._start_trace() def on_train_batch_end(self, batch, logs=None): - """Performs profiling if current batch is in profiler_batches. - - Arguments: - batch: Integer, index of batch within the current epoch. - logs: Dict. Metric results for this batch. - """ if not self._should_trace: return - if self._is_tracing and batch >= self._stop_batch: + if self._is_tracing and self._global_train_batch >= self._stop_batch: self._stop_trace() def on_epoch_begin(self, epoch, logs=None): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 54f71402177..e488835a6c5 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -2076,22 +2076,24 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): model.compile(gradient_descent.SGD(1), 'mse') + logdir = os.path.join(self.get_temp_dir(), 'tb1') model.fit( np.zeros((64, 1)), np.zeros((64, 1)), batch_size=32, - callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)], + callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=1)], ) - # Verifies trace exists in the first train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) + # Verifies trace exists in the first logdir. + self.assertIsNotNone(self._get_trace_file(logdir=logdir)) + logdir = os.path.join(self.get_temp_dir(), 'tb2') model.fit( np.zeros((64, 1)), np.zeros((64, 1)), batch_size=32, - callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)], + callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=2)], ) - # Verifies trace exists in the second train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) + # Verifies trace exists in the second logdir. + self.assertIsNotNone(self._get_trace_file(logdir=logdir)) def test_TensorBoard_autoTrace_profileBatchRange(self): model = self._get_seq_model() From 690c33e4b5dd7d4b7a5a36ab020b7d69362aa80a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 15:29:06 -0700 Subject: [PATCH 105/492] Changed OpDefBuilderWrapper::SetShapeFn declaration to use OpShapeInferenceFn to match OpDefBuilder::SetShapeFn. PiperOrigin-RevId: 301466148 Change-Id: I210df42451f0e9cd7c5e65fcf99db3cf440a58b1 --- tensorflow/core/framework/op.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 8481e2f1021..278534e1ab0 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -261,9 +261,8 @@ class OpDefBuilderWrapper { builder_.Doc(std::move(text)); return *this; } - OpDefBuilderWrapper& SetShapeFn( - Status (*fn)(shape_inference::InferenceContext*)) { - builder_.SetShapeFn(fn); + OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) { + builder_.SetShapeFn(std::move(fn)); return *this; } const ::tensorflow::OpDefBuilder& builder() const { return builder_; } From 2bb1aed6314392c971461b6867f608002ec51022 Mon Sep 17 00:00:00 2001 From: Anna R Date: Tue, 17 Mar 2020 15:40:59 -0700 Subject: [PATCH 106/492] For tensorflow::io, on windows use '\' as path separator. PiperOrigin-RevId: 301468420 Change-Id: I9625f2524437447f596f81a9983c3b38ca40fba4 --- tensorflow/core/platform/path.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorflow/core/platform/path.cc b/tensorflow/core/platform/path.cc index 00e3f0eca28..1e88328aace 100644 --- a/tensorflow/core/platform/path.cc +++ b/tensorflow/core/platform/path.cc @@ -38,11 +38,7 @@ namespace io { namespace internal { namespace { -#if defined(PLATFORM_WINDOWS) -const char kPathSep[] = "\\"; -#else const char kPathSep[] = "/"; -#endif // PLATFORM_WINDOWS bool FixBazelEnvPath(const char* path, string* out) { if (path == nullptr) return false; From b7f92d99381182194e4e98125edb96f19dfcc4ea Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Tue, 17 Mar 2020 15:58:24 -0700 Subject: [PATCH 107/492] Call update server def when re-enable collective ops PiperOrigin-RevId: 301471698 Change-Id: Iad7390b4ea9f70dce33cfdc6cddf245a39d7c55d --- .../distribute/multi_worker_continuous_run_test.py | 12 ++++++++---- tensorflow/python/eager/context.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/distribute/multi_worker_continuous_run_test.py b/tensorflow/python/distribute/multi_worker_continuous_run_test.py index 9e406e846b8..df30c3f6e3f 100644 --- a/tensorflow/python/distribute/multi_worker_continuous_run_test.py +++ b/tensorflow/python/distribute/multi_worker_continuous_run_test.py @@ -36,13 +36,14 @@ from tensorflow.python.framework import config from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +NUM_WORKERS = 5 + # TODO(b/143286947): expand the test to cover fault tolerance and elasticity class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine(mode=['eager'])) def testAllReduceContinuousRun(self, mode): - num_workers = 5 tensor_shape = [2, 2] local_device = '/device:CPU:0' if config.list_physical_devices('GPU'): @@ -50,6 +51,9 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): def worker_step_fn(): strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() + # Make sure the processeses are in sync after updating the cluster + multi_process_runner.barrier().wait() + tf_config = json.loads(os.environ['TF_CONFIG']) worker_id = tf_config['task']['index'] @@ -62,7 +66,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): t_out = run_reduce() # Element values from the workers are # 0, 1, ..., (num_workers - 1) - expected_mean = (num_workers - 1) / 2 + expected_mean = (NUM_WORKERS - 1) / 2 expected_out = np.ones(tensor_shape) * expected_mean self.assertAllClose(t_out, expected_out) @@ -78,8 +82,8 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): multi_process_runner.run( worker_fn, - cluster_spec=test_base.create_cluster_spec(num_workers=num_workers)) + cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS)) if __name__ == '__main__': - multi_process_runner.test_main() + multi_process_runner.test_main(barrier_parties=NUM_WORKERS) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 073c33383c3..a36d2142329 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -669,13 +669,19 @@ class Context(object): if not server_def: raise ValueError("server_def is None.") + self._collective_ops_server_def = server_def + # TODO(b/129298253): Allow creating datasets/tensors before enabling # collective ops. if self._context_handle is not None: logging.warning("Enabling collective ops after program startup may cause " "error when accessing previously created tensors.") - - self._collective_ops_server_def = server_def + with self._initialize_lock: + assert self._initialized + server_def_str = self._collective_ops_server_def.SerializeToString() + pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str) + self._initialize_logical_devices() + self._clear_caches() def configure_collective_ops( self, From 743f41fa9602a33e62ce7a40644e8e5ec732c7b7 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Tue, 17 Mar 2020 16:12:28 -0700 Subject: [PATCH 108/492] Extend PromoteResourcesToArgsPass to support resource arguments. For TF2 function graphs, resources are passed in as arguments to a function and are not defined as VarHandleOps in the function body. Support for resource writes that take on an argument value is added. PiperOrigin-RevId: 301474692 Change-Id: I52c31eacbe6d4d7d138ec77126a4226ef3194f7d --- .../tests/promote_resources_to_args.mlir | 192 ++++++++++++++- .../transforms/promote_resources_to_args.cc | 224 +++++++++++------- 2 files changed, 329 insertions(+), 87 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index 5741e9527b5..c9eb0830a6c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -tf-promote-resources-to-args | FileCheck %s -dump-input-on-failure +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-resources-to-args | FileCheck %s -dump-input-on-failure // One resource, one read. // CHECK-LABEL: func @main(%arg0: tensor) -> tensor<2xf32> @@ -112,3 +112,193 @@ func @main() -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outp %6 = "tf.Pack"(%2, %5) {N = 2 : i64, T = f32, axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xf32> return %6 : tensor<2xf32> } + +// ----- + +// Tests resource passed in as an argument is not modified and not returned. + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor +func @main(%arg0: tensor>>) { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + // CHECK-NEXT: "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) + %1 = "tf.AddV2"(%0, %0) : (tensor, tensor) -> tensor + // CHECK-NEXT: return + return +} + +// ----- + +// Tests resource passed in as an argument is modified but not returned. + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: -> tensor +func @main(%arg0: tensor>>) { + // CHECK-NEXT: %[[CONST:[a-z0-9]+]] = "tf.Const" + %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor + "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () + // CHECK-NEXT: return %[[CONST]] : tensor + return +} + +// ----- + +// Tests last resource assign is returned as a result. + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: -> tensor +func @main(%arg0: tensor>>) { + %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor + "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () + // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} + %1 = "tf.Const"() {value = dense<1.050000e+03> : tensor} : () -> tensor + "tf.AssignVariableOp"(%arg0, %1) : (tensor>>, tensor) -> () + // CHECK-NEXT: return %[[CONST]] : tensor + return +} + +// ----- + +// Tests last resource assign is returned even when the original function +// returns the same value prior. + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: -> (tensor, tensor) +func @main(%arg0: tensor>>) -> tensor { + %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor + "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () + // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} + %1 = "tf.Const"() {value = dense<1.050000e+03> : tensor} : () -> tensor + "tf.AssignVariableOp"(%arg0, %1) : (tensor>>, tensor) -> () + // CHECK-NEXT: return %[[CONST]], %[[CONST]] : tensor, tensor + return %1 : tensor +} + +// ----- + +// Tests read interleaved between writes. + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: -> (tensor, tensor) +func @main(%arg0: tensor>>) -> tensor { + // CHECK-NEXT: %[[CONST_0:[a-z0-9]+]] = "tf.Const"() {value = dense<4.200000e+01> : tensor} + %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor + "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () + %1 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + // CHECK-NEXT: %[[ADD:[a-z0-9]+]] = "tf.AddV2"(%[[CONST_0]], %[[CONST_0]]) + %2 = "tf.AddV2"(%1, %1) : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[CONST_1:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} + %3 = "tf.Const"() {value = dense<1.050000e+03> : tensor} : () -> tensor + "tf.AssignVariableOp"(%arg0, %3) : (tensor>>, tensor) -> () + // CHECK-NEXT: return %[[ADD]], %[[CONST_1]] : tensor, tensor + return %2 : tensor +} + +// ----- + +// Tests resource write takes on value that is from an argument. + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %[[ARG_1:[a-z0-9]+]]: tensor +// CHECK-SAME: -> tensor +func @main(%arg0: tensor>>, %arg1: tensor) { + "tf.AssignVariableOp"(%arg0, %arg1) : (tensor>>, tensor) -> () + // CHECK-NEXT: return %[[ARG_1]] : tensor + return +} + +// ----- + +// Tests first read of one resource is used as a value to write to another +// resource. + +// CHECK-LABEL: func @main +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %[[ARG_1:[a-z0-9]+]]: tensor +// CHECK-SAME: -> tensor +func @main(%arg0: tensor>>, %arg1: tensor>>) { + %1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + "tf.AssignVariableOp"(%arg0, %1) : (tensor>>, tensor) -> () + // CHECK-NEXT: return %[[ARG_1]] : tensor + return +} + +// ----- + +// Tests main function with multiple blocks. + +// expected-error@+1 {{expects 'main' function to have 1 block, got 2}} +func @main() { + br ^bb1 +^bb1: + return +} + +// ----- + +// Tests main function is terminated with a non MLIR ReturnOp. + +// expected-error@+1 {{expects 'main' function to have a MLIR ReturnOp}} +func @main() { +^bb0: + tf_device.return +} + +// ----- + +// Tests non main function with resource arguments. + +func @main() { + return +} + +// expected-error@+1 {{potential nested resource accesses in function}} +func @other(%arg0: tensor>>) { + return +} + +// ----- + +// Tests main function with invalid resource argument subtype. + +// expected-error@+1 {{expects resource type of argument 0 to have one subtype, got '!tf.resource'}} +func @main(%arg0: tensor) { + return +} + +// ----- + +// Tests main function with invalid VarHandleOp resource subtype. + +func @main() { + // expected-error@+1 {{expects resource type to have one subtype, got '!tf.resource'}} + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + return +} + +// ----- + +// Tests resource argument has users that are not ReadVariableOp or +// AssignVariableOp. + +// expected-error@+1 {{expects users of resource argument 0 to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp'}} +func @main(%arg0: tensor>>) -> tensor { + %0 = "tf.VarIsInitializedOp"(%arg0) : (tensor>>) -> tensor + return %0 : tensor +} + +// ----- + +// Tests VarHandleOp has users that are not removed. + +func @main() -> tensor { + // expected-error@+1 {{expects no uses}} + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.VarIsInitializedOp"(%0) : (tensor>>) -> tensor + return %1 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index d3cc508a490..97661bb204b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -25,14 +25,19 @@ limitations under the License. // . Dead functions have already been removed, as resource arguments in dead // functions can cause the pass to fail. -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/FormatVariadic.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -41,6 +46,11 @@ namespace mlir { namespace TF { namespace { +constexpr char kResourceFunctionMsg[] = + "expects function level resource argument"; +constexpr char kInvalidResourceMsg[] = + "expects resource to be a VarHandleOp or function argument"; + // Records the input argument index and the current live value for a resource // variable. struct ResourceInfo { @@ -48,135 +58,171 @@ struct ResourceInfo { Value live_value; }; -using ResourceMap = llvm::SmallDenseMap; +using ArgOrName = llvm::PointerUnion; +using ResourceMap = llvm::SmallDenseMap; LogicalResult VerifyNoPotentialNestedResourceAccesses(ModuleOp module) { - LogicalResult result = success(); - module.walk([&](FuncOp func) { - for (auto type : func.getType().getInputs()) { - if (getElementTypeOrSelf(type).isa()) { - result = - func.emitError("potential nested resource accesses in function"); - break; - } - } + auto result = module.walk([&](FuncOp func) -> WalkResult { + // Skip main function as resources can be passed in as arguments. + if (func.getName() == "main") return WalkResult::advance(); + + for (auto type : func.getType().getInputs()) + if (getElementTypeOrSelf(type).isa()) + return func.emitError("potential nested resource accesses in function"); + + return WalkResult::advance(); }); - return result; + return failure(result.wasInterrupted()); } LogicalResult PromoteResourcesToArguments(FuncOp function) { - // This routine should only be called when control flow operations are still - // represented with TF IfOp and WhileOp operations. In this case, there should - // be only one basic blocks in the MLIR representation. - if (!has_single_element(function.getBlocks())) { - return function.emitError() - << "expect the function to have 1 block while it has " - << function.getBlocks().size(); - } + Block& block = function.front(); + + auto return_op = llvm::dyn_cast_or_null(block.getTerminator()); + if (!return_op) + return function.emitError( + "expects 'main' function to have a MLIR ReturnOp"); ResourceMap resource_map; - std::vector new_input_types = function.getType().getInputs().vec(); - int64_t input_num = function.getNumArguments(); + auto argument_types = llvm::to_vector<4>(function.getType().getInputs()); + + // Loop through the resource arguments in the function and store a mapping + // from that argument to its index and itself as the current live value. + for (BlockArgument& func_arg : function.getArguments()) { + auto resource_type = + getElementTypeOrSelf(func_arg.getType()).dyn_cast(); + if (!resource_type) continue; + if (resource_type.getSubtypes().size() != 1) + return function.emitError() + << "expects resource type of argument " << func_arg.getArgNumber() + << " to have one subtype, got " << resource_type; + + for (auto* user : func_arg.getUsers()) + if (!llvm::isa(user) && + !llvm::isa(user)) + return function.emitError() + << "expects users of resource argument " + << func_arg.getArgNumber() + << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp'"; + + Type arg_type = resource_type.getSubtypes().front(); + func_arg.setType(arg_type); + resource_map[func_arg] = {func_arg.getArgNumber(), func_arg}; + argument_types[func_arg.getArgNumber()] = arg_type; + } // Loop through the VarHandleOp in the function. When the first VarHandleOp // for a resource variable is encountered, create a new function argument and // add an entry to the resource_map to record the information. - for (auto var_handle_op : function.front().getOps()) { - if (resource_map.count(var_handle_op.shared_name())) { - continue; - } + for (auto var_handle_op : block.getOps()) { + if (resource_map.count(var_handle_op.shared_nameAttr())) continue; auto resource_type = getElementTypeOrSelf(var_handle_op.getType()).cast(); - if (!resource_type || resource_type.getSubtypes().size() != 1) { - return var_handle_op.emitError("unrecognized resource type"); - } + if (resource_type.getSubtypes().size() != 1) + return var_handle_op.emitOpError() + << "expects resource type to have one subtype, got " + << resource_type; + Type arg_type = resource_type.getSubtypes().front(); - BlockArgument arg = function.front().addArgument(arg_type); - new_input_types.push_back(arg_type); - resource_map[var_handle_op.shared_name()] = {input_num++, arg}; + BlockArgument arg = block.addArgument(arg_type); + resource_map[var_handle_op.shared_nameAttr()] = { + static_cast(argument_types.size()), arg}; + argument_types.push_back(arg_type); } - if (resource_map.empty()) { - return success(); - } + if (resource_map.empty()) return success(); // We initially assign the argument for a resource as the live value for the // resource. We then walk through the operations in the function in their // lexical order, to update the live value for the resource when we see a // store to the resource and replace reads of the resource with uses of its // live value. - for (Operation& op : llvm::make_early_inc_range(function.front())) { + for (Operation& op : llvm::make_early_inc_range(block)) { if (auto read_op = llvm::dyn_cast(&op)) { - auto var_handle_op = - llvm::dyn_cast(read_op.resource().getDefiningOp()); - if (!var_handle_op) { - return read_op.emitError("resource is not VarHandleOp"); + if (auto func_arg = read_op.resource().dyn_cast()) { + if (func_arg.getOwner() != &block) + return read_op.emitOpError(kResourceFunctionMsg); + + read_op.value().replaceAllUsesWith(resource_map[func_arg].live_value); + } else if (auto var_handle_op = llvm::dyn_cast( + read_op.resource().getDefiningOp())) { + read_op.value().replaceAllUsesWith( + resource_map[var_handle_op.shared_nameAttr()].live_value); + } else { + return read_op.emitOpError(kInvalidResourceMsg); } - read_op.value().replaceAllUsesWith( - resource_map[var_handle_op.shared_name()].live_value); + read_op.erase(); } else if (auto write_op = llvm::dyn_cast(&op)) { - auto var_handle_op = - llvm::dyn_cast(write_op.resource().getDefiningOp()); - if (!var_handle_op) { - return write_op.emitError("resource is not VarHandleOp"); + if (auto func_arg = write_op.resource().dyn_cast()) { + if (func_arg.getOwner() != &block) + return write_op.emitOpError(kResourceFunctionMsg); + + resource_map[func_arg].live_value = write_op.value(); + } else if (auto var_handle_op = llvm::dyn_cast( + write_op.resource().getDefiningOp())) { + resource_map[var_handle_op.shared_nameAttr()].live_value = + write_op.value(); + } else { + return read_op.emitOpError(kInvalidResourceMsg); } - resource_map[var_handle_op.shared_name()].live_value = write_op.value(); + write_op.erase(); } } - auto return_op = llvm::dyn_cast(function.front().getTerminator()); - if (!return_op) { - return function.emitError("the function doesn't have an MLIR ReturnOp"); - } + const int64_t num_results_before = function.getNumResults(); + auto return_operands = llvm::to_vector<4>(return_op.getOperands()); + return_operands.reserve(num_results_before + resource_map.size()); + auto result_types = llvm::to_vector<4>(return_op.getOperandTypes()); + result_types.reserve(num_results_before + resource_map.size()); + llvm::SmallVector, 4> input_output_alias; + input_output_alias.reserve(resource_map.size()); - int64_t output_num = return_op.getNumOperands(); - llvm::SmallVector new_return_operands(return_op.getOperands()); - std::vector> input_output_alias; - std::vector new_return_types = function.getType().getResults().vec(); + // Collect new return values and mapping from resource input index to output + // alias. If the last live value is itself (argument), then that live value + // will not be returned as the resource is unmodified. + for (auto& resource : resource_map) { + int64_t input_index = resource.getSecond().input_index; + Value live_value = resource.getSecond().live_value; + auto live_arg = live_value.dyn_cast(); + if (live_arg && live_arg.getOwner() == &block && + live_arg.getArgNumber() == input_index) + continue; - // If the live value of a resource is not an argument, then the resource is - // updated by the function. Add the resource live value to the ReturnOp of the - // function and record the input-output aliasing. - for (Operation& op : function.front()) { - if (auto var_handle_op = llvm::dyn_cast(&op)) { - ResourceInfo& resource_info = resource_map[var_handle_op.shared_name()]; - Value live_value = resource_info.live_value; - if (!live_value.isa()) { - new_return_operands.push_back(live_value); - input_output_alias.push_back( - std::make_pair(resource_info.input_index, output_num++)); - new_return_types.push_back(live_value.getType()); - } - } + return_operands.push_back(live_value); + result_types.push_back(live_value.getType()); + input_output_alias.push_back( + {input_index, num_results_before + input_output_alias.size()}); } // Erase all VarHandleOp. for (Operation& op : llvm::make_early_inc_range(function.front())) { - if (llvm::isa(&op)) { - op.erase(); - } + auto var_handle_op = llvm::dyn_cast(op); + if (!var_handle_op) continue; + if (!var_handle_op.use_empty()) + return var_handle_op.emitOpError() << "expects no uses"; + + op.erase(); } + // Rewrite return if more results need to be returned by the function. OpBuilder builder(return_op); - function.setType(builder.getFunctionType(new_input_types, new_return_types)); - - if (input_output_alias.empty()) { - return success(); + if (!input_output_alias.empty()) { + builder.create(return_op.getLoc(), return_operands); + return_op.erase(); } - builder.create(return_op.getLoc(), new_return_operands); - return_op.erase(); + // Update function argument and result types with new resource subtypes. + function.setType(builder.getFunctionType(argument_types, result_types)); // Add aliasing_output attribute to the input argument for the resources that // are updated by the function. - for (auto input_output : input_output_alias) { + for (auto& input_output : input_output_alias) function.setArgAttr(input_output.first, "tf.aliasing_output", builder.getI64IntegerAttr(input_output.second)); - } return success(); } @@ -190,15 +236,21 @@ class PromoteResourcesToArgsPass void PromoteResourcesToArgsPass::runOnModule() { ModuleOp module = getModule(); FuncOp main_func = module.lookupSymbol("main"); - if (!main_func) { - return; + if (!main_func) return; + + // This routine should only be called when control flow operations are still + // represented with TF IfOp and WhileOp operations. In this case, there should + // be only one basic blocks in the MLIR representation. + if (!has_single_element(main_func.getBlocks())) { + main_func.emitError() << "expects 'main' function to have 1 block, got " + << main_func.getBlocks().size(); + return signalPassFailure(); } if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) || failed(VerifyNoPotentialNestedResourceAccesses(module)) || - failed(PromoteResourcesToArguments(main_func))) { + failed(PromoteResourcesToArguments(main_func))) return signalPassFailure(); - } } } // namespace From 8df7014814bd563adca85ea23be15dc6e0e743d4 Mon Sep 17 00:00:00 2001 From: Ayush Dubey Date: Tue, 17 Mar 2020 16:30:18 -0700 Subject: [PATCH 109/492] Add a test for changing input shape of a collective instance across continuous runs. cl/301471698 changed the behavior of `enable_collective_ops` so that it invokes `GrpcServer::UpdateServerDef`, which in turn creates a new `CollectiveParamResolverLocal` instance. This change adds a test that verifies that because there is a new param resolver, we can reuse collective instance keys with a shape that is different from the previous run. PiperOrigin-RevId: 301477942 Change-Id: I4e46abceb0b01e3583f5491f3c18b8face88d827 --- .../multi_worker_continuous_run_test.py | 70 ++++++++++++++----- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/distribute/multi_worker_continuous_run_test.py b/tensorflow/python/distribute/multi_worker_continuous_run_test.py index df30c3f6e3f..90484a12423 100644 --- a/tensorflow/python/distribute/multi_worker_continuous_run_test.py +++ b/tensorflow/python/distribute/multi_worker_continuous_run_test.py @@ -35,6 +35,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import config from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope NUM_WORKERS = 5 @@ -42,43 +43,78 @@ NUM_WORKERS = 5 # TODO(b/143286947): expand the test to cover fault tolerance and elasticity class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): + def setUp(self): + self._gpus = config.list_physical_devices('GPU') + self._local_device = '/device:GPU:0' if self._gpus else '/device:CPU:0' + super(MultiWorkerContinuousRunTest, self).setUp() + + def _maybe_setup_gpus(self): + if self._gpus: + # Set virtual GPU with memory limit of 64MB so that multiple worker + # processes can share the physical GPU + config.set_logical_device_configuration( + self._gpus[0], [context.LogicalDeviceConfiguration(64)]) + @combinations.generate(combinations.combine(mode=['eager'])) def testAllReduceContinuousRun(self, mode): tensor_shape = [2, 2] - local_device = '/device:CPU:0' - if config.list_physical_devices('GPU'): - local_device = '/device:GPU:0' - def worker_step_fn(): + def worker_step_fn(worker_id): strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() # Make sure the processeses are in sync after updating the cluster multi_process_runner.barrier().wait() - tf_config = json.loads(os.environ['TF_CONFIG']) - worker_id = tf_config['task']['index'] - @def_function.function def run_reduce(): - with ops.device(local_device): + with ops.device(self._local_device): t_in = array_ops.ones(tensor_shape) * worker_id return strategy.reduce(reduce_util.ReduceOp.MEAN, t_in, axis=None) t_out = run_reduce() # Element values from the workers are - # 0, 1, ..., (num_workers - 1) + # 0, 1, ..., (NUM_WORKERS - 1) expected_mean = (NUM_WORKERS - 1) / 2 expected_out = np.ones(tensor_shape) * expected_mean self.assertAllClose(t_out, expected_out) def worker_fn(): - gpus = config.list_physical_devices('GPU') - if gpus: - # Set virtual GPU with memory limit of 64MB so that multiple worker - # processes can share the physical GPU - config.set_logical_device_configuration( - gpus[0], [context.LogicalDeviceConfiguration(64)]) - for _ in range(100): - worker_step_fn() + self._maybe_setup_gpus() + tf_config = json.loads(os.environ['TF_CONFIG']) + worker_id = tf_config['task']['index'] + for _ in range(20): + worker_step_fn(worker_id) + + multi_process_runner.run( + worker_fn, + cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS)) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testVariableInitializationWithChangingShape(self, mode): + + def worker_step_fn(worker_id, num_dims): + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() + # Make sure the processeses are in sync after updating the cluster + multi_process_runner.barrier().wait() + tensor_shape = [2] * num_dims + + def variable_fn(): + with ops.device(self._local_device): + # The initial value will be broadcasted from worker 0 to others. + initial_value = (array_ops.ones(tensor_shape) if worker_id == 0 else + array_ops.zeros(tensor_shape)) + var = variable_scope.get_variable(name='x', initializer=initial_value) + return array_ops.identity(var) + + t_out = strategy.extended.call_for_each_replica(variable_fn) + expected_out = np.ones(tensor_shape) + self.assertAllClose(t_out, expected_out) + + def worker_fn(): + self._maybe_setup_gpus() + tf_config = json.loads(os.environ['TF_CONFIG']) + worker_id = tf_config['task']['index'] + for i in range(20): + worker_step_fn(worker_id, num_dims=(i + 1)) multi_process_runner.run( worker_fn, From 151a62f3d891d071cbbad7faeb477c2233ec1f31 Mon Sep 17 00:00:00 2001 From: Anna R Date: Tue, 17 Mar 2020 16:36:04 -0700 Subject: [PATCH 110/492] Remove underscore in front of _list_to_tuple.py in nest namespace. There is no need to have both nest.list_to_tuple and nest._list_to_tuple anymore. PiperOrigin-RevId: 301479033 Change-Id: I894c0d07dab768f6bf7d2a13bc96907d1c1c6ec5 --- tensorflow/python/keras/engine/compile_utils.py | 10 ++++------ tensorflow/python/keras/engine/data_adapter.py | 4 ++-- tensorflow/python/util/nest.py | 6 +----- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/keras/engine/compile_utils.py b/tensorflow/python/keras/engine/compile_utils.py index 82aed487b85..fd792e0ee8c 100644 --- a/tensorflow/python/keras/engine/compile_utils.py +++ b/tensorflow/python/keras/engine/compile_utils.py @@ -304,12 +304,10 @@ class MetricsContainer(Container): self._weighted_metrics) # Standardize on tuple since `tf.data` turns lists into `Tensor`s. - # pylint: disable=protected-access - y_pred = nest._list_to_tuple(y_pred) - y_true = nest._list_to_tuple(y_true) - self._metrics = nest._list_to_tuple(self._metrics) - self._weighted_metrics = nest._list_to_tuple(self._weighted_metrics) - # pylint: enable=protected-access + y_pred = nest.list_to_tuple(y_pred) + y_true = nest.list_to_tuple(y_true) + self._metrics = nest.list_to_tuple(self._metrics) + self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics) # Convert to `Metric` objects, potentially disambiguating based on output # properties. diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 5f1e8e2de64..43eefd75c36 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -830,7 +830,7 @@ class GeneratorDataAdapter(DataAdapter): x, y, sample_weight = unpack_x_y_sample_weight(data) data = pack_x_y_sample_weight(x, y, sample_weight) - data = nest._list_to_tuple(data) # pylint: disable=protected-access + data = nest.list_to_tuple(data) def _convert_dtype(t): if (isinstance(t, np.ndarray) and issubclass(t.dtype.type, np.floating)): @@ -1023,7 +1023,7 @@ def _process_tensorlike(inputs): return x inputs = nest.map_structure(_convert_numpy_and_scipy, inputs) - return nest._list_to_tuple(inputs) # pylint: disable=protected-access + return nest.list_to_tuple(inputs) def is_none_or_empty(inputs): diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index f69179d80dd..d215fb632b3 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -1375,7 +1375,7 @@ def flatten_with_tuple_paths(structure, expand_composites=False): flatten(structure, expand_composites=expand_composites))) -def _list_to_tuple(structure): +def list_to_tuple(structure): """Replace all lists with tuples. The fork of nest that tf.data uses treats lists as single elements, while @@ -1398,10 +1398,6 @@ def _list_to_tuple(structure): sequence_fn=sequence_fn) -# TODO(b/143287251): Only have `list_to_tuple` -list_to_tuple = _list_to_tuple - - _pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping) _pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping) _pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence) From 1eb5b2c61f81a6cd95823b7e81fc81b98cd15460 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 17:47:41 -0700 Subject: [PATCH 111/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301491007 Change-Id: Ic6a724db9937a9433e760de7c657a52cfe742515 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 52a9bf9551b..6456f104ad3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From 248bc00ab1bae639a49a462c539d5cac03ca762b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 17 Mar 2020 17:55:54 -0700 Subject: [PATCH 112/492] [TF:MLIR] Enable MLIR graph optimizations via tf.options Add MLIR based layout optimization to MLIR graph optimization pass. PiperOrigin-RevId: 301492207 Change-Id: I205bb08e3d071839f785be81250e992e59c6d129 --- .../transforms/graph_optimization_pass.cc | 42 ++++++++++++++++--- .../transforms/graph_optimization_pass.h | 13 +++--- .../graph_optimization_pass_registration.cc | 2 +- .../transforms/layout_optimization.cc | 6 +-- tensorflow/python/eager/context.py | 12 ++++++ tensorflow/python/framework/config.py | 23 ++++++++++ tensorflow/python/framework/config_test.py | 16 +++++++ .../v1/tensorflow.config.experimental.pbtxt | 8 ++++ .../v2/tensorflow.config.experimental.pbtxt | 8 ++++ 9 files changed, 116 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index 281a6011af6..c563a98d8c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -15,19 +15,51 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h" -namespace tensorflow { +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" + +namespace mlir { +namespace TF { +namespace { +using Status = ::tensorflow::Status; +using ConfigProto = ::tensorflow::ConfigProto; +} // namespace Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto, - mlir::ModuleOp module) { + ModuleOp module) { if (!config_proto.experimental().enable_mlir_graph_optimization()) { VLOG(1) << "Skipping MLIR Graph Optimization Pass" << ", session flag not enabled"; return Status::OK(); } - // TODO(ezhulenev): Add something here. + VLOG(1) << "Run MLIR Graph Optimization Passes"; + PassManager pm(module.getContext()); - return Status::OK(); + // Run island coarsening before shape inference to allow more exact shape + // inference using constant folding within islands. + pm.addNestedPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); + pm.addPass(CreateTFShapeInferencePass()); + + // Assign optimal data layout to layout sensitive operations and delete + // redundant transposes from the IR. + LayoutOptimizationPipelineOptions layout_optimization_options; + CreateLayoutOptimizationPipeline(pm, layout_optimization_options); + + // Prepare IR for exporting. + pm.addNestedPass(CreateBreakUpIslandsPass()); + + // In case of failure, the `diag_handler` converts MLIR errors emitted to the + // MLIRContext into a tensorflow::Status. + StatusScopedDiagnosticHandler diag_handler(module.getContext()); + LogicalResult result = pm.run(module); + (void)result; + return diag_handler.ConsumeStatus(); } -} // namespace tensorflow +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h index 955da470494..5bab0ffab7e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h @@ -18,21 +18,24 @@ limitations under the License. #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" -namespace tensorflow { +namespace mlir { +namespace TF { // Bundle generic MLIR graph optimization passes (some derived from TF Grappler // graph optimizers) into a single MLIR optimization pass. -class MlirGraphOptimizationPass : public MlirOptimizationPass { +class MlirGraphOptimizationPass : public ::tensorflow::MlirOptimizationPass { public: llvm::StringRef name() const override { return "graph_optimization"; } - bool IsEnabled(const ConfigProto& config_proto) const override { + bool IsEnabled(const ::tensorflow::ConfigProto& config_proto) const override { return config_proto.experimental().enable_mlir_graph_optimization(); } - Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override; + ::tensorflow::Status Run(const ::tensorflow::ConfigProto& config_proto, + ModuleOp module) override; }; -} // namespace tensorflow +} // namespace TF +} // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc index 4681f8a0f33..ba72f0a0966 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc @@ -25,6 +25,6 @@ constexpr int kMlirGraphOptimizationPriority = 0; static mlir_pass_registration::MlirOptimizationPassRegistration register_mlir_graph_optimization_pass( kMlirGraphOptimizationPriority, - std::make_unique()); + std::make_unique()); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index 7d65d16e42d..cb84be5748c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -95,7 +95,9 @@ void LayoutAssignmentPass::runOnFunction() { // Get runtime devices information from the closest parent module. RuntimeDevices devices; - ::tensorflow::GetDevicesFromOp(func.getParentOfType(), &devices); + if (failed(::tensorflow::GetDevicesFromOp(func.getParentOfType(), + &devices))) + return signalPassFailure(); // If there is no runtime device information and data format is not explicitly // forced, there is nothing to do. @@ -419,8 +421,6 @@ void CreateLayoutOptimizationPipeline( const LayoutOptimizationPipelineOptions& options) { using Direction = MoveTransposesPass::Direction; - if (options.force_data_format.empty()) return; - // Assign optimal layout for layout sensitive ops. pm.addPass(std::make_unique(options.force_data_format)); diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index a36d2142329..ab2e18ed99d 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -430,6 +430,7 @@ class Context(object): self._soft_device_placement = None self._log_device_placement = None self._enable_mlir_bridge = None + self._enable_mlir_graph_optimization = None self._optimizer_experimental_options = {} _python_eager_context_create_counter.get_cell().increase_by(1) @@ -908,6 +909,9 @@ class Context(object): if self._enable_mlir_bridge is not None: config.experimental.enable_mlir_bridge = self._enable_mlir_bridge + if self._enable_mlir_graph_optimization is not None: + config.experimental.enable_mlir_graph_optimization = ( + self._enable_mlir_graph_optimization) def rewriter_toggle(option): toggle = self._optimizer_experimental_options.get(option, None) @@ -1376,10 +1380,18 @@ class Context(object): def enable_mlir_bridge(self): return self._enable_mlir_bridge + @property + def enable_mlir_graph_optimization(self): + return self._enable_mlir_graph_optimization + @enable_mlir_bridge.setter def enable_mlir_bridge(self, enabled): self._enable_mlir_bridge = enabled + self._thread_local_data.function_call_options = None + @enable_mlir_graph_optimization.setter + def enable_mlir_graph_optimization(self, enabled): + self._enable_mlir_graph_optimization = enabled self._thread_local_data.function_call_options = None @property diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index c696675fed8..5361d7290e8 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -626,7 +626,30 @@ def enable_mlir_bridge(): context.context().enable_mlir_bridge = True +@tf_export('config.experimental.enable_mlir_graph_optimization') +def enable_mlir_graph_optimization(): + """Enables experimental MLIR-Based TensorFlow Compiler Optimizations. + + DO NOT USE, DEV AND TESTING ONLY AT THE MOMENT. + + NOTE: MLIR-Based TensorFlow Compiler is under active development and has + missing features, please refrain from using. This API exists for development + and testing only. + + TensorFlow Compiler Optimizations are responsible general graph level + optimizations that in the current stack mostly done by Grappler graph + optimizers. + """ + context.context().enable_mlir_graph_optimization = True + + @tf_export('config.experimental.disable_mlir_bridge') def disable_mlir_bridge(): """Disables experimental MLIR-Based TensorFlow Compiler Bridge.""" context.context().enable_mlir_bridge = False + + +@tf_export('config.experimental.disable_mlir_graph_optimization') +def disable_mlir_graph_optimization(): + """Disables experimental MLIR-Based TensorFlow Compiler Optimizations.""" + context.context().enable_mlir_graph_optimization = False diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index 2ef7d737d73..b07bb874385 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -223,6 +223,22 @@ class ConfigTest(test.TestCase, parameterized.TestCase): config.disable_mlir_bridge() self.assertFalse(context.context().config.experimental.enable_mlir_bridge) + @reset_eager + def testEnableMlirGraphOptimization(self): + # Default value of enable_mlir_graph_optimization is false. + self.assertFalse( + context.context().config.experimental.enable_mlir_graph_optimization) + + # Tests enabling mlir graph optimization. + config.enable_mlir_graph_optimization() + self.assertTrue( + context.context().config.experimental.enable_mlir_graph_optimization) + + # Tests disabling mlir graph optimization. + config.disable_mlir_graph_optimization() + self.assertFalse( + context.context().config.experimental.enable_mlir_graph_optimization) + @test_util.run_gpu_only @reset_eager def testJit(self): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt index b8f92b30099..0f3558e844e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt @@ -12,10 +12,18 @@ tf_module { name: "disable_mlir_bridge" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "disable_mlir_graph_optimization" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "enable_mlir_bridge" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "enable_mlir_graph_optimization" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_device_policy" argspec: "args=[], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt index b8f92b30099..0f3558e844e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt @@ -12,10 +12,18 @@ tf_module { name: "disable_mlir_bridge" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "disable_mlir_graph_optimization" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "enable_mlir_bridge" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "enable_mlir_graph_optimization" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_device_policy" argspec: "args=[], varargs=None, keywords=None, defaults=None" From 3404ca7b5da431d95089e96cb6a749f97cb9de10 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Tue, 17 Mar 2020 17:58:35 -0700 Subject: [PATCH 113/492] Creates `framework_lib` target. PiperOrigin-RevId: 301492628 Change-Id: I5f819766e3a1437cfd1ffa28acb57958900e7e39 --- tensorflow/lite/BUILD | 66 ++++++++++++++++++++++++++--------- tensorflow/lite/kernels/BUILD | 4 +-- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 5e22b1fed5c..9c4740b8c0a 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -62,6 +62,22 @@ TFLITE_DEFAULT_COPTS = if_not_windows([ "-Wno-extern-c-compat", ]) +FRAMEWORK_LIB_HDRS = [ + "allocation.h", + "context.h", + "context_util.h", + "core/macros.h", + "core/subgraph.h", + "error_reporter.h", + "graph_info.h", + "interpreter.h", + "model.h", + "mutable_op_resolver.h", + "op_resolver.h", + "optional_debug_tools.h", + "stderr_reporter.h", +] + cc_library( name = "version", hdrs = ["version.h"], @@ -200,9 +216,8 @@ cc_library( ], ) -# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. cc_library( - name = "framework", + name = "framework_lib", srcs = [ "core/subgraph.cc", "graph_info.cc", @@ -212,23 +227,42 @@ cc_library( "optional_debug_tools.cc", "stderr_reporter.cc", ], - hdrs = [ - "allocation.h", - "context.h", - "context_util.h", - "core/macros.h", - "core/subgraph.h", - "error_reporter.h", - "graph_info.h", - "interpreter.h", - "model.h", - "mutable_op_resolver.h", - "op_resolver.h", - "optional_debug_tools.h", - "stderr_reporter.h", + hdrs = FRAMEWORK_LIB_HDRS, + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//tensorflow/lite:__subpackages__", ], + deps = [ + ":allocation", + ":arena_planner", + ":external_cpu_backend_context", + ":graph_info", + ":memory_planner", + ":minimal_logging", + ":simple_memory_arena", + ":string", + ":type_to_tflitetype", + ":util", + ":version", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + "//tensorflow/lite/experimental/resource", + "//tensorflow/lite/nnapi:nnapi_implementation", + "//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) + +# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. +cc_library( + name = "framework", + srcs = [ + ], + hdrs = FRAMEWORK_LIB_HDRS, copts = tflite_copts() + TFLITE_DEFAULT_COPTS, deps = [ + ":framework_lib", ":allocation", ":arena_planner", ":external_cpu_backend_context", diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 57e9b876ec1..1f04cc3ee47 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -526,7 +526,7 @@ cc_library( ":lstm_shared", ":op_macros", ":padding", - "//tensorflow/lite:framework", + "//tensorflow/lite:framework_lib", "//tensorflow/lite:minimal_logging", "//tensorflow/lite:string_util", "//tensorflow/lite/c:common", @@ -660,7 +660,7 @@ cc_library( ], deps = [ ":builtin_op_kernels", - "//tensorflow/lite:framework", + "//tensorflow/lite:framework_lib", "//tensorflow/lite/c:common", ], ) From cfc31e324c8de6b52f752a39cb161d99d853ca99 Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Tue, 17 Mar 2020 18:11:52 -0700 Subject: [PATCH 114/492] Update XNNPACK and cpuinfo dependencies PiperOrigin-RevId: 301494671 Change-Id: I7444d1e1c0562994de775d171aae30f352259831 --- tensorflow/workspace.bzl | 8 +- third_party/cpuinfo/BUILD.bazel | 1 - third_party/cpuinfo/cpuinfo.patch | 3016 +++++++++++++++++++++++++++++ third_party/cpuinfo/workspace.bzl | 14 +- 4 files changed, 3030 insertions(+), 9 deletions(-) create mode 100644 third_party/cpuinfo/cpuinfo.patch diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 2ff181e1b25..1277a72416f 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -148,11 +148,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "XNNPACK", - sha256 = "100f675c099c74da46dea8da025f6f9b5e0307370f3dde506d11bd78b2b7d171", - strip_prefix = "XNNPACK-4ea95bef8cdd942895f23f5cc09c778d10500551", + sha256 = "190e61e50af3497bb46b8d936bd2d2d551a9aeedb02ff66388918408a54e216a", + strip_prefix = "XNNPACK-b18783570f0643560be641b193367d3906955141", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/4ea95bef8cdd942895f23f5cc09c778d10500551.zip", - "https://github.com/google/XNNPACK/archive/4ea95bef8cdd942895f23f5cc09c778d10500551.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/b18783570f0643560be641b193367d3906955141.zip", + "https://github.com/google/XNNPACK/archive/b18783570f0643560be641b193367d3906955141.zip", ], ) diff --git a/third_party/cpuinfo/BUILD.bazel b/third_party/cpuinfo/BUILD.bazel index cea88aafbd9..afa0b9798a5 100644 --- a/third_party/cpuinfo/BUILD.bazel +++ b/third_party/cpuinfo/BUILD.bazel @@ -42,7 +42,6 @@ ARM_SRCS = [ # Platform-specific sources and headers LINUX_SRCS = [ "src/linux/cpulist.c", - "src/linux/current.c", "src/linux/multiline.c", "src/linux/processors.c", "src/linux/smallfile.c", diff --git a/third_party/cpuinfo/cpuinfo.patch b/third_party/cpuinfo/cpuinfo.patch new file mode 100644 index 00000000000..a9fa0dde0eb --- /dev/null +++ b/third_party/cpuinfo/cpuinfo.patch @@ -0,0 +1,3016 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index de319ef..fefb60b 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -179,7 +179,6 @@ IF(CPUINFO_SUPPORTED_PLATFORM) + LIST(APPEND CPUINFO_SRCS + src/linux/smallfile.c + src/linux/multiline.c +- src/linux/current.c + src/linux/cpulist.c + src/linux/processors.c) + ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +diff --git a/CMakeLists.txt.orig b/CMakeLists.txt.orig +deleted file mode 100644 +index a71aede..0000000 +--- a/CMakeLists.txt.orig ++++ /dev/null +@@ -1,819 +0,0 @@ +-CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) +- +-INCLUDE(GNUInstallDirs) +- +-# ---[ Project and semantic versioning. +-PROJECT(cpuinfo C CXX) +- +-# ---[ Options. +-SET(CPUINFO_LIBRARY_TYPE "default" CACHE STRING "Type of cpuinfo library (shared, static, or default) to build") +-SET_PROPERTY(CACHE CPUINFO_LIBRARY_TYPE PROPERTY STRINGS default static shared) +-SET(CPUINFO_RUNTIME_TYPE "default" CACHE STRING "Type of runtime library (shared, static, or default) to use") +-SET_PROPERTY(CACHE CPUINFO_RUNTIME_TYPE PROPERTY STRINGS default static shared) +-SET(CPUINFO_LOG_LEVEL "default" CACHE STRING "Minimum logging level (info with lower severity will be ignored)") +-SET_PROPERTY(CACHE CPUINFO_LOG_LEVEL PROPERTY STRINGS default debug info warning error fatal none) +-OPTION(CPUINFO_BUILD_TOOLS "Build command-line tools" ON) +-OPTION(CPUINFO_BUILD_UNIT_TESTS "Build cpuinfo unit tests" ON) +-OPTION(CPUINFO_BUILD_MOCK_TESTS "Build cpuinfo mock tests" ON) +-OPTION(CPUINFO_BUILD_BENCHMARKS "Build cpuinfo micro-benchmarks" ON) +- +-# ---[ CMake options +-IF(CPUINFO_BUILD_UNIT_TESTS OR CPUINFO_BUILD_MOCK_TESTS) +- ENABLE_TESTING() +-ENDIF() +- +-MACRO(CPUINFO_TARGET_ENABLE_C99 target) +- IF(${CMAKE_VERSION} VERSION_LESS "3.1") +- IF(NOT MSVC) +- TARGET_COMPILE_OPTIONS(${target} PRIVATE -std=c99) +- ENDIF() +- ELSE() +- SET_TARGET_PROPERTIES(${target} PROPERTIES +- C_STANDARD 99 +- C_EXTENSIONS NO) +- ENDIF() +-ENDMACRO() +- +-MACRO(CPUINFO_TARGET_ENABLE_CXX11 target) +- IF(${CMAKE_VERSION} VERSION_LESS "3.1") +- IF(NOT MSVC) +- TARGET_COMPILE_OPTIONS(${target} PRIVATE -std=c++11) +- ENDIF() +- ELSE() +- SET_TARGET_PROPERTIES(${target} PROPERTIES +- CXX_STANDARD 11 +- CXX_EXTENSIONS NO) +- ENDIF() +-ENDMACRO() +- +-MACRO(CPUINFO_TARGET_RUNTIME_LIBRARY target) +- IF(MSVC AND NOT CPUINFO_RUNTIME_TYPE STREQUAL "default") +- IF(CPUINFO_RUNTIME_TYPE STREQUAL "shared") +- TARGET_COMPILE_OPTIONS(${target} PRIVATE +- "/MD$<$:d>") +- ELSEIF(CPUINFO_RUNTIME_TYPE STREQUAL "static") +- TARGET_COMPILE_OPTIONS(${target} PRIVATE +- "/MT$<$:d>") +- ENDIF() +- ENDIF() +-ENDMACRO() +- +-# ---[ Build flags +-SET(CPUINFO_SUPPORTED_PLATFORM TRUE) +-IF(NOT CMAKE_SYSTEM_PROCESSOR) +- IF(NOT IOS) +- MESSAGE(WARNING +- "Target processor architecture is not specified. " +- "cpuinfo will compile, but cpuinfo_initialize() will always fail.") +- SET(CPUINFO_SUPPORTED_PLATFORM FALSE) +- ENDIF() +-ELSEIF(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64)$") +- MESSAGE(WARNING +- "Target processor architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in cpuinfo. " +- "cpuinfo will compile, but cpuinfo_initialize() will always fail.") +- SET(CPUINFO_SUPPORTED_PLATFORM FALSE) +-ENDIF() +- +-IF(NOT CMAKE_SYSTEM_NAME) +- MESSAGE(WARNING +- "Target operating system is not specified. " +- "cpuinfo will compile, but cpuinfo_initialize() will always fail.") +- SET(CPUINFO_SUPPORTED_PLATFORM FALSE) +-ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Windows|Darwin|Linux|Android)$") +- IF(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.14" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") +- MESSAGE(WARNING +- "Target operating system \"${CMAKE_SYSTEM_NAME}\" is not supported in cpuinfo. " +- "cpuinfo will compile, but cpuinfo_initialize() will always fail.") +- SET(CPUINFO_SUPPORTED_PLATFORM FALSE) +- ENDIF() +-ENDIF() +- +-# ---[ Download deps +-SET(CONFU_DEPENDENCIES_SOURCE_DIR ${CMAKE_SOURCE_DIR}/deps +- CACHE PATH "Confu-style dependencies source directory") +-SET(CONFU_DEPENDENCIES_BINARY_DIR ${CMAKE_BINARY_DIR}/deps +- CACHE PATH "Confu-style dependencies binary directory") +- +-IF(CPUINFO_BUILD_MOCK_TESTS OR CPUINFO_BUILD_UNIT_TESTS) +- IF(CPUINFO_SUPPORTED_PLATFORM AND NOT DEFINED GOOGLETEST_SOURCE_DIR) +- MESSAGE(STATUS "Downloading Google Test to ${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest (define GOOGLETEST_SOURCE_DIR to avoid it)") +- CONFIGURE_FILE(cmake/DownloadGoogleTest.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download/CMakeLists.txt") +- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . +- WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download") +- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . +- WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download") +- SET(GOOGLETEST_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest" CACHE STRING "Google Test source directory") +- ENDIF() +-ENDIF() +- +-IF(CPUINFO_BUILD_BENCHMARKS) +- IF(CPUINFO_SUPPORTED_PLATFORM AND NOT DEFINED GOOGLEBENCHMARK_SOURCE_DIR) +- MESSAGE(STATUS "Downloading Google Benchmark to ${CONFU_DEPENDENCIES_SOURCE_DIR}/googlebenchmark (define GOOGLEBENCHMARK_SOURCE_DIR to avoid it)") +- CONFIGURE_FILE(cmake/DownloadGoogleBenchmark.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download/CMakeLists.txt") +- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . +- WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download") +- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . +- WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download") +- SET(GOOGLEBENCHMARK_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googlebenchmark" CACHE STRING "Google Benchmark source directory") +- ENDIF() +-ENDIF() +- +-# ---[ cpuinfo library +-SET(CPUINFO_SRCS +- src/init.c +- src/api.c) +- +-IF(CPUINFO_SUPPORTED_PLATFORM) +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?)$" OR IOS_ARCH MATCHES "^(i386|x86_64)$") +- LIST(APPEND CPUINFO_SRCS +- src/x86/init.c +- src/x86/info.c +- src/x86/vendor.c +- src/x86/uarch.c +- src/x86/name.c +- src/x86/topology.c +- src/x86/isa.c +- src/x86/cache/init.c +- src/x86/cache/descriptor.c +- src/x86/cache/deterministic.c) +- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") +- LIST(APPEND CPUINFO_SRCS +- src/x86/linux/init.c +- src/x86/linux/cpuinfo.c) +- ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +- LIST(APPEND CPUINFO_SRCS src/x86/mach/init.c) +- ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Windows") +- LIST(APPEND CPUINFO_SRCS src/x86/windows/init.c) +- ENDIF() +- ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$" OR IOS_ARCH MATCHES "^(armv7.*|arm64.*)$") +- LIST(APPEND CPUINFO_SRCS +- src/arm/uarch.c +- src/arm/cache.c) +- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") +- LIST(APPEND CPUINFO_SRCS +- src/arm/linux/init.c +- src/arm/linux/cpuinfo.c +- src/arm/linux/clusters.c +- src/arm/linux/chipset.c +- src/arm/linux/midr.c +- src/arm/linux/hwcap.c) +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") +- LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch32-isa.c) +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND ANDROID_ABI STREQUAL "armeabi") +- SET_SOURCE_FILES_PROPERTIES(src/arm/linux/aarch32-isa.c PROPERTIES COMPILE_FLAGS -marm) +- ENDIF() +- ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") +- LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch64-isa.c) +- ENDIF() +- ELSEIF(IOS) +- LIST(APPEND CPUINFO_SRCS src/arm/mach/init.c) +- ENDIF() +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android") +- LIST(APPEND CPUINFO_SRCS +- src/arm/android/properties.c) +- ENDIF() +- ENDIF() +- +- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") +- LIST(APPEND CPUINFO_SRCS +- src/linux/smallfile.c +- src/linux/multiline.c +- src/linux/current.c +- src/linux/cpulist.c +- src/linux/processors.c) +- ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +- LIST(APPEND CPUINFO_SRCS src/mach/topology.c) +- ENDIF() +- +- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") +- SET(CMAKE_THREAD_PREFER_PTHREAD TRUE) +- SET(THREADS_PREFER_PTHREAD_FLAG TRUE) +- FIND_PACKAGE(Threads REQUIRED) +- ENDIF() +-ENDIF() +- +-IF(CPUINFO_LIBRARY_TYPE STREQUAL "default") +- ADD_LIBRARY(cpuinfo ${CPUINFO_SRCS}) +-ELSEIF(CPUINFO_LIBRARY_TYPE STREQUAL "shared") +- ADD_LIBRARY(cpuinfo SHARED ${CPUINFO_SRCS}) +-ELSEIF(CPUINFO_LIBRARY_TYPE STREQUAL "static") +- ADD_LIBRARY(cpuinfo STATIC ${CPUINFO_SRCS}) +-ELSE() +- MESSAGE(FATAL_ERROR "Unsupported library type ${CPUINFO_LIBRARY_TYPE}") +-ENDIF() +-ADD_LIBRARY(cpuinfo_internals STATIC ${CPUINFO_SRCS}) +-CPUINFO_TARGET_ENABLE_C99(cpuinfo) +-CPUINFO_TARGET_ENABLE_C99(cpuinfo_internals) +-CPUINFO_TARGET_RUNTIME_LIBRARY(cpuinfo) +-SET_TARGET_PROPERTIES(cpuinfo PROPERTIES PUBLIC_HEADER include/cpuinfo.h) +-TARGET_INCLUDE_DIRECTORIES(cpuinfo BEFORE PUBLIC include) +-TARGET_INCLUDE_DIRECTORIES(cpuinfo BEFORE PRIVATE src) +-TARGET_INCLUDE_DIRECTORIES(cpuinfo_internals BEFORE PUBLIC include src) +-IF(CPUINFO_LOG_LEVEL STREQUAL "default") +- # default logging level: error (subject to change) +- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=2) +-ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "debug") +- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=5) +-ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "info") +- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=4) +-ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "warning") +- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=3) +-ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "error") +- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=2) +-ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "fatal") +- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=1) +-ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "none") +- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=0) +-ELSE() +- MESSAGE(FATAL_ERROR "Unsupported logging level ${CPUINFO_LOG_LEVEL}") +-ENDIF() +-TARGET_COMPILE_DEFINITIONS(cpuinfo_internals PRIVATE CPUINFO_LOG_LEVEL=0) +- +-IF(CPUINFO_SUPPORTED_PLATFORM) +- TARGET_COMPILE_DEFINITIONS(cpuinfo INTERFACE CPUINFO_SUPPORTED_PLATFORM=1) +- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") +- TARGET_LINK_LIBRARIES(cpuinfo PUBLIC ${CMAKE_THREAD_LIBS_INIT}) +- TARGET_LINK_LIBRARIES(cpuinfo_internals PUBLIC ${CMAKE_THREAD_LIBS_INIT}) +- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE _GNU_SOURCE=1) +- TARGET_COMPILE_DEFINITIONS(cpuinfo_internals PRIVATE _GNU_SOURCE=1) +- ENDIF() +-ELSE() +- TARGET_COMPILE_DEFINITIONS(cpuinfo INTERFACE CPUINFO_SUPPORTED_PLATFORM=0) +-ENDIF() +- +-# ---[ cpuinfo dependencies: clog +-IF(NOT DEFINED CLOG_SOURCE_DIR) +- SET(CLOG_SOURCE_DIR "${PROJECT_SOURCE_DIR}/deps/clog") +-ENDIF() +-IF(NOT TARGET clog) +- SET(CLOG_BUILD_TESTS OFF CACHE BOOL "") +- SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "") +- ADD_SUBDIRECTORY( +- "${CLOG_SOURCE_DIR}") +- # We build static version of clog but a dynamic library may indirectly depend on it +- SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON) +-ENDIF() +-TARGET_LINK_LIBRARIES(cpuinfo PRIVATE clog) +-TARGET_LINK_LIBRARIES(cpuinfo_internals PRIVATE clog) +- +-INSTALL(TARGETS cpuinfo +- LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +- ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} +- PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) +- +-# ---[ cpuinfo micro-benchmarks +-IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_BENCHMARKS) +- # ---[ Build google benchmark +- IF(NOT TARGET benchmark) +- SET(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "") +- ADD_SUBDIRECTORY( +- "${GOOGLEBENCHMARK_SOURCE_DIR}" +- "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark") +- ENDIF() +- +- IF(CMAKE_SYSTEM_NAME MATCHES "^(Linux|Android)$") +- ADD_EXECUTABLE(get-current-bench bench/get-current.cc) +- TARGET_LINK_LIBRARIES(get-current-bench cpuinfo benchmark) +- ENDIF() +- +- ADD_EXECUTABLE(init-bench bench/init.cc) +- TARGET_LINK_LIBRARIES(init-bench cpuinfo benchmark) +-ENDIF() +- +-IF(CPUINFO_SUPPORTED_PLATFORM) +- IF(CPUINFO_BUILD_MOCK_TESTS OR CPUINFO_BUILD_UNIT_TESTS) +- # ---[ Build google test +- IF(NOT TARGET gtest) +- IF(MSVC AND NOT CPUINFO_RUNTIME_TYPE STREQUAL "static") +- SET(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +- ENDIF() +- ADD_SUBDIRECTORY( +- "${GOOGLETEST_SOURCE_DIR}" +- "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest") +- ENDIF() +- ENDIF() +-ENDIF() +- +-# ---[ cpuinfo mock library and mock tests +-IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) +- SET(CPUINFO_MOCK_SRCS "${CPUINFO_SRCS}") +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86_64)$") +- LIST(APPEND CPUINFO_MOCK_SRCS src/x86/mockcpuid.c) +- ENDIF() +- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") +- LIST(APPEND CPUINFO_MOCK_SRCS src/linux/mockfile.c) +- ENDIF() +- +- ADD_LIBRARY(cpuinfo_mock STATIC ${CPUINFO_MOCK_SRCS}) +- CPUINFO_TARGET_ENABLE_C99(cpuinfo_mock) +- CPUINFO_TARGET_RUNTIME_LIBRARY(cpuinfo_mock) +- SET_TARGET_PROPERTIES(cpuinfo_mock PROPERTIES PUBLIC_HEADER include/cpuinfo.h) +- TARGET_INCLUDE_DIRECTORIES(cpuinfo_mock BEFORE PUBLIC include) +- TARGET_INCLUDE_DIRECTORIES(cpuinfo_mock BEFORE PRIVATE src) +- TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PUBLIC CPUINFO_MOCK=1) +- TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE CLOG_LOG_TO_STDIO=1) +- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") +- TARGET_LINK_LIBRARIES(cpuinfo_mock PUBLIC ${CMAKE_THREAD_LIBS_INIT}) +- TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE _GNU_SOURCE=1) +- ENDIF() +- TARGET_LINK_LIBRARIES(cpuinfo_mock PRIVATE clog) +- +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a)$") +- ADD_EXECUTABLE(atm7029b-tablet-test test/mock/atm7029b-tablet.cc) +- TARGET_INCLUDE_DIRECTORIES(atm7029b-tablet-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(atm7029b-tablet-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(atm7029b-tablet-test atm7029b-tablet-test) +- +- ADD_EXECUTABLE(blu-r1-hd-test test/mock/blu-r1-hd.cc) +- TARGET_INCLUDE_DIRECTORIES(blu-r1-hd-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(blu-r1-hd-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(blu-r1-hd-test blu-r1-hd-test) +- +- ADD_EXECUTABLE(galaxy-a3-2016-eu-test test/mock/galaxy-a3-2016-eu.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-a3-2016-eu-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-a3-2016-eu-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-a3-2016-eu-test galaxy-a3-2016-eu-test) +- +- ADD_EXECUTABLE(galaxy-a8-2016-duos-test test/mock/galaxy-a8-2016-duos.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-a8-2016-duos-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-a8-2016-duos-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-a8-2016-duos-test galaxy-a8-2016-duos-test) +- +- ADD_EXECUTABLE(galaxy-grand-prime-value-edition-test test/mock/galaxy-grand-prime-value-edition.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-grand-prime-value-edition-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-grand-prime-value-edition-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-grand-prime-value-edition-test galaxy-grand-prime-value-edition-test) +- +- ADD_EXECUTABLE(galaxy-j1-2016-test test/mock/galaxy-j1-2016.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-j1-2016-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-j1-2016-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-j1-2016-test galaxy-j1-2016-test) +- +- ADD_EXECUTABLE(galaxy-j5-test test/mock/galaxy-j5.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-j5-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-j5-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-j5-test galaxy-j5-test) +- +- ADD_EXECUTABLE(galaxy-j7-prime-test test/mock/galaxy-j7-prime.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-j7-prime-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-j7-prime-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-j7-prime-test galaxy-j7-prime-test) +- +- ADD_EXECUTABLE(galaxy-j7-tmobile-test test/mock/galaxy-j7-tmobile.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-j7-tmobile-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-j7-tmobile-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-j7-tmobile-test galaxy-j7-tmobile-test) +- +- ADD_EXECUTABLE(galaxy-j7-uae-test test/mock/galaxy-j7-uae.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-j7-uae-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-j7-uae-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-j7-uae-test galaxy-j7-uae-test) +- +- ADD_EXECUTABLE(galaxy-s3-us-test test/mock/galaxy-s3-us.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s3-us-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s3-us-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s3-us-test galaxy-s3-us-test) +- +- ADD_EXECUTABLE(galaxy-s4-us-test test/mock/galaxy-s4-us.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s4-us-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s4-us-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s4-us-test galaxy-s4-us-test) +- +- ADD_EXECUTABLE(galaxy-s5-global-test test/mock/galaxy-s5-global.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s5-global-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s5-global-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s5-global-test galaxy-s5-global-test) +- +- ADD_EXECUTABLE(galaxy-s5-us-test test/mock/galaxy-s5-us.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s5-us-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s5-us-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s5-us-test galaxy-s5-us-test) +- +- ADD_EXECUTABLE(galaxy-tab-3-7.0-test test/mock/galaxy-tab-3-7.0.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-tab-3-7.0-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-tab-3-7.0-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-tab-3-7.0-test galaxy-tab-3-7.0-test) +- +- ADD_EXECUTABLE(galaxy-tab-3-lite-test test/mock/galaxy-tab-3-lite.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-tab-3-lite-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-tab-3-lite-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-tab-3-lite-test galaxy-tab-3-lite-test) +- +- ADD_EXECUTABLE(galaxy-win-duos-test test/mock/galaxy-win-duos.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-win-duos-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-win-duos-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-win-duos-test galaxy-win-duos-test) +- +- ADD_EXECUTABLE(huawei-ascend-p7-test test/mock/huawei-ascend-p7.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-ascend-p7-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-ascend-p7-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-ascend-p7-test huawei-ascend-p7-test) +- +- ADD_EXECUTABLE(huawei-honor-6-test test/mock/huawei-honor-6.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-honor-6-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-honor-6-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-honor-6-test huawei-honor-6-test) +- +- ADD_EXECUTABLE(lenovo-a6600-plus-test test/mock/lenovo-a6600-plus.cc) +- TARGET_INCLUDE_DIRECTORIES(lenovo-a6600-plus-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(lenovo-a6600-plus-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(lenovo-a6600-plus-test lenovo-a6600-plus-test) +- +- ADD_EXECUTABLE(lenovo-vibe-x2-test test/mock/lenovo-vibe-x2.cc) +- TARGET_INCLUDE_DIRECTORIES(lenovo-vibe-x2-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(lenovo-vibe-x2-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(lenovo-vibe-x2-test lenovo-vibe-x2-test) +- +- ADD_EXECUTABLE(lg-k10-eu-test test/mock/lg-k10-eu.cc) +- TARGET_INCLUDE_DIRECTORIES(lg-k10-eu-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(lg-k10-eu-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(lg-k10-eu-test lg-k10-eu-test) +- +- ADD_EXECUTABLE(lg-optimus-g-pro-test test/mock/lg-optimus-g-pro.cc) +- TARGET_INCLUDE_DIRECTORIES(lg-optimus-g-pro-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(lg-optimus-g-pro-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(lg-optimus-g-pro-test lg-optimus-g-pro-test) +- +- ADD_EXECUTABLE(moto-e-gen1-test test/mock/moto-e-gen1.cc) +- TARGET_INCLUDE_DIRECTORIES(moto-e-gen1-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(moto-e-gen1-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(moto-e-gen1-test moto-e-gen1-test) +- +- ADD_EXECUTABLE(moto-g-gen1-test test/mock/moto-g-gen1.cc) +- TARGET_INCLUDE_DIRECTORIES(moto-g-gen1-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(moto-g-gen1-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(moto-g-gen1-test moto-g-gen1-test) +- +- ADD_EXECUTABLE(moto-g-gen2-test test/mock/moto-g-gen2.cc) +- TARGET_INCLUDE_DIRECTORIES(moto-g-gen2-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(moto-g-gen2-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(moto-g-gen2-test moto-g-gen2-test) +- +- ADD_EXECUTABLE(moto-g-gen3-test test/mock/moto-g-gen3.cc) +- TARGET_INCLUDE_DIRECTORIES(moto-g-gen3-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(moto-g-gen3-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(moto-g-gen3-test moto-g-gen3-test) +- +- ADD_EXECUTABLE(moto-g-gen4-test test/mock/moto-g-gen4.cc) +- TARGET_INCLUDE_DIRECTORIES(moto-g-gen4-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(moto-g-gen4-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(moto-g-gen4-test moto-g-gen4-test) +- +- ADD_EXECUTABLE(moto-g-gen5-test test/mock/moto-g-gen5.cc) +- TARGET_INCLUDE_DIRECTORIES(moto-g-gen5-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(moto-g-gen5-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(moto-g-gen5-test moto-g-gen5-test) +- +- ADD_EXECUTABLE(nexus-s-test test/mock/nexus-s.cc) +- TARGET_INCLUDE_DIRECTORIES(nexus-s-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(nexus-s-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(nexus-s-test nexus-s-test) +- +- ADD_EXECUTABLE(nexus4-test test/mock/nexus4.cc) +- TARGET_INCLUDE_DIRECTORIES(nexus4-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(nexus4-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(nexus4-test nexus4-test) +- +- ADD_EXECUTABLE(nexus6-test test/mock/nexus6.cc) +- TARGET_INCLUDE_DIRECTORIES(nexus6-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(nexus6-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(nexus6-test nexus6-test) +- +- ADD_EXECUTABLE(nexus10-test test/mock/nexus10.cc) +- TARGET_INCLUDE_DIRECTORIES(nexus10-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(nexus10-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(nexus10-test nexus10-test) +- +- ADD_EXECUTABLE(padcod-10.1-test test/mock/padcod-10.1.cc) +- TARGET_INCLUDE_DIRECTORIES(padcod-10.1-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(padcod-10.1-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(padcod-10.1-test padcod-10.1-test) +- +- ADD_EXECUTABLE(xiaomi-redmi-2a-test test/mock/xiaomi-redmi-2a.cc) +- TARGET_INCLUDE_DIRECTORIES(xiaomi-redmi-2a-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(xiaomi-redmi-2a-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(xiaomi-redmi-2a-test xiaomi-redmi-2a-test) +- +- ADD_EXECUTABLE(xperia-sl-test test/mock/xperia-sl.cc) +- TARGET_INCLUDE_DIRECTORIES(xperia-sl-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(xperia-sl-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(xperia-sl-test xperia-sl-test) +- ENDIF() +- +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") +- ADD_EXECUTABLE(alcatel-revvl-test test/mock/alcatel-revvl.cc) +- TARGET_INCLUDE_DIRECTORIES(alcatel-revvl-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(alcatel-revvl-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(alcatel-revvl-test alcatel-revvl-test) +- +- ADD_EXECUTABLE(galaxy-a8-2018-test test/mock/galaxy-a8-2018.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-a8-2018-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-a8-2018-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-a8-2018-test galaxy-a8-2018-test) +- +- ADD_EXECUTABLE(galaxy-c9-pro-test test/mock/galaxy-c9-pro.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-c9-pro-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-c9-pro-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-c9-pro-test galaxy-c9-pro-test) +- +- ADD_EXECUTABLE(galaxy-s6-test test/mock/galaxy-s6.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s6-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s6-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s6-test galaxy-s6-test) +- +- ADD_EXECUTABLE(galaxy-s7-us-test test/mock/galaxy-s7-us.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s7-us-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s7-us-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s7-us-test galaxy-s7-us-test) +- +- ADD_EXECUTABLE(galaxy-s7-global-test test/mock/galaxy-s7-global.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s7-global-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s7-global-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s7-global-test galaxy-s7-global-test) +- +- ADD_EXECUTABLE(galaxy-s8-us-test test/mock/galaxy-s8-us.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s8-us-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s8-us-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s8-us-test galaxy-s8-us-test) +- +- ADD_EXECUTABLE(galaxy-s8-global-test test/mock/galaxy-s8-global.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s8-global-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s8-global-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s8-global-test galaxy-s8-global-test) +- +- ADD_EXECUTABLE(galaxy-s9-us-test test/mock/galaxy-s9-us.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s9-us-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s9-us-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s9-us-test galaxy-s9-us-test) +- +- ADD_EXECUTABLE(galaxy-s9-global-test test/mock/galaxy-s9-global.cc) +- TARGET_INCLUDE_DIRECTORIES(galaxy-s9-global-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(galaxy-s9-global-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(galaxy-s9-global-test galaxy-s9-global-test) +- +- ADD_EXECUTABLE(huawei-mate-8-test test/mock/huawei-mate-8.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-mate-8-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-mate-8-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-mate-8-test huawei-mate-8-test) +- +- ADD_EXECUTABLE(huawei-mate-9-test test/mock/huawei-mate-9.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-mate-9-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-mate-9-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-mate-9-test huawei-mate-9-test) +- +- ADD_EXECUTABLE(huawei-mate-10-test test/mock/huawei-mate-10.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-mate-10-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-mate-10-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-mate-10-test huawei-mate-10-test) +- +- ADD_EXECUTABLE(huawei-mate-20-test test/mock/huawei-mate-20.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-mate-20-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-mate-20-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-mate-20-test huawei-mate-20-test) +- +- ADD_EXECUTABLE(huawei-p8-lite-test test/mock/huawei-p8-lite.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-p8-lite-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-p8-lite-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-p8-lite-test huawei-p8-lite-test) +- +- ADD_EXECUTABLE(huawei-p9-lite-test test/mock/huawei-p9-lite.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-p9-lite-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-p9-lite-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-p9-lite-test huawei-p9-lite-test) +- +- ADD_EXECUTABLE(huawei-p20-pro-test test/mock/huawei-p20-pro.cc) +- TARGET_INCLUDE_DIRECTORIES(huawei-p20-pro-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(huawei-p20-pro-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(huawei-p20-pro-test huawei-p20-pro-test) +- +- ADD_EXECUTABLE(iconia-one-10-test test/mock/iconia-one-10.cc) +- TARGET_INCLUDE_DIRECTORIES(iconia-one-10-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(iconia-one-10-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(iconia-one-10-test iconia-one-10-test) +- +- ADD_EXECUTABLE(meizu-pro-6-test test/mock/meizu-pro-6.cc) +- TARGET_INCLUDE_DIRECTORIES(meizu-pro-6-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(meizu-pro-6-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(meizu-pro-6-test meizu-pro-6-test) +- +- ADD_EXECUTABLE(meizu-pro-6s-test test/mock/meizu-pro-6s.cc) +- TARGET_INCLUDE_DIRECTORIES(meizu-pro-6s-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(meizu-pro-6s-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(meizu-pro-6s-test meizu-pro-6s-test) +- +- ADD_EXECUTABLE(meizu-pro-7-plus-test test/mock/meizu-pro-7-plus.cc) +- TARGET_INCLUDE_DIRECTORIES(meizu-pro-7-plus-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(meizu-pro-7-plus-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(meizu-pro-7-plus-test meizu-pro-7-plus-test) +- +- ADD_EXECUTABLE(nexus5x-test test/mock/nexus5x.cc) +- TARGET_INCLUDE_DIRECTORIES(nexus5x-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(nexus5x-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(nexus5x-test nexus5x-test) +- +- ADD_EXECUTABLE(nexus6p-test test/mock/nexus6p.cc) +- TARGET_INCLUDE_DIRECTORIES(nexus6p-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(nexus6p-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(nexus6p-test nexus6p-test) +- +- ADD_EXECUTABLE(nexus9-test test/mock/nexus9.cc) +- TARGET_INCLUDE_DIRECTORIES(nexus9-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(nexus9-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(nexus9-test nexus9-test) +- +- ADD_EXECUTABLE(oneplus-3t-test test/mock/oneplus-3t.cc) +- TARGET_INCLUDE_DIRECTORIES(oneplus-3t-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(oneplus-3t-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(oneplus-3t-test oneplus-3t-test) +- +- ADD_EXECUTABLE(oneplus-5-test test/mock/oneplus-5.cc) +- TARGET_INCLUDE_DIRECTORIES(oneplus-5-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(oneplus-5-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(oneplus-5-test oneplus-5-test) +- +- ADD_EXECUTABLE(oneplus-5t-test test/mock/oneplus-5t.cc) +- TARGET_INCLUDE_DIRECTORIES(oneplus-5t-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(oneplus-5t-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(oneplus-5t-test oneplus-5t-test) +- +- ADD_EXECUTABLE(oppo-a37-test test/mock/oppo-a37.cc) +- TARGET_INCLUDE_DIRECTORIES(oppo-a37-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(oppo-a37-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(oppo-a37-test oppo-a37-test) +- +- ADD_EXECUTABLE(oppo-r9-test test/mock/oppo-r9.cc) +- TARGET_INCLUDE_DIRECTORIES(oppo-r9-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(oppo-r9-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(oppo-r9-test oppo-r9-test) +- +- ADD_EXECUTABLE(oppo-r15-test test/mock/oppo-r15.cc) +- TARGET_INCLUDE_DIRECTORIES(oppo-r15-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(oppo-r15-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(oppo-r15-test oppo-r15-test) +- +- ADD_EXECUTABLE(pixel-test test/mock/pixel.cc) +- TARGET_INCLUDE_DIRECTORIES(pixel-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(pixel-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(pixel-test pixel-test) +- +- ADD_EXECUTABLE(pixel-c-test test/mock/pixel-c.cc) +- TARGET_INCLUDE_DIRECTORIES(pixel-c-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(pixel-c-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(pixel-c-test pixel-c-test) +- +- ADD_EXECUTABLE(pixel-xl-test test/mock/pixel-xl.cc) +- TARGET_INCLUDE_DIRECTORIES(pixel-xl-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(pixel-xl-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(pixel-xl-test pixel-xl-test) +- +- ADD_EXECUTABLE(pixel-2-xl-test test/mock/pixel-2-xl.cc) +- TARGET_INCLUDE_DIRECTORIES(pixel-2-xl-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(pixel-2-xl-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(pixel-2-xl-test pixel-2-xl-test) +- +- ADD_EXECUTABLE(xiaomi-mi-5c-test test/mock/xiaomi-mi-5c.cc) +- TARGET_INCLUDE_DIRECTORIES(xiaomi-mi-5c-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(xiaomi-mi-5c-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(xiaomi-mi-5c-test xiaomi-mi-5c-test) +- +- ADD_EXECUTABLE(xiaomi-redmi-note-3-test test/mock/xiaomi-redmi-note-3.cc) +- TARGET_INCLUDE_DIRECTORIES(xiaomi-redmi-note-3-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(xiaomi-redmi-note-3-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(xiaomi-redmi-note-3-test xiaomi-redmi-note-3-test) +- +- ADD_EXECUTABLE(xiaomi-redmi-note-4-test test/mock/xiaomi-redmi-note-4.cc) +- TARGET_INCLUDE_DIRECTORIES(xiaomi-redmi-note-4-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(xiaomi-redmi-note-4-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(xiaomi-redmi-note-4-test xiaomi-redmi-note-4-test) +- +- ADD_EXECUTABLE(xperia-c4-dual-test test/mock/xperia-c4-dual.cc) +- TARGET_INCLUDE_DIRECTORIES(xperia-c4-dual-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(xperia-c4-dual-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(xperia-c4-dual-test xperia-c4-dual-test) +- ENDIF() +- +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64)$") +- ADD_EXECUTABLE(alldocube-iwork8-test test/mock/alldocube-iwork8.cc) +- TARGET_INCLUDE_DIRECTORIES(alldocube-iwork8-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(alldocube-iwork8-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(alldocube-iwork8-test alldocube-iwork8-test) +- +- ADD_EXECUTABLE(leagoo-t5c-test test/mock/leagoo-t5c.cc) +- TARGET_INCLUDE_DIRECTORIES(leagoo-t5c-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(leagoo-t5c-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(leagoo-t5c-test leagoo-t5c-test) +- +- ADD_EXECUTABLE(memo-pad-7-test test/mock/memo-pad-7.cc) +- TARGET_INCLUDE_DIRECTORIES(memo-pad-7-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(memo-pad-7-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(memo-pad-7-test memo-pad-7-test) +- +- ADD_EXECUTABLE(zenfone-c-test test/mock/zenfone-c.cc) +- TARGET_INCLUDE_DIRECTORIES(zenfone-c-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(zenfone-c-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(zenfone-c-test zenfone-c-test) +- +- ADD_EXECUTABLE(zenfone-2-test test/mock/zenfone-2.cc) +- TARGET_INCLUDE_DIRECTORIES(zenfone-2-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(zenfone-2-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(zenfone-2-test zenfone-2-test) +- +- ADD_EXECUTABLE(zenfone-2e-test test/mock/zenfone-2e.cc) +- TARGET_INCLUDE_DIRECTORIES(zenfone-2e-test BEFORE PRIVATE test/mock) +- TARGET_LINK_LIBRARIES(zenfone-2e-test PRIVATE cpuinfo_mock gtest) +- ADD_TEST(zenfone-2e-test zenfone-2e-test) +- ENDIF() +-ENDIF() +- +-# ---[ cpuinfo unit tests +-IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_UNIT_TESTS) +- ADD_EXECUTABLE(init-test test/init.cc) +- CPUINFO_TARGET_ENABLE_CXX11(init-test) +- CPUINFO_TARGET_RUNTIME_LIBRARY(init-test) +- TARGET_LINK_LIBRARIES(init-test PRIVATE cpuinfo gtest gtest_main) +- ADD_TEST(init-test init-test) +- +- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") +- ADD_EXECUTABLE(get-current-test test/get-current.cc) +- CPUINFO_TARGET_ENABLE_CXX11(get-current-test) +- CPUINFO_TARGET_RUNTIME_LIBRARY(get-current-test) +- TARGET_LINK_LIBRARIES(get-current-test PRIVATE cpuinfo gtest gtest_main) +- ADD_TEST(get-current-test get-current-test) +- ENDIF() +- +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86_64)$") +- ADD_EXECUTABLE(brand-string-test test/name/brand-string.cc) +- CPUINFO_TARGET_ENABLE_CXX11(brand-string-test) +- CPUINFO_TARGET_RUNTIME_LIBRARY(brand-string-test) +- TARGET_LINK_LIBRARIES(brand-string-test PRIVATE cpuinfo_internals gtest gtest_main) +- ADD_TEST(brand-string-test brand-string-test) +- ENDIF() +- +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") +- ADD_LIBRARY(android_properties_interface STATIC test/name/android-properties-interface.c) +- CPUINFO_TARGET_ENABLE_C99(android_properties_interface) +- CPUINFO_TARGET_RUNTIME_LIBRARY(android_properties_interface) +- TARGET_LINK_LIBRARIES(android_properties_interface PRIVATE cpuinfo_internals) +- +- ADD_EXECUTABLE(chipset-test +- test/name/proc-cpuinfo-hardware.cc +- test/name/ro-product-board.cc +- test/name/ro-board-platform.cc +- test/name/ro-mediatek-platform.cc +- test/name/ro-arch.cc +- test/name/ro-chipname.cc +- test/name/android-properties.cc) +- CPUINFO_TARGET_ENABLE_CXX11(chipset-test) +- CPUINFO_TARGET_RUNTIME_LIBRARY(chipset-test) +- TARGET_LINK_LIBRARIES(chipset-test PRIVATE android_properties_interface gtest gtest_main) +- ADD_TEST(chipset-test chipset-test) +- +- ADD_EXECUTABLE(cache-test test/arm-cache.cc) +- CPUINFO_TARGET_ENABLE_CXX11(cache-test) +- CPUINFO_TARGET_RUNTIME_LIBRARY(cache-test) +- TARGET_COMPILE_DEFINITIONS(cache-test PRIVATE __STDC_LIMIT_MACROS=1 __STDC_CONSTANT_MACROS=1) +- TARGET_LINK_LIBRARIES(cache-test PRIVATE cpuinfo_internals gtest gtest_main) +- ADD_TEST(cache-test, cache-test) +- ENDIF() +-ENDIF() +- +-# ---[ Helper and debug tools +-IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_TOOLS) +- ADD_EXECUTABLE(isa-info tools/isa-info.c) +- CPUINFO_TARGET_ENABLE_C99(isa-info) +- CPUINFO_TARGET_RUNTIME_LIBRARY(isa-info) +- TARGET_LINK_LIBRARIES(isa-info PRIVATE cpuinfo) +- INSTALL(TARGETS isa-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +- +- ADD_EXECUTABLE(cpu-info tools/cpu-info.c) +- CPUINFO_TARGET_ENABLE_C99(cpu-info) +- CPUINFO_TARGET_RUNTIME_LIBRARY(cpu-info) +- TARGET_LINK_LIBRARIES(cpu-info PRIVATE cpuinfo) +- INSTALL(TARGETS cpu-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +- +- ADD_EXECUTABLE(cache-info tools/cache-info.c) +- CPUINFO_TARGET_ENABLE_C99(cache-info) +- CPUINFO_TARGET_RUNTIME_LIBRARY(cache-info) +- TARGET_LINK_LIBRARIES(cache-info PRIVATE cpuinfo) +- INSTALL(TARGETS cache-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +- +- IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") +- ADD_EXECUTABLE(auxv-dump tools/auxv-dump.c) +- CPUINFO_TARGET_ENABLE_C99(auxv-dump) +- CPUINFO_TARGET_RUNTIME_LIBRARY(auxv-dump) +- TARGET_LINK_LIBRARIES(auxv-dump PRIVATE ${CMAKE_DL_LIBS} cpuinfo) +- +- ADD_EXECUTABLE(cpuinfo-dump tools/cpuinfo-dump.c) +- CPUINFO_TARGET_ENABLE_C99(cpuinfo-dump) +- CPUINFO_TARGET_RUNTIME_LIBRARY(cpuinfo-dump) +- ENDIF() +- +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86_64)$") +- ADD_EXECUTABLE(cpuid-dump tools/cpuid-dump.c) +- CPUINFO_TARGET_ENABLE_C99(cpuid-dump) +- CPUINFO_TARGET_RUNTIME_LIBRARY(cpuid-dump) +- TARGET_INCLUDE_DIRECTORIES(cpuid-dump BEFORE PRIVATE src) +- TARGET_INCLUDE_DIRECTORIES(cpuid-dump BEFORE PRIVATE include) +- INSTALL(TARGETS cpuid-dump RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +- ENDIF() +-ENDIF() +diff --git a/README.md b/README.md +index 7d383ff..ee5fb82 100644 +--- a/README.md ++++ b/README.md +@@ -152,21 +152,20 @@ pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpu_set); + - [x] Using `ro.chipname`, `ro.board.platform`, `ro.product.board`, `ro.mediatek.platform`, `ro.arch` properties (Android) + - [ ] Using kernel log (`dmesg`) on ARM Linux + - Vendor and microarchitecture detection +- - [x] Intel-designed x86/x86-64 cores (up to Kaby Lake, Airmont, and Knights Mill) +- - [x] AMD-designed x86/x86-64 cores (up to Puma/Jaguar and Zen) ++ - [x] Intel-designed x86/x86-64 cores (up to Sunny Cove, Goldmont Plus, and Knights Mill) ++ - [x] AMD-designed x86/x86-64 cores (up to Puma/Jaguar and Zen 2) + - [ ] VIA-designed x86/x86-64 cores + - [ ] Other x86 cores (DM&P, RDC, Transmeta, Cyrix, Rise) +- - [x] ARM-designed ARM cores (up to Cortex-A55 and Cortex-A75) +- - [x] Qualcomm-designed ARM cores (up to Kryo, Kryo-280, and Kryo-385) +- - [x] Nvidia-designed ARM cores (Denver) ++ - [x] ARM-designed ARM cores (up to Cortex-A55, Cortex-A77, and Neoverse E1/N1) ++ - [x] Qualcomm-designed ARM cores (Scorpion, Krait, and Kryo) ++ - [x] Nvidia-designed ARM cores (Denver and Carmel) + - [x] Samsung-designed ARM cores (Exynos) + - [x] Intel-designed ARM cores (XScale up to 3rd-gen) +- - [x] Apple-designed ARM cores (up to Hurricane) ++ - [x] Apple-designed ARM cores (up to Lightning and Thunder) + - [x] Cavium-designed ARM cores (ThunderX) + - [x] AppliedMicro-designed ARM cores (X-Gene) + - Instruction set detection + - [x] Using CPUID (x86/x86-64) +- - [x] Using dynamic code generation validator (Native Client/x86-64) + - [x] Using `/proc/cpuinfo` on 32-bit ARM EABI (Linux) + - [x] Using microarchitecture heuristics on (32-bit ARM) + - [x] Using `FPSID` and `WCID` registers (32-bit ARM) +diff --git a/bench/get-current.cc b/bench/get-current.cc +index 91b35a0..b547df0 100644 +--- a/bench/get-current.cc ++++ b/bench/get-current.cc +@@ -21,4 +21,13 @@ static void cpuinfo_get_current_core(benchmark::State& state) { + } + BENCHMARK(cpuinfo_get_current_core)->Unit(benchmark::kNanosecond); + ++static void cpuinfo_get_current_uarch_index(benchmark::State& state) { ++ cpuinfo_initialize(); ++ while (state.KeepRunning()) { ++ const uint32_t uarch_index = cpuinfo_get_current_uarch_index(); ++ benchmark::DoNotOptimize(uarch_index); ++ } ++} ++BENCHMARK(cpuinfo_get_current_uarch_index)->Unit(benchmark::kNanosecond); ++ + BENCHMARK_MAIN(); +diff --git a/cmake/DownloadGoogleTest.cmake b/cmake/DownloadGoogleTest.cmake +index d69d19a..dc86c9c 100644 +--- a/cmake/DownloadGoogleTest.cmake ++++ b/cmake/DownloadGoogleTest.cmake +@@ -4,8 +4,8 @@ PROJECT(googletest-download NONE) + + INCLUDE(ExternalProject) + ExternalProject_Add(googletest +- URL https://github.com/google/googletest/archive/release-1.8.0.zip +- URL_HASH SHA256=f3ed3b58511efd272eb074a3a6d6fb79d7c2e6a0e374323d1e6bcbcc1ef141bf ++ URL https://github.com/google/googletest/archive/release-1.10.0.zip ++ URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91 + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest" + CONFIGURE_COMMAND "" +diff --git a/configure.py b/configure.py +index a340c4c..0e58dba 100755 +--- a/configure.py ++++ b/configure.py +@@ -26,8 +26,8 @@ def main(args): + sources = ["init.c", "api.c"] + if build.target.is_x86 or build.target.is_x86_64: + sources += [ +- "x86/init.c", "x86/info.c", "x86/vendor.c", "x86/uarch.c", "x86/name.c", +- "x86/topology.c", ++ "x86/init.c", "x86/info.c", "x86/isa.c", "x86/vendor.c", ++ "x86/uarch.c", "x86/name.c", "x86/topology.c", + "x86/cache/init.c", "x86/cache/descriptor.c", "x86/cache/deterministic.c", + ] + if build.target.is_macos: +@@ -37,7 +37,6 @@ def main(args): + "x86/linux/init.c", + "x86/linux/cpuinfo.c", + ] +- sources.append("x86/isa.c" if not build.target.is_nacl else "x86/nacl/isa.c") + if build.target.is_arm or build.target.is_arm64: + sources += ["arm/uarch.c", "arm/cache.c"] + if build.target.is_linux or build.target.is_android: +diff --git a/include/cpuinfo.h b/include/cpuinfo.h +index 9938d2b..e4d2d0c 100644 +--- a/include/cpuinfo.h ++++ b/include/cpuinfo.h +@@ -34,10 +34,6 @@ + #define CPUINFO_ARCH_PPC64 1 + #endif + +-#if defined(__pnacl__) +- #define CPUINFO_ARCH_PNACL 1 +-#endif +- + #if defined(__asmjs__) + #define CPUINFO_ARCH_ASMJS 1 + #endif +@@ -80,10 +76,6 @@ + #define CPUINFO_ARCH_PPC64 0 + #endif + +-#ifndef CPUINFO_ARCH_PNACL +- #define CPUINFO_ARCH_PNACL 0 +-#endif +- + #ifndef CPUINFO_ARCH_ASMJS + #define CPUINFO_ARCH_ASMJS 0 + #endif +@@ -190,6 +182,12 @@ enum cpuinfo_vendor { + * Processors are designed by HiSilicon, a subsidiary of Huawei. + */ + cpuinfo_vendor_huawei = 15, ++ /** ++ * Hygon (Chengdu Haiguang Integrated Circuit Design Co., Ltd), Vendor of x86-64 processor microarchitectures. ++ * ++ * Processors are variants of AMD cores. ++ */ ++ cpuinfo_vendor_hygon = 16, + + /* Active vendors of embedded CPUs */ + +@@ -401,6 +399,8 @@ enum cpuinfo_uarch { + cpuinfo_uarch_cortex_a35 = 0x00300335, + /** ARM Cortex-A53. */ + cpuinfo_uarch_cortex_a53 = 0x00300353, ++ /** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities compared to revision 1+). */ ++ cpuinfo_uarch_cortex_a55r0 = 0x00300354, + /** ARM Cortex-A55. */ + cpuinfo_uarch_cortex_a55 = 0x00300355, + /** ARM Cortex-A57. */ +@@ -478,6 +478,10 @@ enum cpuinfo_uarch { + cpuinfo_uarch_vortex = 0x00700107, + /** Apple A12 processor (little cores). */ + cpuinfo_uarch_tempest = 0x00700108, ++ /** Apple A13 processor (big cores). */ ++ cpuinfo_uarch_lightning = 0x00700109, ++ /** Apple A13 processor (little cores). */ ++ cpuinfo_uarch_thunder = 0x0070010A, + + /** Cavium ThunderX. */ + cpuinfo_uarch_thunderx = 0x00800100, +@@ -494,6 +498,9 @@ enum cpuinfo_uarch { + + /** Applied Micro X-Gene. */ + cpuinfo_uarch_xgene = 0x00B00100, ++ ++ /* Hygon Dhyana (a modification of AMD Zen for Chinese market). */ ++ cpuinfo_uarch_dhyana = 0x01000100, + }; + + struct cpuinfo_processor { +@@ -613,6 +620,22 @@ struct cpuinfo_package { + uint32_t cluster_count; + }; + ++struct cpuinfo_uarch_info { ++ /** Type of CPU microarchitecture */ ++ enum cpuinfo_uarch uarch; ++#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 ++ /** Value of CPUID leaf 1 EAX register for the microarchitecture */ ++ uint32_t cpuid; ++#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 ++ /** Value of Main ID Register (MIDR) for the microarchitecture */ ++ uint32_t midr; ++#endif ++ /** Number of logical processors with the microarchitecture */ ++ uint32_t processor_count; ++ /** Number of cores with the microarchitecture */ ++ uint32_t core_count; ++}; ++ + #ifdef __cplusplus + extern "C" { + #endif +@@ -1721,6 +1744,7 @@ const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processors(void); + const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_cores(void); + const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_clusters(void); + const struct cpuinfo_package* CPUINFO_ABI cpuinfo_get_packages(void); ++const struct cpuinfo_uarch_info* CPUINFO_ABI cpuinfo_get_uarchs(void); + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_caches(void); + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_caches(void); + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_caches(void); +@@ -1731,6 +1755,7 @@ const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processor(uint32_t index + const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_core(uint32_t index); + const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_cluster(uint32_t index); + const struct cpuinfo_package* CPUINFO_ABI cpuinfo_get_package(uint32_t index); ++const struct cpuinfo_uarch_info* CPUINFO_ABI cpuinfo_get_uarch(uint32_t index); + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_cache(uint32_t index); + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_cache(uint32_t index); + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_cache(uint32_t index); +@@ -1741,6 +1766,7 @@ uint32_t CPUINFO_ABI cpuinfo_get_processors_count(void); + uint32_t CPUINFO_ABI cpuinfo_get_cores_count(void); + uint32_t CPUINFO_ABI cpuinfo_get_clusters_count(void); + uint32_t CPUINFO_ABI cpuinfo_get_packages_count(void); ++uint32_t CPUINFO_ABI cpuinfo_get_uarchs_count(void); + uint32_t CPUINFO_ABI cpuinfo_get_l1i_caches_count(void); + uint32_t CPUINFO_ABI cpuinfo_get_l1d_caches_count(void); + uint32_t CPUINFO_ABI cpuinfo_get_l2_caches_count(void); +@@ -1752,9 +1778,31 @@ uint32_t CPUINFO_ABI cpuinfo_get_l4_caches_count(void); + */ + uint32_t CPUINFO_ABI cpuinfo_get_max_cache_size(void); + ++/** ++ * Identify the logical processor that executes the current thread. ++ * ++ * There is no guarantee that the thread will stay on the same logical processor for any time. ++ * Callers should treat the result as only a hint, and be prepared to handle NULL return value. ++ */ + const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_current_processor(void); ++ ++/** ++ * Identify the core that executes the current thread. ++ * ++ * There is no guarantee that the thread will stay on the same core for any time. ++ * Callers should treat the result as only a hint, and be prepared to handle NULL return value. ++ */ + const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_current_core(void); + ++/** ++ * Identify the microarchitecture index of the core that executes the current thread. ++ * If the system does not support such identification, the function return 0. ++ * ++ * There is no guarantee that the thread will stay on the same type of core for any time. ++ * Callers should treat the result as only a hint. ++ */ ++uint32_t CPUINFO_ABI cpuinfo_get_current_uarch_index(void); ++ + #ifdef __cplusplus + } /* extern "C" */ + #endif +diff --git a/src/api.c b/src/api.c +index b180d80..0cc5d4e 100644 +--- a/src/api.c ++++ b/src/api.c +@@ -1,9 +1,16 @@ ++#include + #include + + #include + #include + #include + ++#ifdef __linux__ ++ #include ++ ++ #include ++ #include ++#endif + + bool cpuinfo_is_initialized = false; + +@@ -20,235 +27,347 @@ uint32_t cpuinfo_packages_count = 0; + uint32_t cpuinfo_cache_count[cpuinfo_cache_level_max] = { 0 }; + uint32_t cpuinfo_max_cache_size = 0; + ++#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 ++ struct cpuinfo_uarch_info* cpuinfo_uarchs = NULL; ++ uint32_t cpuinfo_uarchs_count = 0; ++#else ++ struct cpuinfo_uarch_info cpuinfo_global_uarch = { cpuinfo_uarch_unknown }; ++#endif ++ ++#ifdef __linux__ ++ uint32_t cpuinfo_linux_cpu_max = 0; ++ const struct cpuinfo_processor** cpuinfo_linux_cpu_to_processor_map = NULL; ++ const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map = NULL; ++ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 ++ const uint32_t* cpuinfo_linux_cpu_to_uarch_index_map = NULL; ++ #endif ++#endif ++ + + const struct cpuinfo_processor* cpuinfo_get_processors(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "processors"); + } + return cpuinfo_processors; + } + + const struct cpuinfo_core* cpuinfo_get_cores(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "core"); + } + return cpuinfo_cores; + } + + const struct cpuinfo_cluster* cpuinfo_get_clusters(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "clusters"); + } + return cpuinfo_clusters; + } + + const struct cpuinfo_package* cpuinfo_get_packages(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "packages"); + } + return cpuinfo_packages; + } + +-const struct cpuinfo_processor* cpuinfo_get_processor(uint32_t index) { ++const struct cpuinfo_uarch_info* cpuinfo_get_uarchs() { + if (!cpuinfo_is_initialized) { ++ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "uarchs"); ++ } ++ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 ++ return cpuinfo_uarchs; ++ #else ++ return &cpuinfo_global_uarch; ++ #endif ++} ++ ++const struct cpuinfo_processor* cpuinfo_get_processor(uint32_t index) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "processor"); + } +- if (index < cpuinfo_processors_count) { +- return cpuinfo_processors + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_processors_count) { + return NULL; + } ++ return &cpuinfo_processors[index]; + } + + const struct cpuinfo_core* cpuinfo_get_core(uint32_t index) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "core"); + } +- if (index < cpuinfo_cores_count) { +- return cpuinfo_cores + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_cores_count) { + return NULL; + } ++ return &cpuinfo_cores[index]; + } + + const struct cpuinfo_cluster* cpuinfo_get_cluster(uint32_t index) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "cluster"); + } +- if (index < cpuinfo_clusters_count) { +- return cpuinfo_clusters + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_clusters_count) { + return NULL; + } ++ return &cpuinfo_clusters[index]; + } + + const struct cpuinfo_package* cpuinfo_get_package(uint32_t index) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "package"); + } +- if (index < cpuinfo_packages_count) { +- return cpuinfo_packages + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_packages_count) { + return NULL; + } ++ return &cpuinfo_packages[index]; + } + +-uint32_t cpuinfo_get_processors_count(void) { ++const struct cpuinfo_uarch_info* cpuinfo_get_uarch(uint32_t index) { + if (!cpuinfo_is_initialized) { ++ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "uarch"); ++ } ++ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 ++ if CPUINFO_UNLIKELY(index >= cpuinfo_uarchs_count) { ++ return NULL; ++ } ++ return &cpuinfo_uarchs[index]; ++ #else ++ if CPUINFO_UNLIKELY(index != 0) { ++ return NULL; ++ } ++ return &cpuinfo_global_uarch; ++ #endif ++} ++ ++uint32_t cpuinfo_get_processors_count(void) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "processors_count"); + } + return cpuinfo_processors_count; + } + + uint32_t cpuinfo_get_cores_count(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "cores_count"); + } + return cpuinfo_cores_count; + } + + uint32_t cpuinfo_get_clusters_count(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "clusters_count"); + } + return cpuinfo_clusters_count; + } + + uint32_t cpuinfo_get_packages_count(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "packages_count"); + } + return cpuinfo_packages_count; + } + +-const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_caches(void) { ++uint32_t cpuinfo_get_uarchs_count(void) { + if (!cpuinfo_is_initialized) { ++ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "uarchs_count"); ++ } ++ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 ++ return cpuinfo_uarchs_count; ++ #else ++ return 1; ++ #endif ++} ++ ++const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_caches(void) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1i_caches"); + } + return cpuinfo_cache[cpuinfo_cache_level_1i]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_caches(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1d_caches"); + } + return cpuinfo_cache[cpuinfo_cache_level_1d]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_caches(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l2_caches"); + } + return cpuinfo_cache[cpuinfo_cache_level_2]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l3_caches(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l3_caches"); + } + return cpuinfo_cache[cpuinfo_cache_level_3]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l4_caches(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l4_caches"); + } + return cpuinfo_cache[cpuinfo_cache_level_4]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_cache(uint32_t index) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1i_cache"); + } +- if (index < cpuinfo_cache_count[cpuinfo_cache_level_1i]) { +- return cpuinfo_cache[cpuinfo_cache_level_1i] + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_1i]) { + return NULL; + } ++ return &cpuinfo_cache[cpuinfo_cache_level_1i][index]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_cache(uint32_t index) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1d_cache"); + } +- if (index < cpuinfo_cache_count[cpuinfo_cache_level_1d]) { +- return cpuinfo_cache[cpuinfo_cache_level_1d] + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_1d]) { + return NULL; + } ++ return &cpuinfo_cache[cpuinfo_cache_level_1d][index]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_cache(uint32_t index) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l2_cache"); + } +- if (index < cpuinfo_cache_count[cpuinfo_cache_level_2]) { +- return cpuinfo_cache[cpuinfo_cache_level_2] + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_2]) { + return NULL; + } ++ return &cpuinfo_cache[cpuinfo_cache_level_2][index]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l3_cache(uint32_t index) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l3_cache"); + } +- if (index < cpuinfo_cache_count[cpuinfo_cache_level_3]) { +- return cpuinfo_cache[cpuinfo_cache_level_3] + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_3]) { + return NULL; + } ++ return &cpuinfo_cache[cpuinfo_cache_level_3][index]; + } + + const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l4_cache(uint32_t index) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l4_cache"); + } +- if (index < cpuinfo_cache_count[cpuinfo_cache_level_4]) { +- return cpuinfo_cache[cpuinfo_cache_level_4] + index; +- } else { ++ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_4]) { + return NULL; + } ++ return &cpuinfo_cache[cpuinfo_cache_level_4][index]; + } + + uint32_t CPUINFO_ABI cpuinfo_get_l1i_caches_count(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1i_caches_count"); + } + return cpuinfo_cache_count[cpuinfo_cache_level_1i]; + } + + uint32_t CPUINFO_ABI cpuinfo_get_l1d_caches_count(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1d_caches_count"); + } + return cpuinfo_cache_count[cpuinfo_cache_level_1d]; + } + + uint32_t CPUINFO_ABI cpuinfo_get_l2_caches_count(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l2_caches_count"); + } + return cpuinfo_cache_count[cpuinfo_cache_level_2]; + } + + uint32_t CPUINFO_ABI cpuinfo_get_l3_caches_count(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l3_caches_count"); + } + return cpuinfo_cache_count[cpuinfo_cache_level_3]; + } + + uint32_t CPUINFO_ABI cpuinfo_get_l4_caches_count(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l4_caches_count"); + } + return cpuinfo_cache_count[cpuinfo_cache_level_4]; + } + + uint32_t CPUINFO_ABI cpuinfo_get_max_cache_size(void) { +- if (!cpuinfo_is_initialized) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { + cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "max_cache_size"); + } + return cpuinfo_max_cache_size; + } ++ ++const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_current_processor(void) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { ++ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_processor"); ++ } ++ #ifdef __linux__ ++ unsigned cpu; ++ if CPUINFO_UNLIKELY(syscall(__NR_getcpu, &cpu, NULL, NULL) != 0) { ++ return 0; ++ } ++ if CPUINFO_UNLIKELY((uint32_t) cpu >= cpuinfo_linux_cpu_max) { ++ return 0; ++ } ++ return cpuinfo_linux_cpu_to_processor_map[cpu]; ++ #else ++ return NULL; ++ #endif ++} ++ ++const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_current_core(void) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { ++ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_core"); ++ } ++ #ifdef __linux__ ++ unsigned cpu; ++ if CPUINFO_UNLIKELY(syscall(__NR_getcpu, &cpu, NULL, NULL) != 0) { ++ return 0; ++ } ++ if CPUINFO_UNLIKELY((uint32_t) cpu >= cpuinfo_linux_cpu_max) { ++ return 0; ++ } ++ return cpuinfo_linux_cpu_to_core_map[cpu]; ++ #else ++ return NULL; ++ #endif ++} ++ ++uint32_t CPUINFO_ABI cpuinfo_get_current_uarch_index(void) { ++ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { ++ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_uarch_index"); ++ } ++ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 ++ #ifdef __linux__ ++ if (cpuinfo_linux_cpu_to_uarch_index_map == NULL) { ++ /* Special case: avoid syscall on systems with only a single type of cores */ ++ return 0; ++ } ++ ++ /* General case */ ++ unsigned cpu; ++ if CPUINFO_UNLIKELY(syscall(__NR_getcpu, &cpu, NULL, NULL) != 0) { ++ return 0; ++ } ++ if CPUINFO_UNLIKELY((uint32_t) cpu >= cpuinfo_linux_cpu_max) { ++ return 0; ++ } ++ return cpuinfo_linux_cpu_to_uarch_index_map[cpu]; ++ #else ++ /* Fallback: pretend to be on the big core. */ ++ return 0; ++ #endif ++ #else ++ /* Only ARM/ARM64 processors may include cores of different types in the same package. */ ++ return 0; ++ #endif ++} +diff --git a/src/arm/cache.c b/src/arm/cache.c +index ccadeb4..c2bc7d2 100644 +--- a/src/arm/cache.c ++++ b/src/arm/cache.c +@@ -659,6 +659,7 @@ void cpuinfo_arm_decode_cache( + }; + } + break; ++ case cpuinfo_uarch_cortex_a55r0: + case cpuinfo_uarch_cortex_a55: + /* + * ARM Cortex-A55 Core Technical Reference Manual +diff --git a/src/arm/linux/api.h b/src/arm/linux/api.h +index 275d072..f99da66 100644 +--- a/src/arm/linux/api.h ++++ b/src/arm/linux/api.h +@@ -153,6 +153,7 @@ struct cpuinfo_arm_linux_processor { + uint32_t midr; + enum cpuinfo_vendor vendor; + enum cpuinfo_uarch uarch; ++ uint32_t uarch_index; + /** + * ID of the physical package which includes this logical processor. + * The value is parsed from /sys/devices/system/cpu/cpu/topology/physical_package_id +@@ -346,3 +347,6 @@ CPUINFO_INTERNAL uint32_t cpuinfo_arm_linux_detect_cluster_midr( + uint32_t max_processors, + uint32_t usable_processors, + struct cpuinfo_arm_linux_processor processors[restrict static max_processors]); ++ ++extern CPUINFO_INTERNAL const uint32_t* cpuinfo_linux_cpu_to_uarch_index_map; ++extern CPUINFO_INTERNAL uint32_t cpuinfo_linux_cpu_to_uarch_index_map_entries; +diff --git a/src/arm/linux/init.c b/src/arm/linux/init.c +index f0c432c..6272abf 100644 +--- a/src/arm/linux/init.c ++++ b/src/arm/linux/init.c +@@ -106,12 +106,14 @@ void cpuinfo_arm_linux_init(void) { + struct cpuinfo_processor* processors = NULL; + struct cpuinfo_core* cores = NULL; + struct cpuinfo_cluster* clusters = NULL; +- const struct cpuinfo_processor** linux_cpu_to_processor_map = NULL; +- const struct cpuinfo_core** linux_cpu_to_core_map = NULL; ++ struct cpuinfo_uarch_info* uarchs = NULL; + struct cpuinfo_cache* l1i = NULL; + struct cpuinfo_cache* l1d = NULL; + struct cpuinfo_cache* l2 = NULL; + struct cpuinfo_cache* l3 = NULL; ++ const struct cpuinfo_processor** linux_cpu_to_processor_map = NULL; ++ const struct cpuinfo_core** linux_cpu_to_core_map = NULL; ++ uint32_t* linux_cpu_to_uarch_index_map = NULL; + + const uint32_t max_processors_count = cpuinfo_linux_get_max_processors_count(); + cpuinfo_log_debug("system maximum processors count: %"PRIu32, max_processors_count); +@@ -400,6 +402,18 @@ void cpuinfo_arm_linux_init(void) { + } + } + ++ uint32_t uarchs_count = 0; ++ enum cpuinfo_uarch last_uarch; ++ for (uint32_t i = 0; i < arm_linux_processors_count; i++) { ++ if (bitmask_all(arm_linux_processors[i].flags, CPUINFO_LINUX_FLAG_VALID)) { ++ if (uarchs_count == 0 || arm_linux_processors[i].uarch != last_uarch) { ++ last_uarch = arm_linux_processors[i].uarch; ++ uarchs_count += 1; ++ } ++ arm_linux_processors[i].uarch_index = uarchs_count - 1; ++ } ++ } ++ + /* + * Assumptions: + * - No SMP (i.e. each core supports only one hardware thread). +@@ -432,6 +446,13 @@ void cpuinfo_arm_linux_init(void) { + goto cleanup; + } + ++ uarchs = calloc(uarchs_count, sizeof(struct cpuinfo_uarch_info)); ++ if (uarchs == NULL) { ++ cpuinfo_log_error("failed to allocate %zu bytes for descriptions of %"PRIu32" microarchitectures", ++ uarchs_count * sizeof(struct cpuinfo_uarch_info), uarchs_count); ++ goto cleanup; ++ } ++ + linux_cpu_to_processor_map = calloc(arm_linux_processors_count, sizeof(struct cpuinfo_processor*)); + if (linux_cpu_to_processor_map == NULL) { + cpuinfo_log_error("failed to allocate %zu bytes for %"PRIu32" logical processor mapping entries", +@@ -446,6 +467,15 @@ void cpuinfo_arm_linux_init(void) { + goto cleanup; + } + ++ if (uarchs_count > 1) { ++ linux_cpu_to_uarch_index_map = calloc(arm_linux_processors_count, sizeof(uint32_t)); ++ if (linux_cpu_to_uarch_index_map == NULL) { ++ cpuinfo_log_error("failed to allocate %zu bytes for %"PRIu32" uarch index mapping entries", ++ arm_linux_processors_count * sizeof(uint32_t), arm_linux_processors_count); ++ goto cleanup; ++ } ++ } ++ + l1i = calloc(valid_processors, sizeof(struct cpuinfo_cache)); + if (l1i == NULL) { + cpuinfo_log_error("failed to allocate %zu bytes for descriptions of %"PRIu32" L1I caches", +@@ -460,6 +490,22 @@ void cpuinfo_arm_linux_init(void) { + goto cleanup; + } + ++ uint32_t uarchs_index = 0; ++ for (uint32_t i = 0; i < arm_linux_processors_count; i++) { ++ if (bitmask_all(arm_linux_processors[i].flags, CPUINFO_LINUX_FLAG_VALID)) { ++ if (uarchs_index == 0 || arm_linux_processors[i].uarch != last_uarch) { ++ last_uarch = arm_linux_processors[i].uarch; ++ uarchs[uarchs_index] = (struct cpuinfo_uarch_info) { ++ .uarch = arm_linux_processors[i].uarch, ++ .midr = arm_linux_processors[i].midr, ++ }; ++ uarchs_index += 1; ++ } ++ uarchs[uarchs_index - 1].processor_count += 1; ++ uarchs[uarchs_index - 1].core_count += 1; ++ } ++ } ++ + uint32_t l2_count = 0, l3_count = 0, big_l3_size = 0, cluster_id = UINT32_MAX; + /* Indication whether L3 (if it exists) is shared between all cores */ + bool shared_l3 = true; +@@ -499,6 +545,11 @@ void cpuinfo_arm_linux_init(void) { + cores[i].midr = arm_linux_processors[i].midr; + linux_cpu_to_core_map[arm_linux_processors[i].system_processor_id] = &cores[i]; + ++ if (linux_cpu_to_uarch_index_map != NULL) { ++ linux_cpu_to_uarch_index_map[arm_linux_processors[i].system_processor_id] = ++ arm_linux_processors[i].uarch_index; ++ } ++ + struct cpuinfo_cache temp_l2 = { 0 }, temp_l3 = { 0 }; + cpuinfo_arm_decode_cache( + arm_linux_processors[i].uarch, +@@ -658,12 +709,11 @@ void cpuinfo_arm_linux_init(void) { + } + + /* Commit */ +- cpuinfo_linux_cpu_to_processor_map = linux_cpu_to_processor_map; +- cpuinfo_linux_cpu_to_core_map = linux_cpu_to_core_map; + cpuinfo_processors = processors; + cpuinfo_cores = cores; + cpuinfo_clusters = clusters; + cpuinfo_packages = &package; ++ cpuinfo_uarchs = uarchs; + cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; + cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; + cpuinfo_cache[cpuinfo_cache_level_2] = l2; +@@ -673,33 +723,42 @@ void cpuinfo_arm_linux_init(void) { + cpuinfo_cores_count = valid_processors; + cpuinfo_clusters_count = cluster_count; + cpuinfo_packages_count = 1; ++ cpuinfo_uarchs_count = uarchs_count; + cpuinfo_cache_count[cpuinfo_cache_level_1i] = valid_processors; + cpuinfo_cache_count[cpuinfo_cache_level_1d] = valid_processors; + cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; + cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; +- + cpuinfo_max_cache_size = cpuinfo_arm_compute_max_cache_size(&processors[0]); + ++ cpuinfo_linux_cpu_max = arm_linux_processors_count; ++ cpuinfo_linux_cpu_to_processor_map = linux_cpu_to_processor_map; ++ cpuinfo_linux_cpu_to_core_map = linux_cpu_to_core_map; ++ cpuinfo_linux_cpu_to_uarch_index_map = linux_cpu_to_uarch_index_map; ++ + __sync_synchronize(); + + cpuinfo_is_initialized = true; + +- linux_cpu_to_processor_map = NULL; +- linux_cpu_to_core_map = NULL; + processors = NULL; + cores = NULL; + clusters = NULL; ++ uarchs = NULL; + l1i = l1d = l2 = l3 = NULL; ++ linux_cpu_to_processor_map = NULL; ++ linux_cpu_to_core_map = NULL; ++ linux_cpu_to_uarch_index_map = NULL; + + cleanup: + free(arm_linux_processors); +- free(linux_cpu_to_processor_map); +- free(linux_cpu_to_core_map); + free(processors); + free(cores); + free(clusters); ++ free(uarchs); + free(l1i); + free(l1d); + free(l2); + free(l3); ++ free(linux_cpu_to_processor_map); ++ free(linux_cpu_to_core_map); ++ free(linux_cpu_to_uarch_index_map); + } +diff --git a/src/arm/mach/init.c b/src/arm/mach/init.c +index e64cc18..bd27259 100644 +--- a/src/arm/mach/init.c ++++ b/src/arm/mach/init.c +@@ -14,6 +14,16 @@ + #include + #include + ++/* Polyfill recent CPUFAMILY_ARM_* values for older SDKs */ ++#ifndef CPUFAMILY_ARM_MONSOON_MISTRAL ++ #define CPUFAMILY_ARM_MONSOON_MISTRAL 0xE81E7EF6 ++#endif ++#ifndef CPUFAMILY_ARM_VORTEX_TEMPEST ++ #define CPUFAMILY_ARM_VORTEX_TEMPEST 0x07D34B9F ++#endif ++#ifndef CPUFAMILY_ARM_LIGHTNING_THUNDER ++ #define CPUFAMILY_ARM_LIGHTNING_THUNDER 0x462504D2 ++#endif + + struct cpuinfo_arm_isa cpuinfo_isa = { + #if CPUINFO_ARCH_ARM +@@ -82,37 +92,34 @@ static enum cpuinfo_uarch decode_uarch(uint32_t cpu_family, uint32_t cpu_subtype + return cpuinfo_uarch_twister; + case CPUFAMILY_ARM_HURRICANE: + return cpuinfo_uarch_hurricane; +-#ifdef CPUFAMILY_ARM_MONSOON_MISTRAL + case CPUFAMILY_ARM_MONSOON_MISTRAL: +-#else +- case 0xe81e7ef6: +- /* Hard-coded value for older SDKs which do not define CPUFAMILY_ARM_MONSOON_MISTRAL */ +-#endif + /* 2x Monsoon + 4x Mistral cores */ + return core_index < 2 ? cpuinfo_uarch_monsoon : cpuinfo_uarch_mistral; +-#ifdef CPUFAMILY_ARM_VORTEX_TEMPEST + case CPUFAMILY_ARM_VORTEX_TEMPEST: +-#else +- case 0x07d34b9f: +- /* Hard-coded value for older SDKs which do not define CPUFAMILY_ARM_VORTEX_TEMPEST */ +-#endif + /* Hexa-core: 2x Vortex + 4x Tempest; Octa-core: 4x Cortex + 4x Tempest */ + return core_index + 4 < core_count ? cpuinfo_uarch_vortex : cpuinfo_uarch_tempest; ++ case CPUFAMILY_ARM_LIGHTNING_THUNDER: ++ /* Hexa-core: 2x Lightning + 4x Thunder; Octa-core (presumed): 4x Lightning + 4x Thunder */ ++ return core_index + 4 < core_count ? cpuinfo_uarch_lightning : cpuinfo_uarch_thunder; + default: + /* Use hw.cpusubtype for detection */ + break; + } + +- switch (cpu_subtype) { +- case CPU_SUBTYPE_ARM_V7: +- return cpuinfo_uarch_cortex_a8; +- case CPU_SUBTYPE_ARM_V7F: +- return cpuinfo_uarch_cortex_a9; +- case CPU_SUBTYPE_ARM_V7K: +- return cpuinfo_uarch_cortex_a7; +- default: +- return cpuinfo_uarch_unknown; +- } ++ #if CPUINFO_ARCH_ARM ++ switch (cpu_subtype) { ++ case CPU_SUBTYPE_ARM_V7: ++ return cpuinfo_uarch_cortex_a8; ++ case CPU_SUBTYPE_ARM_V7F: ++ return cpuinfo_uarch_cortex_a9; ++ case CPU_SUBTYPE_ARM_V7K: ++ return cpuinfo_uarch_cortex_a7; ++ default: ++ return cpuinfo_uarch_unknown; ++ } ++ #else ++ return cpuinfo_uarch_unknown; ++ #endif + } + + static void decode_package_name(char* package_name) { +@@ -244,6 +251,7 @@ void cpuinfo_arm_mach_init(void) { + struct cpuinfo_core* cores = NULL; + struct cpuinfo_cluster* clusters = NULL; + struct cpuinfo_package* packages = NULL; ++ struct cpuinfo_uarch_info* uarchs = NULL; + struct cpuinfo_cache* l1i = NULL; + struct cpuinfo_cache* l1d = NULL; + struct cpuinfo_cache* l2 = NULL; +@@ -330,21 +338,12 @@ void cpuinfo_arm_mach_init(void) { + * Thus, we whitelist CPUs known to support these instructions. + */ + switch (cpu_family) { +-#ifdef CPUFAMILY_ARM_MONSOON_MISTRAL + case CPUFAMILY_ARM_MONSOON_MISTRAL: +-#else +- case 0xe81e7ef6: +- /* Hard-coded value for older SDKs which do not define CPUFAMILY_ARM_MONSOON_MISTRAL */ +-#endif +-#ifdef CPUFAMILY_ARM_VORTEX_TEMPEST + case CPUFAMILY_ARM_VORTEX_TEMPEST: +-#else +- case 0x07d34b9f: +- /* Hard-coded value for older SDKs which do not define CPUFAMILY_ARM_VORTEX_TEMPEST */ +-#endif +-#if CPUINFO_ARCH_ARM64 +- cpuinfo_isa.atomics = true; +-#endif ++ case CPUFAMILY_ARM_LIGHTNING_THUNDER: ++ #if CPUINFO_ARCH_ARM64 ++ cpuinfo_isa.atomics = true; ++ #endif + cpuinfo_isa.fp16arith = true; + } + +@@ -379,10 +378,22 @@ void cpuinfo_arm_mach_init(void) { + num_clusters * sizeof(struct cpuinfo_cluster), num_clusters); + goto cleanup; + } ++ uarchs = calloc(num_clusters, sizeof(struct cpuinfo_uarch_info)); ++ if (uarchs == NULL) { ++ cpuinfo_log_error( ++ "failed to allocate %zu bytes for descriptions of %"PRIu32" uarchs", ++ num_clusters * sizeof(enum cpuinfo_uarch), num_clusters); ++ goto cleanup; ++ } + uint32_t cluster_idx = UINT32_MAX; + for (uint32_t i = 0; i < mach_topology.cores; i++) { + if (i == 0 || cores[i].uarch != cores[i - 1].uarch) { + cluster_idx++; ++ uarchs[cluster_idx] = (struct cpuinfo_uarch_info) { ++ .uarch = cores[i].uarch, ++ .processor_count = 1, ++ .core_count = 1, ++ }; + clusters[cluster_idx] = (struct cpuinfo_cluster) { + .processor_start = i * threads_per_core, + .processor_count = 1, +@@ -394,6 +405,8 @@ void cpuinfo_arm_mach_init(void) { + .uarch = cores[i].uarch, + }; + } else { ++ uarchs[cluster_idx].processor_count++; ++ uarchs[cluster_idx].core_count++; + clusters[cluster_idx].processor_count++; + clusters[cluster_idx].core_count++; + } +@@ -542,26 +555,25 @@ void cpuinfo_arm_mach_init(void) { + } + + /* Commit changes */ +- cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; +- cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; +- cpuinfo_cache[cpuinfo_cache_level_2] = l2; +- cpuinfo_cache[cpuinfo_cache_level_3] = l3; +- + cpuinfo_processors = processors; + cpuinfo_cores = cores; + cpuinfo_clusters = clusters; + cpuinfo_packages = packages; +- +- cpuinfo_cache_count[cpuinfo_cache_level_1i] = l1_count; +- cpuinfo_cache_count[cpuinfo_cache_level_1d] = l1_count; +- cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; +- cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; ++ cpuinfo_uarchs = uarchs; ++ cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; ++ cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; ++ cpuinfo_cache[cpuinfo_cache_level_2] = l2; ++ cpuinfo_cache[cpuinfo_cache_level_3] = l3; + + cpuinfo_processors_count = mach_topology.threads; + cpuinfo_cores_count = mach_topology.cores; + cpuinfo_clusters_count = num_clusters; + cpuinfo_packages_count = mach_topology.packages; +- ++ cpuinfo_uarchs_count = num_clusters; ++ cpuinfo_cache_count[cpuinfo_cache_level_1i] = l1_count; ++ cpuinfo_cache_count[cpuinfo_cache_level_1d] = l1_count; ++ cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; ++ cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; + cpuinfo_max_cache_size = cpuinfo_compute_max_cache_size(&processors[0]); + + __sync_synchronize(); +@@ -572,6 +584,7 @@ void cpuinfo_arm_mach_init(void) { + cores = NULL; + clusters = NULL; + packages = NULL; ++ uarchs = NULL; + l1i = l1d = l2 = l3 = NULL; + + cleanup: +@@ -579,6 +592,7 @@ cleanup: + free(cores); + free(clusters); + free(packages); ++ free(uarchs); + free(l1i); + free(l1d); + free(l2); +diff --git a/src/arm/uarch.c b/src/arm/uarch.c +index a38250a..2aef9e7 100644 +--- a/src/arm/uarch.c ++++ b/src/arm/uarch.c +@@ -58,7 +58,9 @@ void cpuinfo_arm_decode_vendor_uarch( + *uarch = cpuinfo_uarch_cortex_a35; + break; + case 0xD05: +- *uarch = cpuinfo_uarch_cortex_a55; ++ // Note: use Variant, not Revision, field ++ *uarch = (midr & CPUINFO_ARM_MIDR_VARIANT_MASK) == 0 ? ++ cpuinfo_uarch_cortex_a55r0 : cpuinfo_uarch_cortex_a55; + break; + case 0xD06: + *uarch = cpuinfo_uarch_cortex_a65; +@@ -257,9 +259,9 @@ void cpuinfo_arm_decode_vendor_uarch( + *vendor = cpuinfo_vendor_arm; + *uarch = cpuinfo_uarch_cortex_a75; + break; +- case 0x803: /* Low-power Kryo 385 "Silver" -> Cortex-A55 */ ++ case 0x803: /* Low-power Kryo 385 "Silver" -> Cortex-A55r0 */ + *vendor = cpuinfo_vendor_arm; +- *uarch = cpuinfo_uarch_cortex_a55; ++ *uarch = cpuinfo_uarch_cortex_a55r0; + break; + case 0x804: /* High-performance Kryo 485 "Gold" / "Gold Prime" -> Cortex-A76 */ + *vendor = cpuinfo_vendor_arm; +diff --git a/src/cpuinfo/common.h b/src/cpuinfo/common.h +index 6ba746e..b2b404d 100644 +--- a/src/cpuinfo/common.h ++++ b/src/cpuinfo/common.h +@@ -12,29 +12,29 @@ + #define CPUINFO_COUNT_OF(array) (sizeof(array) / sizeof(0[array])) + + #if defined(__GNUC__) +- #define CPUINFO_LIKELY(condition) (__builtin_expect(!!(condition), 1)) +- #define CPUINFO_UNLIKELY(condition) (__builtin_expect(!!(condition), 0)) ++ #define CPUINFO_LIKELY(condition) (__builtin_expect(!!(condition), 1)) ++ #define CPUINFO_UNLIKELY(condition) (__builtin_expect(!!(condition), 0)) + #else +- #define CPUINFO_LIKELY(condition) (!!(condition)) +- #define CPUINFO_UNLIKELY(condition) (!!(condition)) ++ #define CPUINFO_LIKELY(condition) (!!(condition)) ++ #define CPUINFO_UNLIKELY(condition) (!!(condition)) + #endif + + #ifndef CPUINFO_INTERNAL +- #if defined(__ELF__) +- #define CPUINFO_INTERNAL __attribute__((__visibility__("internal"))) +- #elif defined(__MACH__) +- #define CPUINFO_INTERNAL __attribute__((__visibility__("hidden"))) +- #else +- #define CPUINFO_INTERNAL +- #endif ++ #if defined(__ELF__) ++ #define CPUINFO_INTERNAL __attribute__((__visibility__("internal"))) ++ #elif defined(__MACH__) ++ #define CPUINFO_INTERNAL __attribute__((__visibility__("hidden"))) ++ #else ++ #define CPUINFO_INTERNAL ++ #endif + #endif + + #ifndef CPUINFO_PRIVATE +- #if defined(__ELF__) +- #define CPUINFO_PRIVATE __attribute__((__visibility__("hidden"))) +- #elif defined(__MACH__) +- #define CPUINFO_PRIVATE __attribute__((__visibility__("hidden"))) +- #else +- #define CPUINFO_PRIVATE +- #endif ++ #if defined(__ELF__) ++ #define CPUINFO_PRIVATE __attribute__((__visibility__("hidden"))) ++ #elif defined(__MACH__) ++ #define CPUINFO_PRIVATE __attribute__((__visibility__("hidden"))) ++ #else ++ #define CPUINFO_PRIVATE ++ #endif + #endif +diff --git a/src/cpuinfo/internal-api.h b/src/cpuinfo/internal-api.h +index f12c48d..c6eed0b 100644 +--- a/src/cpuinfo/internal-api.h ++++ b/src/cpuinfo/internal-api.h +@@ -21,11 +21,13 @@ enum cpuinfo_cache_level { + }; + + extern CPUINFO_INTERNAL bool cpuinfo_is_initialized; ++ + extern CPUINFO_INTERNAL struct cpuinfo_processor* cpuinfo_processors; + extern CPUINFO_INTERNAL struct cpuinfo_core* cpuinfo_cores; + extern CPUINFO_INTERNAL struct cpuinfo_cluster* cpuinfo_clusters; + extern CPUINFO_INTERNAL struct cpuinfo_package* cpuinfo_packages; + extern CPUINFO_INTERNAL struct cpuinfo_cache* cpuinfo_cache[cpuinfo_cache_level_max]; ++ + extern CPUINFO_INTERNAL uint32_t cpuinfo_processors_count; + extern CPUINFO_INTERNAL uint32_t cpuinfo_cores_count; + extern CPUINFO_INTERNAL uint32_t cpuinfo_clusters_count; +@@ -33,6 +35,19 @@ extern CPUINFO_INTERNAL uint32_t cpuinfo_packages_count; + extern CPUINFO_INTERNAL uint32_t cpuinfo_cache_count[cpuinfo_cache_level_max]; + extern CPUINFO_INTERNAL uint32_t cpuinfo_max_cache_size; + ++#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 ++ extern CPUINFO_INTERNAL struct cpuinfo_uarch_info* cpuinfo_uarchs; ++ extern CPUINFO_INTERNAL uint32_t cpuinfo_uarchs_count; ++#else ++ extern CPUINFO_INTERNAL struct cpuinfo_uarch_info cpuinfo_global_uarch; ++#endif ++ ++#ifdef __linux__ ++ extern CPUINFO_INTERNAL uint32_t cpuinfo_linux_cpu_max; ++ extern CPUINFO_INTERNAL const struct cpuinfo_processor** cpuinfo_linux_cpu_to_processor_map; ++ extern CPUINFO_INTERNAL const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map; ++#endif ++ + CPUINFO_PRIVATE void cpuinfo_x86_mach_init(void); + CPUINFO_PRIVATE void cpuinfo_x86_linux_init(void); + #ifdef _WIN32 +diff --git a/src/linux/current.c b/src/linux/current.c +deleted file mode 100644 +index 472a4c9..0000000 +--- a/src/linux/current.c ++++ /dev/null +@@ -1,41 +0,0 @@ +-#include +-#include +-#include +-#include +-#include +- +-#include +- +-#include +-#include +-#include +-#include +- +- +-const struct cpuinfo_processor** cpuinfo_linux_cpu_to_processor_map = NULL; +-const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map = NULL; +- +- +-const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_current_processor(void) { +- if (!cpuinfo_is_initialized) { +- cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_processor"); +- } +- const int cpu = sched_getcpu(); +- if (cpu >= 0) { +- return cpuinfo_linux_cpu_to_processor_map[cpu]; +- } else { +- return &cpuinfo_processors[0]; +- } +-} +- +-const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_current_core(void) { +- if (!cpuinfo_is_initialized) { +- cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_core"); +- } +- const int cpu = sched_getcpu(); +- if (cpu >= 0) { +- return cpuinfo_linux_cpu_to_core_map[cpu]; +- } else { +- return &cpuinfo_cores[0]; +- } +-} +diff --git a/src/x86/api.h b/src/x86/api.h +index 5f5e76d..213c2d8 100644 +--- a/src/x86/api.h ++++ b/src/x86/api.h +@@ -93,7 +93,6 @@ CPUINFO_INTERNAL struct cpuinfo_x86_isa cpuinfo_x86_detect_isa( + const struct cpuid_regs basic_info, const struct cpuid_regs extended_info, + uint32_t max_base_index, uint32_t max_extended_index, + enum cpuinfo_vendor vendor, enum cpuinfo_uarch uarch); +-CPUINFO_INTERNAL struct cpuinfo_x86_isa cpuinfo_x86_nacl_detect_isa(void); + + CPUINFO_INTERNAL void cpuinfo_x86_detect_topology( + uint32_t max_base_index, +diff --git a/src/x86/cache/init.c b/src/x86/cache/init.c +index d581016..dd1f1ea 100644 +--- a/src/x86/cache/init.c ++++ b/src/x86/cache/init.c +@@ -65,7 +65,7 @@ iterate_descriptors: + } + } + +- if (vendor != cpuinfo_vendor_amd && max_base_index >= 4) { ++ if (vendor != cpuinfo_vendor_amd && vendor != cpuinfo_vendor_hygon && max_base_index >= 4) { + struct cpuid_regs leaf4; + uint32_t input_ecx = 0; + uint32_t package_cores_max = 0; +diff --git a/src/x86/cpuid.h b/src/x86/cpuid.h +index 829ec21..9e9e013 100644 +--- a/src/x86/cpuid.h ++++ b/src/x86/cpuid.h +@@ -67,18 +67,13 @@ + } + #endif + +-/* +- * This instruction may be not supported by Native Client validator, +- * make sure it doesn't appear in the binary +- */ +-#ifndef __native_client__ +- static inline uint64_t xgetbv(uint32_t ext_ctrl_reg) { +- #ifdef _MSC_VER +- return (uint64_t)_xgetbv((unsigned int)ext_ctrl_reg); +- #else +- uint32_t lo, hi; +- __asm__(".byte 0x0F, 0x01, 0xD0" : "=a" (lo), "=d" (hi) : "c" (ext_ctrl_reg)); +- return ((uint64_t) hi << 32) | (uint64_t) lo; +- #endif +- } +-#endif ++static inline uint64_t xgetbv(uint32_t ext_ctrl_reg) { ++ #ifdef _MSC_VER ++ return (uint64_t)_xgetbv((unsigned int)ext_ctrl_reg); ++ #else ++ uint32_t lo, hi; ++ __asm__(".byte 0x0F, 0x01, 0xD0" : "=a" (lo), "=d" (hi) : "c" (ext_ctrl_reg)); ++ return ((uint64_t) hi << 32) | (uint64_t) lo; ++ #endif ++} ++ +diff --git a/src/x86/init.c b/src/x86/init.c +index d736578..244359c 100644 +--- a/src/x86/init.c ++++ b/src/x86/init.c +@@ -61,12 +61,8 @@ void cpuinfo_x86_init_processor(struct cpuinfo_x86_processor* processor) { + + cpuinfo_x86_detect_topology(max_base_index, max_extended_index, leaf1, &processor->topology); + +- #ifdef __native_client__ +- cpuinfo_isa = cpuinfo_x86_nacl_detect_isa(); +- #else +- cpuinfo_isa = cpuinfo_x86_detect_isa(leaf1, leaf0x80000001, +- max_base_index, max_extended_index, vendor, uarch); +- #endif ++ cpuinfo_isa = cpuinfo_x86_detect_isa(leaf1, leaf0x80000001, ++ max_base_index, max_extended_index, vendor, uarch); + } + if (max_extended_index >= UINT32_C(0x80000004)) { + struct cpuid_regs brand_string[3]; +diff --git a/src/x86/isa.c b/src/x86/isa.c +index d27dbca..f2e5a28 100644 +--- a/src/x86/isa.c ++++ b/src/x86/isa.c +@@ -244,6 +244,7 @@ struct cpuinfo_x86_isa cpuinfo_x86_detect_isa( + */ + break; + case cpuinfo_vendor_amd: ++ case cpuinfo_vendor_hygon: + isa.prefetch = !!((extended_info.ecx & UINT32_C(0x00000100)) | (extended_info.edx & UINT32_C(0xE0000000))); + break; + default: +@@ -265,6 +266,7 @@ struct cpuinfo_x86_isa cpuinfo_x86_detect_isa( + */ + switch (vendor) { + case cpuinfo_vendor_amd: ++ case cpuinfo_vendor_hygon: + isa.prefetchw = !!((extended_info.ecx & UINT32_C(0x00000100)) | (extended_info.edx & UINT32_C(0xE0000000))); + break; + default: +diff --git a/src/x86/linux/init.c b/src/x86/linux/init.c +index c096336..f565789 100644 +--- a/src/x86/linux/init.c ++++ b/src/x86/linux/init.c +@@ -569,9 +569,6 @@ void cpuinfo_x86_linux_init(void) { + } + + /* Commit changes */ +- cpuinfo_linux_cpu_to_processor_map = linux_cpu_to_processor_map; +- cpuinfo_linux_cpu_to_core_map = linux_cpu_to_core_map; +- + cpuinfo_processors = processors; + cpuinfo_cores = cores; + cpuinfo_clusters = clusters; +@@ -591,24 +588,32 @@ void cpuinfo_x86_linux_init(void) { + cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; + cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; + cpuinfo_cache_count[cpuinfo_cache_level_4] = l4_count; +- + cpuinfo_max_cache_size = cpuinfo_compute_max_cache_size(&processors[0]); + ++ cpuinfo_global_uarch = (struct cpuinfo_uarch_info) { ++ .uarch = x86_processor.uarch, ++ .cpuid = x86_processor.cpuid, ++ .processor_count = processors_count, ++ .core_count = cores_count, ++ }; ++ ++ cpuinfo_linux_cpu_max = x86_linux_processors_count; ++ cpuinfo_linux_cpu_to_processor_map = linux_cpu_to_processor_map; ++ cpuinfo_linux_cpu_to_core_map = linux_cpu_to_core_map; ++ + __sync_synchronize(); + + cpuinfo_is_initialized = true; + +- linux_cpu_to_processor_map = NULL; +- linux_cpu_to_core_map = NULL; + processors = NULL; + cores = NULL; + clusters = NULL; + packages = NULL; + l1i = l1d = l2 = l3 = l4 = NULL; ++ linux_cpu_to_processor_map = NULL; ++ linux_cpu_to_core_map = NULL; + + cleanup: +- free(linux_cpu_to_processor_map); +- free(linux_cpu_to_core_map); + free(x86_linux_processors); + free(processors); + free(cores); +@@ -619,4 +624,6 @@ cleanup: + free(l2); + free(l3); + free(l4); ++ free(linux_cpu_to_processor_map); ++ free(linux_cpu_to_core_map); + } +diff --git a/src/x86/mach/init.c b/src/x86/mach/init.c +index ae2be33..b44d3ad 100644 +--- a/src/x86/mach/init.c ++++ b/src/x86/mach/init.c +@@ -305,30 +305,34 @@ void cpuinfo_x86_mach_init(void) { + } + + /* Commit changes */ ++ cpuinfo_processors = processors; ++ cpuinfo_cores = cores; ++ cpuinfo_clusters = clusters; ++ cpuinfo_packages = packages; + cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; + cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; + cpuinfo_cache[cpuinfo_cache_level_2] = l2; + cpuinfo_cache[cpuinfo_cache_level_3] = l3; + cpuinfo_cache[cpuinfo_cache_level_4] = l4; + +- cpuinfo_processors = processors; +- cpuinfo_cores = cores; +- cpuinfo_clusters = clusters; +- cpuinfo_packages = packages; +- ++ cpuinfo_processors_count = mach_topology.threads; ++ cpuinfo_cores_count = mach_topology.cores; ++ cpuinfo_clusters_count = mach_topology.packages; ++ cpuinfo_packages_count = mach_topology.packages; + cpuinfo_cache_count[cpuinfo_cache_level_1i] = l1_count; + cpuinfo_cache_count[cpuinfo_cache_level_1d] = l1_count; + cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; + cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; + cpuinfo_cache_count[cpuinfo_cache_level_4] = l4_count; +- +- cpuinfo_processors_count = mach_topology.threads; +- cpuinfo_cores_count = mach_topology.cores; +- cpuinfo_clusters_count = mach_topology.packages; +- cpuinfo_packages_count = mach_topology.packages; +- + cpuinfo_max_cache_size = cpuinfo_compute_max_cache_size(&processors[0]); + ++ cpuinfo_global_uarch = (struct cpuinfo_uarch_info) { ++ .uarch = x86_processor.uarch, ++ .cpuid = x86_processor.cpuid, ++ .processor_count = mach_topology.threads, ++ .core_count = mach_topology.cores, ++ }; ++ + __sync_synchronize(); + + cpuinfo_is_initialized = true; +diff --git a/src/x86/nacl/isa.c b/src/x86/nacl/isa.c +deleted file mode 100644 +index 662be33..0000000 +--- a/src/x86/nacl/isa.c ++++ /dev/null +@@ -1,306 +0,0 @@ +-#include +-#include +-#include +- +-#include +- +-#define NACL_CODE_BUNDLE_SIZE 32 +-#include +-#include +- +-static const uint8_t cmpxchg16b_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* MOV edi, edi */ +- 0x89, 0xFF, +- /* CMPXCHG16B [r15 + rdi * 1] */ +- 0x49, 0x0F, 0xC7, 0x0C, 0x3F, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t lzcnt_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* LZCNT eax, ecx */ +- 0xF3, 0x0F, 0xBD, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t popcnt_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* POPCNT eax, ecx */ +- 0xF3, 0x0F, 0xB8, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t movbe_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* MOV ecx, ecx */ +- 0x89, 0xC9, +- /* MOVBE eax, [r15 + rcx * 1] */ +- 0x41, 0x0F, 0x38, 0xF0, 0x04, 0x0F, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t bmi_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* ANDN eax, ecx, edx */ +- 0xC4, 0xE2, 0x70, 0xF2, 0xC2, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t tbm_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* BLCS eax, ecx */ +- 0x8F, 0xE9, 0x78, 0x01, 0xD9, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t three_d_now_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* PFADD mm0, mm1 */ +- 0x0F, 0x0F, 0xC1, 0x9E, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t three_d_now_plus_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* PFNACC mm0, mm1 */ +- 0x0F, 0x0F, 0xC1, 0x8A, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t sse3_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* HADDPS xmm0, xmm1 */ +- 0xF2, 0x0F, 0x7C, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t ssse3_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* PSHUFB xmm0, xmm1 */ +- 0x66, 0x0F, 0x38, 0x00, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t sse4_1_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* PMULLD xmm0, xmm1 */ +- 0x66, 0x0F, 0x38, 0x40, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t sse4_2_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* PCMPGTQ xmm0, xmm1 */ +- 0x66, 0x0F, 0x38, 0x37, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t sse4a_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* EXTRQ xmm0, xmm1 */ +- 0x66, 0x0F, 0x79, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t aes_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* AESENC xmm0, xmm1 */ +- 0x66, 0x0F, 0x38, 0xDC, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t pclmulqdq_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* PCLMULQDQ xmm0, xmm1, 0 */ +- 0x66, 0x0F, 0x3A, 0x44, 0xC1, 0x00, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t avx_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* VPERMILPS ymm0, ymm1, 0xAA */ +- 0xC4, 0xE3, 0x7D, 0x04, 0xC1, 0xAA, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t fma3_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* VFMADDSUB213PS ymm0, ymm1, ymm2 */ +- 0xC4, 0xE2, 0x75, 0xA6, 0xC2, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t fma4_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* VFMADDPS ymm0, ymm1, ymm2, ymm3 */ +- 0xC4, 0xE3, 0xF5, 0x68, 0xC3, 0x20, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t xop_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* VPHADDBQ xmm0, xmm1 */ +- 0x8F, 0xE9, 0x78, 0xC3, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t f16c_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* VCVTPH2PS ymm0, xmm1 */ +- 0xC4, 0xE2, 0x7D, 0x13, 0xC1, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +-static const uint8_t avx2_bundle[NACL_CODE_BUNDLE_SIZE] = { +- /* VPERMPS ymm0, ymm1, ymm2 */ +- 0xC4, 0xE2, 0x75, 0x16, 0xC2, +- /* Fill remainder with HLTs */ +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, +-}; +- +- +-struct cpuinfo_x86_isa cpuinfo_x86_nacl_detect_isa(void) { +- /* +- * Under Native Client sandbox we can't just ask the CPU: +- * - First, some instructions (XGETBV) necessary to query AVX support are not white-listed in the validator. +- * - Secondly, even if CPU supports some instruction, but validator doesn't know about it (e.g. due a bug in the +- * ISA detection in the validator), all instructions from the "unsupported" ISA extensions will be replaced by +- * HLTs when the module is loaded. +- * Thus, instead of quering the CPU about supported ISA extensions, we query the validator: we pass bundles with +- * instructions from ISA extensions to dynamic code generation APIs, and test if they are accepted. +- */ +- +- struct cpuinfo_x86_isa isa = { 0 }; +- +- struct nacl_irt_code_data_alloc nacl_irt_code_data_alloc = { 0 }; +- struct nacl_irt_dyncode nacl_irt_dyncode = { 0 }; +- if (sizeof(nacl_irt_code_data_alloc) != nacl_interface_query(NACL_IRT_CODE_DATA_ALLOC_v0_1, +- &nacl_irt_code_data_alloc, +- sizeof(nacl_irt_code_data_alloc))) +- { +- goto finish; +- } +- +- if (sizeof(nacl_irt_dyncode) != nacl_interface_query(NACL_IRT_DYNCODE_v0_1, +- &nacl_irt_dyncode, +- sizeof(nacl_irt_dyncode))) +- { +- goto finish; +- } +- +- const size_t allocation_size = 65536; +- uintptr_t code_segment = 0; +- if (0 != nacl_irt_code_data_alloc.allocate_code_data(0, allocation_size, 0, 0, &code_segment)) +- { +- goto finish; +- } +- +- isa.cmpxchg16b = !nacl_irt_dyncode.dyncode_create((void*) code_segment, cmpxchg16b_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.lzcnt = !nacl_irt_dyncode.dyncode_create((void*) code_segment, lzcnt_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.popcnt = !nacl_irt_dyncode.dyncode_create((void*) code_segment, popcnt_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.movbe = !nacl_irt_dyncode.dyncode_create((void*) code_segment, movbe_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.bmi = !nacl_irt_dyncode.dyncode_create((void*) code_segment, bmi_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.tbm = !nacl_irt_dyncode.dyncode_create((void*) code_segment, tbm_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.three_d_now = !nacl_irt_dyncode.dyncode_create((void*) code_segment, three_d_now_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.three_d_now_plus = +- !nacl_irt_dyncode.dyncode_create((void*) code_segment, three_d_now_plus_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.sse3 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, sse3_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.ssse3 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, ssse3_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.sse4_1 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, sse4_1_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.sse4_2 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, sse4_2_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.sse4a = !nacl_irt_dyncode.dyncode_create((void*) code_segment, sse4a_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.aes = !nacl_irt_dyncode.dyncode_create((void*) code_segment, aes_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.pclmulqdq = !nacl_irt_dyncode.dyncode_create((void*) code_segment, pclmulqdq_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.avx = !nacl_irt_dyncode.dyncode_create((void*) code_segment, avx_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.fma3 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, fma3_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.fma4 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, fma4_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.xop = !nacl_irt_dyncode.dyncode_create((void*) code_segment, xop_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.f16c = !nacl_irt_dyncode.dyncode_create((void*) code_segment, f16c_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- code_segment += NACL_CODE_BUNDLE_SIZE; +- +- isa.avx2 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, avx2_bundle, NACL_CODE_BUNDLE_SIZE) && +- (*((const uint8_t*) code_segment) != 0xF4); +- +-finish: +- return isa; +-} +diff --git a/src/x86/name.c b/src/x86/name.c +index 708be1d..e0d5a5b 100644 +--- a/src/x86/name.c ++++ b/src/x86/name.c +@@ -671,6 +671,7 @@ static const char* vendor_string_map[] = { + [cpuinfo_vendor_intel] = "Intel", + [cpuinfo_vendor_amd] = "AMD", + [cpuinfo_vendor_via] = "VIA", ++ [cpuinfo_vendor_hygon] = "Hygon", + [cpuinfo_vendor_rdc] = "RDC", + [cpuinfo_vendor_dmp] = "DM&P", + [cpuinfo_vendor_transmeta] = "Transmeta", +diff --git a/src/x86/uarch.c b/src/x86/uarch.c +index ba72d8a..ecaa762 100644 +--- a/src/x86/uarch.c ++++ b/src/x86/uarch.c +@@ -79,6 +79,8 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( + case 0x5E: // Sky Lake Client DT/H/S + case 0x8E: // Kaby/Whiskey/Amber/Comet Lake Y/U + case 0x9E: // Kaby/Coffee Lake DT/H/S ++ case 0xA5: // Comet Lake H/S ++ case 0xA6: // Comet Lake U/Y + return cpuinfo_uarch_sky_lake; + case 0x66: // Cannon Lake (Core i3-8121U) + return cpuinfo_uarch_palm_cove; +@@ -94,7 +96,7 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( + return cpuinfo_uarch_bonnell; + case 0x27: // Medfield + case 0x35: // Cloverview +- case 0x36: // Cedarview, Centerton ++ case 0x36: // Cedarview, Centerton + return cpuinfo_uarch_saltwell; + case 0x37: // Bay Trail + case 0x4A: // Merrifield +@@ -110,6 +112,7 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( + return cpuinfo_uarch_goldmont; + case 0x7A: // Gemini Lake + return cpuinfo_uarch_goldmont_plus; ++ + /* Knights-series cores */ + case 0x57: + return cpuinfo_uarch_knights_landing; +@@ -173,7 +176,7 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( + case 0x38: // Godavari + case 0x30: // Kaveri + return cpuinfo_uarch_steamroller; +- case 0x60: // Carrizo ++ case 0x60: // Carrizo + case 0x65: // Bristol Ridge + case 0x70: // Stoney Ridge + return cpuinfo_uarch_excavator; +@@ -201,14 +204,22 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( + switch (model_info->model) { + case 0x01: // 14 nm Naples, Whitehaven, Summit Ridge, Snowy Owl + case 0x08: // 12 nm Pinnacle Ridge +- case 0x11: // 14 nm Raven Ridge ++ case 0x11: // 14 nm Raven Ridge, Great Horned Owl + case 0x18: // 12 nm Picasso + return cpuinfo_uarch_zen; ++ case 0x31: // Rome, Castle Peak ++ case 0x60: // Renoir + case 0x71: // Matisse + return cpuinfo_uarch_zen2; + } + } + break; ++ case cpuinfo_vendor_hygon: ++ switch (model_info->family) { ++ case 0x00: ++ return cpuinfo_uarch_dhyana; ++ } ++ break; + default: + break; + } +diff --git a/src/x86/vendor.c b/src/x86/vendor.c +index 3f3c753..2bba90d 100644 +--- a/src/x86/vendor.c ++++ b/src/x86/vendor.c +@@ -26,6 +26,11 @@ + #define auls UINT32_C(0x736C7561) + #define VIA UINT32_C(0x20414956) + ++/* Hygon vendor string: "HygonGenuine" */ ++#define Hygo UINT32_C(0x6F677948) ++#define nGen UINT32_C(0x6E65476E) ++#define uine UINT32_C(0x656E6975) ++ + /* Transmeta vendor strings: "GenuineTMx86", "TransmetaCPU" */ + #define ineT UINT32_C(0x54656E69) + #define Mx86 UINT32_C(0x3638784D) +@@ -105,6 +110,12 @@ enum cpuinfo_vendor cpuinfo_x86_decode_vendor(uint32_t ebx, uint32_t ecx, uint32 + return cpuinfo_vendor_via; + } + break; ++ case Hygo: ++ if (edx == nGen && ecx == uine) { ++ /* "HygonGenuine" */ ++ return cpuinfo_vendor_hygon; ++ } ++ break; + #if CPUINFO_ARCH_X86 + case AMDi: + if (edx == sbet && ecx == ter) { +diff --git a/src/x86/windows/init.c b/src/x86/windows/init.c +index 7a2090e..2c7e3cd 100644 +--- a/src/x86/windows/init.c ++++ b/src/x86/windows/init.c +@@ -417,9 +417,6 @@ BOOL CALLBACK cpuinfo_x86_windows_init(PINIT_ONCE init_once, PVOID parameter, PV + for (uint32_t i = 0; i < processors_count; i++) { + const uint32_t apic_id = processors[i].apic_id; + +- //linux_cpu_to_processor_map[x86_linux_processors[i].linux_id] = processors + processor_index; +- //linux_cpu_to_core_map[x86_linux_processors[i].linux_id] = cores + core_index; +- + if (x86_processor.cache.l1i.size != 0) { + const uint32_t l1i_id = apic_id & ~bit_mask(x86_processor.cache.l1i.apic_bits); + processors[i].cache.l1i = &l1i[l1i_index]; +@@ -549,30 +546,34 @@ BOOL CALLBACK cpuinfo_x86_windows_init(PINIT_ONCE init_once, PVOID parameter, PV + + + /* Commit changes */ ++ cpuinfo_processors = processors; ++ cpuinfo_cores = cores; ++ cpuinfo_clusters = clusters; ++ cpuinfo_packages = packages; + cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; + cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; + cpuinfo_cache[cpuinfo_cache_level_2] = l2; + cpuinfo_cache[cpuinfo_cache_level_3] = l3; + cpuinfo_cache[cpuinfo_cache_level_4] = l4; + +- cpuinfo_processors = processors; +- cpuinfo_cores = cores; +- cpuinfo_clusters = clusters; +- cpuinfo_packages = packages; +- ++ cpuinfo_processors_count = processors_count; ++ cpuinfo_cores_count = cores_count; ++ cpuinfo_clusters_count = packages_count; ++ cpuinfo_packages_count = packages_count; + cpuinfo_cache_count[cpuinfo_cache_level_1i] = l1i_count; + cpuinfo_cache_count[cpuinfo_cache_level_1d] = l1d_count; + cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; + cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; + cpuinfo_cache_count[cpuinfo_cache_level_4] = l4_count; +- +- cpuinfo_processors_count = processors_count; +- cpuinfo_cores_count = cores_count; +- cpuinfo_clusters_count = packages_count; +- cpuinfo_packages_count = packages_count; +- + cpuinfo_max_cache_size = cpuinfo_compute_max_cache_size(&processors[0]); + ++ cpuinfo_global_uarch = (struct cpuinfo_uarch_info) { ++ .uarch = x86_processor.uarch, ++ .cpuid = x86_processor.cpuid, ++ .processor_count = processors_count, ++ .core_count = cores_count, ++ }; ++ + MemoryBarrier(); + + cpuinfo_is_initialized = true; +diff --git a/test/arm-cache.cc b/test/arm-cache.cc +index 8373f7c..7d2e4a4 100644 +--- a/test/arm-cache.cc ++++ b/test/arm-cache.cc +@@ -766,7 +766,7 @@ TEST(QUALCOMM, snapdragon_845) { + struct cpuinfo_cache little_l2 = { 0 }; + struct cpuinfo_cache little_l3 = { 0 }; + cpuinfo_arm_decode_cache( +- cpuinfo_uarch_cortex_a55, 4, UINT32_C(0x518F803C), ++ cpuinfo_uarch_cortex_a55r0, 4, UINT32_C(0x518F803C), + &chipset, 1, 8, + &little_l1i, &little_l1d, &little_l2, &little_l3); + +@@ -910,7 +910,7 @@ TEST(SAMSUNG, exynos_9810) { + struct cpuinfo_cache little_l2 = { 0 }; + struct cpuinfo_cache little_l3 = { 0 }; + cpuinfo_arm_decode_cache( +- cpuinfo_uarch_cortex_a55, 4, UINT32_C(0x410FD051), ++ cpuinfo_uarch_cortex_a55r0, 4, UINT32_C(0x410FD051), + &chipset, 1, 8, + &little_l1i, &little_l1d, &little_l2, &little_l3); + +diff --git a/test/get-current.cc b/test/get-current.cc +index 4a80cab..f410b12 100644 +--- a/test/get-current.cc ++++ b/test/get-current.cc +@@ -3,34 +3,36 @@ + #include + + +-TEST(CURRENT_PROCESSOR, not_null) { +- ASSERT_TRUE(cpuinfo_initialize()); +- +- ASSERT_TRUE(cpuinfo_get_current_processor()); +-} +- + TEST(CURRENT_PROCESSOR, within_bounds) { + ASSERT_TRUE(cpuinfo_initialize()); + + const struct cpuinfo_processor* current_processor = cpuinfo_get_current_processor(); ++ if (current_processor == nullptr) { ++ GTEST_SKIP(); ++ } ++ + const struct cpuinfo_processor* processors_begin = cpuinfo_get_processors(); + const struct cpuinfo_processor* processors_end = processors_begin + cpuinfo_get_processors_count(); + ASSERT_GE(current_processor, processors_begin); + ASSERT_LT(current_processor, processors_end); + } + +-TEST(CURRENT_CORE, not_null) { +- ASSERT_TRUE(cpuinfo_initialize()); +- +- ASSERT_TRUE(cpuinfo_get_current_core()); +-} +- + TEST(CURRENT_CORE, within_bounds) { + ASSERT_TRUE(cpuinfo_initialize()); + + const struct cpuinfo_core* current_core = cpuinfo_get_current_core(); ++ if (current_core == nullptr) { ++ GTEST_SKIP(); ++ } ++ + const struct cpuinfo_core* cores_begin = cpuinfo_get_cores(); + const struct cpuinfo_core* cores_end = cores_begin + cpuinfo_get_cores_count(); + ASSERT_GE(current_core, cores_begin); + ASSERT_LT(current_core, cores_end); + } ++ ++TEST(CURRENT_UARCH_INDEX, within_bounds) { ++ ASSERT_TRUE(cpuinfo_initialize()); ++ ++ ASSERT_LT(cpuinfo_get_current_uarch_index(), cpuinfo_get_uarchs_count()); ++} +diff --git a/test/init.cc b/test/init.cc +index 941cb97..718eb96 100644 +--- a/test/init.cc ++++ b/test/init.cc +@@ -678,6 +678,72 @@ TEST(PACKAGE, consistent_cluster) { + cpuinfo_deinitialize(); + } + ++TEST(UARCHS_COUNT, within_bounds) { ++ ASSERT_TRUE(cpuinfo_initialize()); ++ EXPECT_NE(0, cpuinfo_get_uarchs_count()); ++ EXPECT_LE(cpuinfo_get_packages_count(), cpuinfo_get_cores_count()); ++ EXPECT_LE(cpuinfo_get_packages_count(), cpuinfo_get_processors_count()); ++ cpuinfo_deinitialize(); ++} ++ ++TEST(UARCHS, non_null) { ++ ASSERT_TRUE(cpuinfo_initialize()); ++ EXPECT_TRUE(cpuinfo_get_uarchs()); ++ cpuinfo_deinitialize(); ++} ++ ++TEST(UARCH, non_null) { ++ ASSERT_TRUE(cpuinfo_initialize()); ++ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { ++ EXPECT_TRUE(cpuinfo_get_uarch(i)); ++ } ++ cpuinfo_deinitialize(); ++} ++ ++TEST(UARCH, non_zero_processors) { ++ ASSERT_TRUE(cpuinfo_initialize()); ++ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { ++ const cpuinfo_uarch_info* uarch = cpuinfo_get_uarch(i); ++ ASSERT_TRUE(uarch); ++ ++ EXPECT_NE(0, uarch->processor_count); ++ } ++ cpuinfo_deinitialize(); ++} ++ ++TEST(UARCH, valid_processors) { ++ ASSERT_TRUE(cpuinfo_initialize()); ++ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { ++ const cpuinfo_uarch_info* uarch = cpuinfo_get_uarch(i); ++ ASSERT_TRUE(uarch); ++ ++ EXPECT_LE(uarch->processor_count, cpuinfo_get_processors_count()); ++ } ++ cpuinfo_deinitialize(); ++} ++ ++TEST(UARCH, non_zero_cores) { ++ ASSERT_TRUE(cpuinfo_initialize()); ++ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { ++ const cpuinfo_uarch_info* uarch = cpuinfo_get_uarch(i); ++ ASSERT_TRUE(uarch); ++ ++ EXPECT_NE(0, uarch->core_count); ++ } ++ cpuinfo_deinitialize(); ++} ++ ++TEST(UARCH, valid_cores) { ++ ASSERT_TRUE(cpuinfo_initialize()); ++ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { ++ const cpuinfo_uarch_info* uarch = cpuinfo_get_uarch(i); ++ ASSERT_TRUE(uarch); ++ ++ EXPECT_LE(uarch->core_count, cpuinfo_get_cores_count()); ++ } ++ cpuinfo_deinitialize(); ++} ++ + TEST(L1I_CACHES_COUNT, within_bounds) { + ASSERT_TRUE(cpuinfo_initialize()); + EXPECT_NE(0, cpuinfo_get_l1i_caches_count()); +diff --git a/test/mock/galaxy-s9-global.cc b/test/mock/galaxy-s9-global.cc +index 7a67129..6c72513 100644 +--- a/test/mock/galaxy-s9-global.cc ++++ b/test/mock/galaxy-s9-global.cc +@@ -207,7 +207,7 @@ TEST(CORES, uarch) { + case 5: + case 6: + case 7: +- ASSERT_EQ(cpuinfo_uarch_cortex_a55, cpuinfo_get_core(i)->uarch); ++ ASSERT_EQ(cpuinfo_uarch_cortex_a55r0, cpuinfo_get_core(i)->uarch); + break; + } + } +@@ -329,7 +329,7 @@ TEST(CLUSTERS, uarch) { + ASSERT_EQ(cpuinfo_uarch_exynos_m3, cpuinfo_get_cluster(i)->uarch); + break; + case 1: +- ASSERT_EQ(cpuinfo_uarch_cortex_a55, cpuinfo_get_cluster(i)->uarch); ++ ASSERT_EQ(cpuinfo_uarch_cortex_a55r0, cpuinfo_get_cluster(i)->uarch); + break; + } + } +diff --git a/test/mock/galaxy-s9-us.cc b/test/mock/galaxy-s9-us.cc +index 6df7f3c..ceea969 100644 +--- a/test/mock/galaxy-s9-us.cc ++++ b/test/mock/galaxy-s9-us.cc +@@ -168,7 +168,7 @@ TEST(CORES, uarch) { + case 5: + case 6: + case 7: +- ASSERT_EQ(cpuinfo_uarch_cortex_a55, cpuinfo_get_core(i)->uarch); ++ ASSERT_EQ(cpuinfo_uarch_cortex_a55r0, cpuinfo_get_core(i)->uarch); + break; + } + } +@@ -283,7 +283,7 @@ TEST(CLUSTERS, uarch) { + ASSERT_EQ(cpuinfo_uarch_cortex_a75, cpuinfo_get_cluster(i)->uarch); + break; + case 1: +- ASSERT_EQ(cpuinfo_uarch_cortex_a55, cpuinfo_get_cluster(i)->uarch); ++ ASSERT_EQ(cpuinfo_uarch_cortex_a55r0, cpuinfo_get_cluster(i)->uarch); + break; + } + } +@@ -817,4 +817,4 @@ int main(int argc, char* argv[]) { + cpuinfo_initialize(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +-} +\ No newline at end of file ++} +diff --git a/tools/cpu-info.c b/tools/cpu-info.c +index 7fa5187..7963c00 100644 +--- a/tools/cpu-info.c ++++ b/tools/cpu-info.c +@@ -14,6 +14,8 @@ static const char* vendor_to_string(enum cpuinfo_vendor vendor) { + return "Intel"; + case cpuinfo_vendor_amd: + return "AMD"; ++ case cpuinfo_vendor_hygon: ++ return "Hygon"; + case cpuinfo_vendor_arm: + return "ARM"; + case cpuinfo_vendor_qualcomm: +@@ -161,6 +163,8 @@ static const char* uarch_to_string(enum cpuinfo_uarch uarch) { + return "Cortex-A35"; + case cpuinfo_uarch_cortex_a53: + return "Cortex-A53"; ++ case cpuinfo_uarch_cortex_a55r0: ++ return "Cortex-A55r0"; + case cpuinfo_uarch_cortex_a55: + return "Cortex-A55"; + case cpuinfo_uarch_cortex_a57: +@@ -223,6 +227,10 @@ static const char* uarch_to_string(enum cpuinfo_uarch uarch) { + return "Vortex"; + case cpuinfo_uarch_tempest: + return "Tempest"; ++ case cpuinfo_uarch_lightning: ++ return "Lightning"; ++ case cpuinfo_uarch_thunder: ++ return "Thunder"; + case cpuinfo_uarch_thunderx: + return "ThunderX"; + case cpuinfo_uarch_thunderx2: +@@ -253,6 +261,17 @@ int main(int argc, char** argv) { + printf("\t%"PRIu32": %s\n", i, cpuinfo_get_package(i)->name); + } + #endif ++ printf("Microarchitectures:\n"); ++ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { ++ const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(i); ++ const char* uarch_string = uarch_to_string(uarch_info->uarch); ++ if (uarch_string == NULL) { ++ printf("\t%"PRIu32"x Unknown (0x%08"PRIx32"\n", ++ uarch_info->core_count, (uint32_t) uarch_info->uarch); ++ } else { ++ printf("\t%"PRIu32"x %s\n", uarch_info->core_count, uarch_string); ++ } ++ } + printf("Cores:\n"); + for (uint32_t i = 0; i < cpuinfo_get_cores_count(); i++) { + const struct cpuinfo_core* core = cpuinfo_get_core(i); +@@ -277,17 +296,17 @@ int main(int argc, char** argv) { + } + } + printf("Logical processors"); +- #if defined(__linux__) +- printf(" (System ID)"); +- #endif +- printf(":\n"); ++ #if defined(__linux__) ++ printf(" (System ID)"); ++ #endif ++ printf(":\n"); + for (uint32_t i = 0; i < cpuinfo_get_processors_count(); i++) { + const struct cpuinfo_processor* processor = cpuinfo_get_processor(i); +- printf("\t%"PRIu32"", i); ++ printf("\t%"PRIu32"", i); + +- #if defined(__linux__) +- printf(" (%"PRId32")", processor->linux_id); +- #endif ++ #if defined(__linux__) ++ printf(" (%"PRId32")", processor->linux_id); ++ #endif + + #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + printf(": APIC ID 0x%08"PRIx32"\n", processor->apic_id); diff --git a/third_party/cpuinfo/workspace.bzl b/third_party/cpuinfo/workspace.bzl index c2eeede8a0d..77aecf5a9a9 100644 --- a/third_party/cpuinfo/workspace.bzl +++ b/third_party/cpuinfo/workspace.bzl @@ -2,14 +2,20 @@ load("//third_party:repo.bzl", "third_party_http_archive") +# Sanitize a dependency so that it works correctly from code that includes +# TensorFlow as a submodule. +def clean_dep(dep): + return str(Label(dep)) + def repo(): third_party_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-e39a5790059b6b8274ed91f7b5b5b13641dff267", - sha256 = "e5caa8b7c58f1623eed88f4d5147e3753ff19cde821526bc9aa551b004f751fe", + strip_prefix = "cpuinfo-d6c0f915ee737f961915c9d17f1679b6777af207", + sha256 = "146fc61c3cf63d7d88db963876929a4d373f621fb65568b895efa0857f467770", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pytorch/cpuinfo/archive/e39a5790059b6b8274ed91f7b5b5b13641dff267.tar.gz", - "https://github.com/pytorch/cpuinfo/archive/e39a5790059b6b8274ed91f7b5b5b13641dff267.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pytorch/cpuinfo/archive/d6c0f915ee737f961915c9d17f1679b6777af207.tar.gz", + "https://github.com/pytorch/cpuinfo/archive/d6c0f915ee737f961915c9d17f1679b6777af207.tar.gz", ], build_file = "//third_party/cpuinfo:BUILD.bazel", + patch_file = clean_dep("//third_party/cpuinfo:cpuinfo.patch"), ) From 8eee8630315ff6ce0e054eae2e80576303c66052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cjaketae=E2=80=9D?= Date: Wed, 18 Mar 2020 10:18:50 +0900 Subject: [PATCH 115/492] Fix tf.keras export --- tensorflow/python/keras/preprocessing/text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/keras/preprocessing/text.py b/tensorflow/python/keras/preprocessing/text.py index e9cfe677b5c..96f4b19660e 100644 --- a/tensorflow/python/keras/preprocessing/text.py +++ b/tensorflow/python/keras/preprocessing/text.py @@ -55,7 +55,7 @@ def text_to_word_sequence(text, text, filters=filters, lower=lower, split=split) -@keras_export('tf.keras.preprocessing.text.one_hot') +@keras_export('keras.preprocessing.text.one_hot') def one_hot(text, n, filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', lower=True, From 3c61d997d0bb9a3634b64e3d09d637793db300ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BF=97=E8=B1=AA?= Date: Wed, 18 Mar 2020 09:45:30 +0800 Subject: [PATCH 116/492] Add missing dependencies download command --- tensorflow/lite/tools/pip_package/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/tools/pip_package/README.md b/tensorflow/lite/tools/pip_package/README.md index 88906ee2bb0..484b81546ad 100644 --- a/tensorflow/lite/tools/pip_package/README.md +++ b/tensorflow/lite/tools/pip_package/README.md @@ -8,6 +8,7 @@ Python without requiring the rest of TensorFlow. To build a binary wheel run this script: ``` sudo apt install swig libjpeg-dev zlib1g-dev python3-dev python3-numpy +sh tensorflow/lite/tools/make/download_dependencies.sh sh tensorflow/lite/tools/pip_package/build_pip_package.sh ``` That will print out some output and a .whl file. You can then install that From 75cf60d9fe3528b03dc490313ad1a48905fa4c51 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Tue, 17 Mar 2020 18:51:00 -0700 Subject: [PATCH 117/492] Refactoring: Extract base test class to lite_v2_test_util module (NFC) PiperOrigin-RevId: 301499331 Change-Id: Ib99e2fd5f94def257a8c0de8c85e1cebfbc9ce53 --- tensorflow/lite/python/BUILD | 17 ++++ tensorflow/lite/python/lite_v2_test.py | 86 ++-------------- tensorflow/lite/python/lite_v2_test_util.py | 105 ++++++++++++++++++++ 3 files changed, 129 insertions(+), 79 deletions(-) create mode 100644 tensorflow/lite/python/lite_v2_test_util.py diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 5903a96fb52..86c1b2995f1 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -161,6 +161,23 @@ py_test( tags = [ "no_windows", ], + deps = [ + ":lite", + ":lite_v2_test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "@six_archive//:six", + ], +) + +py_library( + name = "lite_v2_test_util", + testonly = 1, + srcs = ["lite_v2_test_util.py"], + srcs_version = "PY2AND3", + tags = [ + "no_windows", + ], deps = [ ":lite", "//tensorflow/python:client_testlib", diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 658853ef4e8..8cc05eb5f36 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -27,6 +27,7 @@ from six.moves import range from six.moves import zip from tensorflow.lite.python import lite +from tensorflow.lite.python import lite_v2_test_util from tensorflow.lite.python.interpreter import Interpreter from tensorflow.python import keras from tensorflow.python.client import session @@ -55,80 +56,7 @@ from tensorflow.python.saved_model.save import save from tensorflow.python.training.tracking import tracking -class TestModels(test_util.TensorFlowTestCase, parameterized.TestCase): - - def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None): - """Evaluates the model on the `input_data`. - - Args: - tflite_model: TensorFlow Lite model. - input_data: List of EagerTensor const ops containing the input data for - each input tensor. - input_shapes: List of tuples representing the `shape_signature` and the - new shape of each input tensor that has unknown dimensions. - - Returns: - [np.ndarray] - """ - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - if input_shapes: - for idx, (shape_signature, final_shape) in enumerate(input_shapes): - self.assertTrue( - (input_details[idx]['shape_signature'] == shape_signature).all()) - interpreter.resize_tensor_input(idx, final_shape) - interpreter.allocate_tensors() - - output_details = interpreter.get_output_details() - - for input_tensor, tensor_data in zip(input_details, input_data): - interpreter.set_tensor(input_tensor['index'], tensor_data.numpy()) - interpreter.invoke() - return [ - interpreter.get_tensor(details['index']) for details in output_details - ] - - def _getSimpleVariableModel(self): - root = tracking.AutoTrackable() - root.v1 = variables.Variable(3.) - root.v2 = variables.Variable(2.) - root.f = def_function.function(lambda x: root.v1 * root.v2 * x) - return root - - def _getMultiFunctionModel(self): - - class BasicModel(tracking.AutoTrackable): - - def __init__(self): - self.y = None - self.z = None - - @def_function.function - def add(self, x): - if self.y is None: - self.y = variables.Variable(2.) - return x + self.y - - @def_function.function - def sub(self, x): - if self.z is None: - self.z = variables.Variable(3.) - return x - self.z - - return BasicModel() - - def _assertValidDebugInfo(self, debug_info): - """Verify the DebugInfo is valid.""" - file_names = set() - for file_path in debug_info.files: - file_names.add(os.path.basename(file_path)) - # To make the test independent on how the nodes are created, we only assert - # the name of this test file. - self.assertIn('lite_v2_test.py', file_names) - self.assertNotIn('lite_test.py', file_names) - - -class FromConcreteFunctionTest(TestModels): +class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testTypeInvalid(self): @@ -422,7 +350,7 @@ class FromConcreteFunctionTest(TestModels): self._assertValidDebugInfo(converter._debug_info) -class FromSavedModelTest(TestModels): +class FromSavedModelTest(lite_v2_test_util.ModelTest): def _createV1SavedModel(self, shape): """Create a simple SavedModel.""" @@ -604,7 +532,7 @@ class FromSavedModelTest(TestModels): self._assertValidDebugInfo(converter._debug_info) -class FromKerasModelTest(TestModels): +class FromKerasModelTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testSequentialModel(self): @@ -689,7 +617,7 @@ class FromKerasModelTest(TestModels): self._assertValidDebugInfo(converter._debug_info) -class ControlFlowTest(TestModels): +class ControlFlowTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testCond(self): @@ -883,7 +811,7 @@ class ControlFlowTest(TestModels): np.testing.assert_almost_equal(expected_value, actual_value, decimal=5) -class GrapplerTest(TestModels): +class GrapplerTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testConstantFolding(self): @@ -919,7 +847,7 @@ class GrapplerTest(TestModels): np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0]) -class UnknownShapes(TestModels): +class UnknownShapes(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testMatMul(self): diff --git a/tensorflow/lite/python/lite_v2_test_util.py b/tensorflow/lite/python/lite_v2_test_util.py new file mode 100644 index 00000000000..5ea239f22a2 --- /dev/null +++ b/tensorflow/lite/python/lite_v2_test_util.py @@ -0,0 +1,105 @@ +# Lint as: python2, python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lite.py functionality related to TensorFlow 2.0.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl.testing import parameterized +from six.moves import zip + +from tensorflow.lite.python.interpreter import Interpreter +from tensorflow.python.eager import def_function +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.training.tracking import tracking + + +class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase): + """Base test class for TensorFlow Lite 2.x model tests.""" + + def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None): + """Evaluates the model on the `input_data`. + + Args: + tflite_model: TensorFlow Lite model. + input_data: List of EagerTensor const ops containing the input data for + each input tensor. + input_shapes: List of tuples representing the `shape_signature` and the + new shape of each input tensor that has unknown dimensions. + + Returns: + [np.ndarray] + """ + interpreter = Interpreter(model_content=tflite_model) + input_details = interpreter.get_input_details() + if input_shapes: + for idx, (shape_signature, final_shape) in enumerate(input_shapes): + self.assertTrue( + (input_details[idx]['shape_signature'] == shape_signature).all()) + interpreter.resize_tensor_input(idx, final_shape) + interpreter.allocate_tensors() + + output_details = interpreter.get_output_details() + + for input_tensor, tensor_data in zip(input_details, input_data): + interpreter.set_tensor(input_tensor['index'], tensor_data.numpy()) + interpreter.invoke() + return [ + interpreter.get_tensor(details['index']) for details in output_details + ] + + def _getSimpleVariableModel(self): + root = tracking.AutoTrackable() + root.v1 = variables.Variable(3.) + root.v2 = variables.Variable(2.) + root.f = def_function.function(lambda x: root.v1 * root.v2 * x) + return root + + def _getMultiFunctionModel(self): + + class BasicModel(tracking.AutoTrackable): + + def __init__(self): + self.y = None + self.z = None + + @def_function.function + def add(self, x): + if self.y is None: + self.y = variables.Variable(2.) + return x + self.y + + @def_function.function + def sub(self, x): + if self.z is None: + self.z = variables.Variable(3.) + return x - self.z + + return BasicModel() + + def _assertValidDebugInfo(self, debug_info): + """Verify the DebugInfo is valid.""" + file_names = set() + for file_path in debug_info.files: + file_names.add(os.path.basename(file_path)) + # To make the test independent on how the nodes are created, we only assert + # the name of this test file. + self.assertIn('lite_v2_test.py', file_names) + self.assertNotIn('lite_test.py', file_names) From 15f6b288fd9d20d2ad5f417cc88bd86856c202d0 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Tue, 17 Mar 2020 18:55:24 -0700 Subject: [PATCH 118/492] Simplify lite_v2_test.py code (NFC) PiperOrigin-RevId: 301499818 Change-Id: Ief73ffdf0421c8a55c99a4e4d307a3eeaeae0255 --- tensorflow/lite/python/BUILD | 1 + tensorflow/lite/python/lite_v2_test.py | 267 +++++++++++-------------- 2 files changed, 122 insertions(+), 146 deletions(-) diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 86c1b2995f1..7248792523e 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -164,6 +164,7 @@ py_test( deps = [ ":lite", ":lite_v2_test_util", + "//tensorflow:tensorflow_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "@six_archive//:six", diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 8cc05eb5f36..e0595893531 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -25,30 +25,15 @@ from absl.testing import parameterized import numpy as np from six.moves import range from six.moves import zip +import tensorflow as tf from tensorflow.lite.python import lite from tensorflow.lite.python import lite_v2_test_util from tensorflow.lite.python.interpreter import Interpreter -from tensorflow.python import keras -from tensorflow.python.client import session -from tensorflow.python.eager import context -from tensorflow.python.eager import def_function -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras.layers import recurrent from tensorflow.python.keras.layers import recurrent_v2 -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import rnn -from tensorflow.python.ops import rnn_cell_impl -from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import save_options from tensorflow.python.saved_model import saved_model @@ -71,7 +56,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testFloat(self, enable_mlir): root = self._getSimpleVariableModel() - input_data = constant_op.constant(1., shape=[1]) + input_data = tf.constant(1., shape=[1]) concrete_func = root.f.get_concrete_function(input_data) # Convert model. @@ -87,7 +72,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testScalarInput(self): root = self._getSimpleVariableModel() - input_data = constant_op.constant(1., shape=[]) + input_data = tf.constant(1., shape=[]) concrete_func = root.f.get_concrete_function(input_data) # Convert model. @@ -103,7 +88,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): def testMultiFunctionModel(self): """Convert a single model in a multi-functional model.""" root = self._getMultiFunctionModel() - input_data = constant_op.constant(1., shape=[1]) + input_data = tf.constant(1., shape=[1]) concrete_func = root.add.get_concrete_function(input_data) # Convert model and ensure model is not None. @@ -119,7 +104,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): def testConvertMultipleFunctions(self): """Convert multiple functions in a multi-functional model.""" root = self._getMultiFunctionModel() - input_data = constant_op.constant(1., shape=[1]) + input_data = tf.constant(1., shape=[1]) add_func = root.add.get_concrete_function(input_data) sub_func = root.sub.get_concrete_function(input_data) @@ -136,16 +121,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): root = tracking.AutoTrackable() - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[1, 5, 5, 3], dtype=dtypes.float32) - ]) + @tf.function( + input_signature=[tf.TensorSpec(shape=[1, 5, 5, 3], dtype=tf.float32)]) def func(inp): - conv = nn_ops.conv2d( - inp, - filter=array_ops.ones([3, 3, 3, 16]), - strides=[1, 1, 1, 1], - padding='SAME') - output = nn_ops.relu(conv, name='output') + conv = tf.nn.conv2d( + inp, tf.ones([3, 3, 3, 16]), strides=[1, 1, 1, 1], padding='SAME') + output = tf.nn.relu(conv, name='output') return output def calibration_gen(): @@ -216,7 +197,8 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): self.assertLess(len(quantized_tflite), len(float_tflite)) def _getTrainingTimeQuantizedModel(self): - class QLinear(keras.layers.Layer): + + class QLinear(tf.keras.layers.Layer): def __init__(self, units=3, **kwargs): super(QLinear, self).__init__(**kwargs) @@ -228,27 +210,27 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): trainable=True) self.min_var = self.add_weight( 'min', - initializer=keras.initializers.Constant(-6.0), + initializer=tf.keras.initializers.Constant(-6.0), trainable=False) self.max_var = self.add_weight( 'max', - initializer=keras.initializers.Constant(6.0), + initializer=tf.keras.initializers.Constant(6.0), trainable=False) def call(self, inputs): - x = array_ops.fake_quant_with_min_max_vars( + x = tf.quantization.fake_quant_with_min_max_vars( inputs, self.min_var, self.max_var) - w_fq = array_ops.fake_quant_with_min_max_vars( + w_fq = tf.quantization.fake_quant_with_min_max_vars( self.w, self.min_var, self.max_var) - x = math_ops.matmul(x, w_fq) + x = tf.matmul(x, w_fq) - x = array_ops.fake_quant_with_min_max_vars( + x = tf.quantization.fake_quant_with_min_max_vars( x, self.min_var, self.max_var) return x - return keras.Sequential(QLinear(3, input_shape=(2,))) + return tf.keras.Sequential(QLinear(3, input_shape=(2,))) @test_util.run_v2_only def testTrainingTimeQuantizeConversion(self): @@ -289,7 +271,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): new_tflite = quantized_converter.convert() for _ in range(5): - input_data = constant_op.constant( + input_data = tf.constant( np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)) old_value = self._evaluateTFLiteModel(old_tflite, [input_data]) new_value = self._evaluateTFLiteModel(new_tflite, [input_data]) @@ -301,25 +283,23 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testEmbeddings(self, enable_mlir): """Test model with embeddings.""" - input_data = constant_op.constant( + input_data = tf.constant( np.array(np.random.random_sample((20)), dtype=np.int32)) - class EmbeddingModel(keras.Model): + class EmbeddingModel(tf.keras.Model): def __init__(self): super(EmbeddingModel, self).__init__() self.shared_weights = self.add_weight( 'weights', shape=(2000, 300), - dtype=dtypes.float32, - initializer=init_ops.random_normal_initializer( + dtype=tf.float32, + initializer=tf.random_normal_initializer( mean=0.0, stddev=300**(-0.5))) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(20), dtype=dtypes.int32) - ]) + @tf.function(input_signature=[tf.TensorSpec(shape=(20), dtype=tf.int32)]) def func(self, x): - return array_ops.gather(self.shared_weights, x) + return tf.gather(self.shared_weights, x) # Building the model. root = EmbeddingModel() @@ -339,9 +319,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): def testGraphDebugInfo(self): """Test a concrete function has debug info captured.""" root = tracking.AutoTrackable() - root.v1 = variables.Variable(3.) - root.f = def_function.function(lambda x: root.v1 * x) - input_data = constant_op.constant(1., shape=[1]) + root.v1 = tf.Variable(3.) + root.f = tf.function(lambda x: root.v1 * x) + input_data = tf.constant(1., shape=[1]) concrete_func = root.f.get_concrete_function(input_data) # Convert model. @@ -355,24 +335,24 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): def _createV1SavedModel(self, shape): """Create a simple SavedModel.""" saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel') - with ops.Graph().as_default(): - with session.Session() as sess: - in_tensor_1 = array_ops.placeholder( - shape=shape, dtype=dtypes.float32, name='inputB') - in_tensor_2 = array_ops.placeholder( - shape=shape, dtype=dtypes.float32, name='inputA') - variable_node = variables.Variable(1.0, name='variable_node') + with tf.Graph().as_default(): + with tf.compat.v1.Session() as sess: + in_tensor_1 = tf.compat.v1.placeholder( + shape=shape, dtype=tf.float32, name='inputB') + in_tensor_2 = tf.compat.v1.placeholder( + shape=shape, dtype=tf.float32, name='inputA') + variable_node = tf.Variable(1.0, name='variable_node') out_tensor = in_tensor_1 + in_tensor_2 * variable_node inputs = {'x': in_tensor_1, 'y': in_tensor_2} outputs = {'z': out_tensor} - sess.run(variables.variables_initializer([variable_node])) + sess.run(tf.compat.v1.variables_initializer([variable_node])) saved_model.simple_save(sess, saved_model_dir, inputs, outputs) return saved_model_dir @test_util.run_v2_only def testV1SimpleModel(self): """Test a SavedModel.""" - with context.graph_mode(): + with tf.Graph().as_default(): saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3]) # Convert model and ensure model is not None. @@ -405,9 +385,9 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testConstModel(self): """Test a basic model with functions to make sure functions are inlined.""" - input_data = constant_op.constant(1., shape=[1]) + input_data = tf.constant(1., shape=[1]) root = tracking.AutoTrackable() - root.f = def_function.function(lambda x: 2. * x) + root.f = tf.function(lambda x: 2. * x) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') @@ -426,7 +406,7 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): def testVariableModel(self): """Test a basic model with Variables with saving/loading the SavedModel.""" root = self._getSimpleVariableModel() - input_data = constant_op.constant(1., shape=[1]) + input_data = tf.constant(1., shape=[1]) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') @@ -445,7 +425,7 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): def testSignatures(self): """Test values for `signature_keys` argument.""" root = self._getSimpleVariableModel() - input_data = constant_op.constant(1., shape=[1]) + input_data = tf.constant(1., shape=[1]) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') @@ -471,7 +451,7 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): def testMultipleFunctionModel(self): """Convert multiple functions in a multi-functional model.""" root = self._getMultiFunctionModel() - input_data = constant_op.constant(1., shape=[1]) + input_data = tf.constant(1., shape=[1]) add_func = root.add.get_concrete_function(input_data) sub_func = root.sub.get_concrete_function(input_data) @@ -491,14 +471,14 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testKerasSequentialModel(self): """Test a simple sequential tf.Keras model.""" - input_data = constant_op.constant(1., shape=[1, 1]) + input_data = tf.constant(1., shape=[1, 1]) x = np.array([[1.], [2.]]) y = np.array([[2.], [4.]]) - model = keras.models.Sequential([ - keras.layers.Dropout(0.2), - keras.layers.Dense(1), + model = tf.keras.models.Sequential([ + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(1), ]) model.compile(optimizer='sgd', loss='mean_squared_error') model.fit(x, y, epochs=1) @@ -518,9 +498,9 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testGraphDebugInfo(self): """Test a SavedModel has debug info captured.""" - input_data = constant_op.constant(1., shape=[1]) + input_data = tf.constant(1., shape=[1]) root = tracking.AutoTrackable() - root.f = def_function.function(lambda x: 2. * x) + root.f = tf.function(lambda x: 2. * x) to_save = root.f.get_concrete_function(input_data) options = save_options.SaveOptions(save_debug_info=True) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') @@ -537,15 +517,15 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testSequentialModel(self): """Test a simple sequential tf.Keras model.""" - input_data = constant_op.constant(1., shape=[1, 1]) + input_data = tf.constant(1., shape=[1, 1]) # Create a simple Keras model. x = np.array([[1.], [2.]]) y = np.array([[2.], [4.]]) - model = keras.models.Sequential([ - keras.layers.Dropout(0.2), - keras.layers.Dense(units=1, input_shape=[1]) + model = tf.keras.models.Sequential([ + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(units=1, input_shape=[1]) ]) model.compile(optimizer='sgd', loss='mean_squared_error') model.fit(x, y, epochs=1) @@ -562,8 +542,8 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testSequentialMultiInputOutputModel(self): """Test a tf.Keras model with multiple inputs and outputs.""" - left_input_data = constant_op.constant(1., shape=[1, 3]) - right_input_data = constant_op.constant(1., shape=[1, 3]) + left_input_data = tf.constant(1., shape=[1, 3]) + right_input_data = tf.constant(1., shape=[1, 3]) # Create a simple Keras model. input_a_np = np.random.random((10, 3)) @@ -571,22 +551,22 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest): output_c_np = np.random.random((10, 3)) output_d_np = np.random.random((10, 2)) - input_a = keras.layers.Input(shape=(3,), name='input_a') - input_b = keras.layers.Input(shape=(3,), name='input_b') + input_a = tf.keras.layers.Input(shape=(3,), name='input_a') + input_b = tf.keras.layers.Input(shape=(3,), name='input_b') - dense = keras.layers.Dense(8, name='dense_1') + dense = tf.keras.layers.Dense(8, name='dense_1') interm_a = dense(input_a) interm_b = dense(input_b) - merged = keras.layers.concatenate([interm_a, interm_b], name='merge') + merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge') - output_c = keras.layers.Dense( + output_c = tf.keras.layers.Dense( 3, activation='softmax', name='dense_2')( merged) - output_d = keras.layers.Dense( + output_d = tf.keras.layers.Dense( 2, activation='softmax', name='dense_3')( merged) - model = keras.models.Model( + model = tf.keras.models.Model( inputs=[input_a, input_b], outputs=[output_c, output_d]) model.compile(optimizer='sgd', loss='mean_squared_error') model.fit([input_a_np, input_b_np], [output_c_np, output_d_np], epochs=1) @@ -608,8 +588,8 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest): # Create a simple Keras model. x = [-1, 0, 1, 2, 3, 4] y = [-3, -1, 1, 3, 5, 7] - model = keras.models.Sequential( - [keras.layers.Dense(units=1, input_shape=[1])]) + model = tf.keras.models.Sequential( + [tf.keras.layers.Dense(units=1, input_shape=[1])]) model.compile(optimizer='sgd', loss='mean_squared_error') model.fit(x, y, epochs=1) converter = lite.TFLiteConverterV2.from_keras_model(model) @@ -622,24 +602,24 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testCond(self): input_data = { - 'x': constant_op.constant([1., 2.], shape=[1, 2]), - 'b': constant_op.constant(True) + 'x': tf.constant([1., 2.], shape=[1, 2]), + 'b': tf.constant(True) } - weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32) + weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32) def true_fn(x): - return math_ops.matmul(x, weights) + return tf.matmul(x, weights) def false_fn(x): - return math_ops.add(x, weights) + return tf.add(x, weights) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[1, 2], dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool) + @tf.function(input_signature=[ + tf.TensorSpec(shape=[1, 2], dtype=tf.float32), + tf.TensorSpec(shape=(), dtype=tf.bool) ]) def model(x, b): - return control_flow_ops.cond( + return tf.cond( b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x)) concrete_func = model.get_concrete_function() @@ -657,18 +637,17 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testStaticRnn(self): - input_data = constant_op.constant( + input_data = tf.constant( np.array(np.random.random_sample((3, 10)), dtype=np.float32)) - cell = rnn_cell_impl.LSTMCell(10) + cell = tf.compat.v1.nn.rnn_cell.LSTMCell(10) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[3, 10], dtype=dtypes.float32) - ]) + @tf.function( + input_signature=[tf.TensorSpec(shape=[3, 10], dtype=tf.float32)]) def model(x): - seq = array_ops.split(x, 3, 0) - return rnn.static_rnn( - cell, seq, dtype=dtypes.float32, sequence_length=[1]) + seq = tf.split(x, 3, 0) + return tf.compat.v1.nn.static_rnn( + cell, seq, dtype=tf.float32, sequence_length=[1]) concrete_func = model.get_concrete_function() @@ -685,21 +664,20 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testWhileLoop(self): - input_data = constant_op.constant([1., 2., 3., 4.], shape=[2, 2]) + input_data = tf.constant([1., 2., 3., 4.], shape=[2, 2]) - weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32) + weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32) def condition(x): - return math_ops.reduce_sum(x) < 100 + return tf.reduce_sum(x) < 100 def body(x): - return math_ops.add(x, weights) + return tf.add(x, weights) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[2, 2], dtype=dtypes.float32) - ]) + @tf.function( + input_signature=[tf.TensorSpec(shape=[2, 2], dtype=tf.float32)]) def model(x): - return control_flow_ops.while_loop(condition, body, [x]) + return tf.while_loop(condition, body, [x]) concrete_func = model.get_concrete_function() @@ -709,22 +687,21 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): tflite_model = converter.convert() # Check values from converted model. - expected_value = concrete_func(input_data) + expected_value = concrete_func(input_data)[0] actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0] np.testing.assert_almost_equal(expected_value.numpy(), actual_value) @test_util.run_v2_only def testDynamicRnn(self): - input_data = constant_op.constant( + input_data = tf.constant( np.array(np.random.random_sample((3, 10, 10)), dtype=np.float32)) - cell = rnn_cell_impl.LSTMCell(10) + cell = tf.compat.v1.nn.rnn_cell.LSTMCell(10) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[3, 10, 10], dtype=dtypes.float32) - ]) + @tf.function( + input_signature=[tf.TensorSpec(shape=[3, 10, 10], dtype=tf.float32)]) def model(x): - return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32) + return tf.compat.v1.nn.dynamic_rnn(cell, x, dtype=tf.float32) concrete_func = model.get_concrete_function() @@ -750,10 +727,10 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): def testKerasRNN(self, rnn_layer): # This relies on TFLiteConverter to rewrite unknown batch size to 1. The # model will fail if resizing the input to non-1 batch size. - input_data = constant_op.constant( + input_data = tf.constant( np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32)) rnn_obj = rnn_layer(units=10, input_shape=(10, 10)) - model = keras.models.Sequential([rnn_obj]) + model = tf.keras.models.Sequential([rnn_obj]) # Convert model. converter = lite.TFLiteConverterV2.from_keras_model(model) @@ -770,12 +747,12 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): ('GRU', recurrent_v2.GRU)) @test_util.run_v2_only def testKerasRNNMultiBatches(self, rnn_layer): - input_data = constant_op.constant( + input_data = tf.constant( np.array(np.random.random_sample((4, 10, 10)), dtype=np.float32)) # Specify a fixed batch size(4) for the test model. - x = keras.layers.Input(batch_shape=(4, 10, 10)) + x = tf.keras.layers.Input(batch_shape=(4, 10, 10)) y = rnn_layer(units=10, input_shape=(10, 10))(x) - model = keras.Model(inputs=[x], outputs=[y]) + model = tf.keras.Model(inputs=[x], outputs=[y]) # Convert model. converter = lite.TFLiteConverterV2.from_keras_model(model) @@ -789,16 +766,16 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testKerasBidirectionalRNN(self): - input_data = constant_op.constant( + input_data = tf.constant( np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32)) - model = keras.models.Sequential() + model = tf.keras.models.Sequential() model.add( - keras.layers.Bidirectional( + tf.keras.layers.Bidirectional( recurrent_v2.LSTM(units=10, return_sequences=True), input_shape=(10, 10))) - model.add(keras.layers.Bidirectional(recurrent_v2.LSTM(units=10))) - model.add(keras.layers.Dense(5)) - model.add(keras.layers.Activation('softmax')) + model.add(tf.keras.layers.Bidirectional(recurrent_v2.LSTM(units=10))) + model.add(tf.keras.layers.Dense(5)) + model.add(tf.keras.layers.Activation('softmax')) # Convert model. converter = lite.TFLiteConverterV2.from_keras_model(model) @@ -817,14 +794,13 @@ class GrapplerTest(lite_v2_test_util.ModelTest): def testConstantFolding(self): # Constant folding handles the tf.broadcast_to operation which was not # supported by the TFLite at the time this test was added. - input_data = constant_op.constant([1., 2., 3., 4., 5., 6., 7., 8., 9.], - shape=[3, 3]) + input_data = tf.constant([1., 2., 3., 4., 5., 6., 7., 8., 9.], shape=[3, 3]) - @def_function.function + @tf.function def func(x): - y_const = constant_op.constant([1., 2., 3.]) - y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3]) - return math_ops.matmul(x, y_broadcast) + y_const = tf.constant([1., 2., 3.]) + y_broadcast = tf.broadcast_to(y_const, [3, 3]) + return tf.matmul(x, y_broadcast) root = tracking.AutoTrackable() root.f = func @@ -851,16 +827,15 @@ class UnknownShapes(lite_v2_test_util.ModelTest): @test_util.run_v2_only def testMatMul(self): - input_data = constant_op.constant( + input_data = tf.constant( np.array(np.random.random_sample((10, 4)), dtype=np.float32)) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[None, 4], dtype=dtypes.float32) - ]) + @tf.function( + input_signature=[tf.TensorSpec(shape=[None, 4], dtype=tf.float32)]) def model(in_tensor): - shape = array_ops.shape_v2(in_tensor) - fill = array_ops.transpose_v2(array_ops.fill(shape, 1.)) - return math_ops.matmul(fill, in_tensor) + shape = tf.shape(in_tensor) + fill = tf.transpose(tf.fill(shape, 1.)) + return tf.matmul(fill, in_tensor) concrete_func = model.get_concrete_function() @@ -877,17 +852,17 @@ class UnknownShapes(lite_v2_test_util.ModelTest): def testBatchMatMul(self): self.skipTest('BatchMatMulV2 does not support unknown batch size.') - input_data_1 = constant_op.constant( + input_data_1 = tf.constant( np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32)) - input_data_2 = constant_op.constant( + input_data_2 = tf.constant( np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32)) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[None, 256, 256], dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=[None, 256, 256], dtype=dtypes.float32) + @tf.function(input_signature=[ + tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32), + tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32) ]) def model(in_tensor_1, in_tensor_2): - return math_ops.matmul(in_tensor_1, in_tensor_2) + return tf.matmul(in_tensor_1, in_tensor_2) concrete_func = model.get_concrete_function() From 84d930ffe915734b154558d22b07f4215a5225a5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 19:25:56 -0700 Subject: [PATCH 119/492] Creates `framework_lib` target. PiperOrigin-RevId: 301503499 Change-Id: I81b5bda1df702eaefbceaddcbed26a71c345a313 --- tensorflow/lite/BUILD | 66 +++++++++-------------------------- tensorflow/lite/kernels/BUILD | 4 +-- 2 files changed, 18 insertions(+), 52 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 9c4740b8c0a..5e22b1fed5c 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -62,22 +62,6 @@ TFLITE_DEFAULT_COPTS = if_not_windows([ "-Wno-extern-c-compat", ]) -FRAMEWORK_LIB_HDRS = [ - "allocation.h", - "context.h", - "context_util.h", - "core/macros.h", - "core/subgraph.h", - "error_reporter.h", - "graph_info.h", - "interpreter.h", - "model.h", - "mutable_op_resolver.h", - "op_resolver.h", - "optional_debug_tools.h", - "stderr_reporter.h", -] - cc_library( name = "version", hdrs = ["version.h"], @@ -216,8 +200,9 @@ cc_library( ], ) +# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. cc_library( - name = "framework_lib", + name = "framework", srcs = [ "core/subgraph.cc", "graph_info.cc", @@ -227,42 +212,23 @@ cc_library( "optional_debug_tools.cc", "stderr_reporter.cc", ], - hdrs = FRAMEWORK_LIB_HDRS, - copts = tflite_copts() + TFLITE_DEFAULT_COPTS, - visibility = [ - "//tensorflow/lite:__subpackages__", + hdrs = [ + "allocation.h", + "context.h", + "context_util.h", + "core/macros.h", + "core/subgraph.h", + "error_reporter.h", + "graph_info.h", + "interpreter.h", + "model.h", + "mutable_op_resolver.h", + "op_resolver.h", + "optional_debug_tools.h", + "stderr_reporter.h", ], - deps = [ - ":allocation", - ":arena_planner", - ":external_cpu_backend_context", - ":graph_info", - ":memory_planner", - ":minimal_logging", - ":simple_memory_arena", - ":string", - ":type_to_tflitetype", - ":util", - ":version", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core/api", - "//tensorflow/lite/delegates/nnapi:nnapi_delegate", - "//tensorflow/lite/experimental/resource", - "//tensorflow/lite/nnapi:nnapi_implementation", - "//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - -# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. -cc_library( - name = "framework", - srcs = [ - ], - hdrs = FRAMEWORK_LIB_HDRS, copts = tflite_copts() + TFLITE_DEFAULT_COPTS, deps = [ - ":framework_lib", ":allocation", ":arena_planner", ":external_cpu_backend_context", diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 1f04cc3ee47..57e9b876ec1 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -526,7 +526,7 @@ cc_library( ":lstm_shared", ":op_macros", ":padding", - "//tensorflow/lite:framework_lib", + "//tensorflow/lite:framework", "//tensorflow/lite:minimal_logging", "//tensorflow/lite:string_util", "//tensorflow/lite/c:common", @@ -660,7 +660,7 @@ cc_library( ], deps = [ ":builtin_op_kernels", - "//tensorflow/lite:framework_lib", + "//tensorflow/lite:framework", "//tensorflow/lite/c:common", ], ) From 62b171eb9a60769b57ee9768e7ac35d760df945d Mon Sep 17 00:00:00 2001 From: Yi Situ Date: Tue, 17 Mar 2020 20:07:26 -0700 Subject: [PATCH 120/492] Update list of TensorCore eligible operations. PiperOrigin-RevId: 301507891 Change-Id: If5e27d3bddf2473d1cc341d3b150ebd1384f1537 --- .../core/profiler/utils/kernel_stats_utils.cc | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc index 665e802229d..3921c7d6aab 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc @@ -100,8 +100,32 @@ bool IsKernelUsingTensorCore(absl::string_view kernel_name) { // This list is not exhaustive. bool IsOpTensorCoreEligible(absl::string_view tf_op_name) { - return (absl::StrContains(tf_op_name, "Conv") || - absl::StrContains(tf_op_name, "Einsum")); + // Disable formatting to keep inline comments vertically aligned. + // clang-format off + return false + // Using EndsWith to match Fused operations. + || absl::EndsWith(tf_op_name, "Conv2D") + || absl::EndsWith(tf_op_name, "Conv2DBackpropFilter") + || absl::EndsWith(tf_op_name, "Conv2DBackpropInput") + || absl::EndsWith(tf_op_name, "Conv3D") + || absl::EndsWith(tf_op_name, "DepthwiseConv2dNative") + || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropFilter") + || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropInput") + // Using Contains because of numeric suffix and possible Xla prefix. + || absl::StrContains(tf_op_name, "Einsum") + // Using Contains to match V2/V3 suffixes. + || absl::StrContains(tf_op_name, "BatchMatMul") + // MatMul requires exact matching. + || absl::EndsWith(tf_op_name, "/MatMul") + || absl::EndsWith(tf_op_name, "FusedMatMul") + // cuDNN operations. + || absl::EndsWith(tf_op_name, "/CudnnRNN") + || absl::StrContains(tf_op_name, "CudnnRNNV") + || absl::StrContains(tf_op_name, "CudnnRNNForward") + || absl::StrContains(tf_op_name, "CudnnRNNBackprop") + // Special cases. + || absl::EndsWith(tf_op_name, "XlaDot"); + // clang-format on } bool KernelReportLessThanComparator::operator()(const KernelReport& lhs, From 34af8d45d0cece0dce479acb3b0e0281228c1ac0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 21:09:10 -0700 Subject: [PATCH 121/492] Re-enable a test that should be working now PiperOrigin-RevId: 301515524 Change-Id: If70f66ed1e271323fe92c17ef870f44c92b5a964 --- tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index ede1dc21618..b6e26f44ffa 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -803,7 +803,6 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase): self.assertAllClose(hist_k_v1.history['loss'], hist_k_v2.history['loss']) def testNumericEquivalenceForAmsgrad(self): - self.skipTest('b/150382655') if context.executing_eagerly(): self.skipTest( 'v1 optimizer does not run in eager mode') From 98cfe329722976c41191212fddc894d9f664b7f1 Mon Sep 17 00:00:00 2001 From: Fabio Di Domenico Date: Wed, 18 Mar 2020 06:24:49 +0200 Subject: [PATCH 122/492] Minor naming change --- tensorflow/lite/c/c_api.cc | 2 +- tensorflow/lite/c/c_api_experimental.cc | 2 +- tensorflow/lite/c/c_api_internal.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/c/c_api.cc b/tensorflow/lite/c/c_api.cc index 831dcc10286..aa93a10302c 100644 --- a/tensorflow/lite/c/c_api.cc +++ b/tensorflow/lite/c/c_api.cc @@ -123,7 +123,7 @@ TfLiteInterpreter* TfLiteInterpreterCreate( } if (optional_options) { - interpreter->UseNNAPI(optional_options->useNNAPI); + interpreter->UseNNAPI(optional_options->use_nnapi); if (optional_options->num_threads != TfLiteInterpreterOptions::kDefaultNumThreads) { diff --git a/tensorflow/lite/c/c_api_experimental.cc b/tensorflow/lite/c/c_api_experimental.cc index e934d7fede9..cff1b3d1530 100644 --- a/tensorflow/lite/c/c_api_experimental.cc +++ b/tensorflow/lite/c/c_api_experimental.cc @@ -52,7 +52,7 @@ void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options, void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options, bool enable) { - options->useNNAPI = enable; + options->use_nnapi = enable; } #ifdef __cplusplus diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h index ce07f16c33d..f13712362a6 100644 --- a/tensorflow/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -50,7 +50,7 @@ struct TfLiteInterpreterOptions { std::vector delegates; - bool useNNAPI = false; + bool use_nnapi = false; }; struct TfLiteInterpreter { From 9b6fd77bd6705753f7927eb4ccd95e71733b1eed Mon Sep 17 00:00:00 2001 From: feihugis Date: Mon, 16 Mar 2020 14:31:44 -0500 Subject: [PATCH 123/492] Refactor DirectedInterleaveDatasetOp --- .../core/kernels/data/experimental/BUILD | 17 + .../directed_interleave_dataset_op.cc | 482 +++++++++--------- .../directed_interleave_dataset_op.h | 47 ++ 3 files changed, 311 insertions(+), 235 deletions(-) create mode 100644 tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 298982eb356..0359899eac1 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -135,14 +135,31 @@ tf_kernel_library( tf_kernel_library( name = "directed_interleave_dataset_op", srcs = ["directed_interleave_dataset_op.cc"], + hdrs = ["directed_interleave_dataset_op.h"], deps = [ "//tensorflow/core:experimental_dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels/data:name_utils", "//third_party/eigen3", ], ) +tf_cc_test( + name = "directed_interleave_dataset_op_test", + size = "small", + srcs = ["directed_interleave_dataset_op_test.cc"], + deps = [ + ":directed_interleave_dataset_op", + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels/data:dataset_test_base", + "//tensorflow/core/kernels/data:range_dataset_op", + "//tensorflow/core/kernels/data:tensor_slice_dataset_op", + ], +) + tf_kernel_library( name = "group_by_reducer_dataset_op", srcs = ["group_by_reducer_dataset_op.cc"], diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index 48a446be42c..575b2e4ebeb 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -12,284 +12,296 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h" + #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { namespace data { namespace experimental { -namespace { -class DirectedInterleaveDatasetOp : public DatasetOpKernel { +/* static */ constexpr const char* const + DirectedInterleaveDatasetOp::kDatasetType; +/* static */ constexpr const char* const + DirectedInterleaveDatasetOp::kSelectorInputDataset; +/* static */ constexpr const char* const + DirectedInterleaveDatasetOp::kDataInputDatasets; +/* static */ constexpr const char* const + DirectedInterleaveDatasetOp::kOutputTypes; +/* static */ constexpr const char* const + DirectedInterleaveDatasetOp::kOutputShapes; +/* static */ constexpr const char* const + DirectedInterleaveDatasetOp::kNumDatasets; + +class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { public: - explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx) - : DatasetOpKernel(ctx) {} + Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, + std::vector data_inputs) + : DatasetBase(DatasetContext(ctx)), + selector_input_(selector_input), + data_inputs_(std::move(data_inputs)) { + selector_input_->Ref(); - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - DatasetBase* selector_input; - OP_REQUIRES_OK(ctx, - GetDatasetFromVariantTensor(ctx->input(0), &selector_input)); - - OP_REQUIRES( - ctx, - selector_input->output_dtypes().size() == 1 && - selector_input->output_dtypes()[0] == DT_INT64 && - selector_input->output_shapes().size() == 1 && - selector_input->output_shapes()[0].IsCompatibleWith( - PartialTensorShape({})), - errors::InvalidArgument( - "The selector input must be a dataset of scalar int64 elements.")); - - std::vector data_inputs; - for (size_t i = 1; i < ctx->num_inputs(); ++i) { - DatasetBase* input; - OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); - data_inputs.push_back(input); - - OP_REQUIRES( - ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(), - errors::InvalidArgument( - "All inputs must have the same output_dtypes. First input " - "has types ", - DataTypeVectorString(data_inputs[0]->output_dtypes()), - ", and input ", i - 1, " has types ", - DataTypeVectorString(input->output_dtypes()))); + output_shapes_ = data_inputs_[0]->output_shapes(); + data_inputs_[0]->Ref(); + for (size_t i = 1; i < data_inputs_.size(); ++i) { + const DatasetBase* data_input = data_inputs_[i]; + data_input->Ref(); + for (size_t j = 0; j < output_shapes_.size(); ++j) { + output_shapes_[j] = MostSpecificCompatibleShape( + output_shapes_[j], data_input->output_shapes()[j]); + } } - *output = new Dataset(ctx, selector_input, std::move(data_inputs)); + } + + ~Dataset() override { + selector_input_->Unref(); + for (DatasetBase* data_input : data_inputs_) { + data_input->Unref(); + } + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique(Iterator::Params{ + this, name_utils::IteratorPrefix(kDatasetType, prefix)}); + } + + const DataTypeVector& output_dtypes() const override { + return data_inputs_[0]->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return name_utils::DatasetDebugString(kDatasetType); + } + + Status CheckExternalState() const override { + for (const auto& input : data_inputs_) { + TF_RETURN_IF_ERROR(input->CheckExternalState()); + } + return selector_input_->CheckExternalState(); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* selector_input_node; + TF_RETURN_IF_ERROR( + b->AddInputDataset(ctx, selector_input_, &selector_input_node)); + std::vector data_input_nodes(data_inputs_.size()); + for (size_t i = 0; i < data_inputs_.size(); ++i) { + TF_RETURN_IF_ERROR( + b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i])); + } + TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, + {{1, data_input_nodes}}, {}, output)); + return Status::OK(); } private: - class Dataset : public DatasetBase { + class Iterator : public DatasetIterator { public: - Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, - std::vector data_inputs) - : DatasetBase(DatasetContext(ctx)), - selector_input_(selector_input), - data_inputs_(std::move(data_inputs)) { - selector_input_->Ref(); + explicit Iterator(const Params& params) + : DatasetIterator(params), + num_active_inputs_(params.dataset->data_inputs_.size()) {} - output_shapes_ = data_inputs_[0]->output_shapes(); - data_inputs_[0]->Ref(); - for (size_t i = 1; i < data_inputs_.size(); ++i) { - const DatasetBase* data_input = data_inputs_[i]; - data_input->Ref(); - for (size_t j = 0; j < output_shapes_.size(); ++j) { - output_shapes_[j] = MostSpecificCompatibleShape( - output_shapes_[j], data_input->output_shapes()[j]); - } + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, this, prefix(), &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, this, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); } - } - - ~Dataset() override { - selector_input_->Unref(); - for (DatasetBase* data_input : data_inputs_) { - data_input->Unref(); - } - } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique(Iterator::Params{ - this, strings::StrCat(prefix, "::DirectedInterleave")}); - } - - const DataTypeVector& output_dtypes() const override { - return data_inputs_[0]->output_dtypes(); - } - - const std::vector& output_shapes() const override { - return output_shapes_; - } - - string DebugString() const override { - return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); - } - - Status CheckExternalState() const override { - for (const auto& input : data_inputs_) { - TF_RETURN_IF_ERROR(input->CheckExternalState()); - } - return selector_input_->CheckExternalState(); - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* selector_input_node; - TF_RETURN_IF_ERROR( - b->AddInputDataset(ctx, selector_input_, &selector_input_node)); - std::vector data_input_nodes(data_inputs_.size()); - for (size_t i = 0; i < data_inputs_.size(); ++i) { - TF_RETURN_IF_ERROR( - b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i])); - } - TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, - {{1, data_input_nodes}}, {}, output)); return Status::OK(); } - private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params), - num_active_inputs_(params.dataset->data_inputs_.size()) {} - - Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( - ctx, this, strings::StrCat(prefix()), &selector_input_impl_)); - data_input_impls_.resize(dataset()->data_inputs_.size()); - for (size_t i = 0; i < data_input_impls_.size(); ++i) { - const DatasetBase* data_input = dataset()->data_inputs_[i]; - TF_RETURN_IF_ERROR(data_input->MakeIterator( - ctx, this, strings::StrCat(prefix(), "[", i, "]"), - &data_input_impls_[i])); - } + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!selector_input_impl_) { + *end_of_sequence = true; return Status::OK(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - if (!selector_input_impl_) { - *end_of_sequence = true; + while (true) { + std::vector selector_result; + *end_of_sequence = false; + TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(ctx, &selector_result, + end_of_sequence)); + if (*end_of_sequence) { + selector_input_impl_.reset(); + for (auto& data_input_impl : data_input_impls_) { + data_input_impl.reset(); + } return Status::OK(); } - while (true) { - std::vector selector_result; - *end_of_sequence = false; - TF_RETURN_IF_ERROR(selector_input_impl_->GetNext( - ctx, &selector_result, end_of_sequence)); - if (*end_of_sequence) { - selector_input_impl_.reset(); - for (auto& data_input_impl : data_input_impls_) { - data_input_impl.reset(); - } + int64 selected_input = selector_result[0].scalar()(); + if (selected_input < 0 || selected_input >= data_input_impls_.size()) { + return errors::InvalidArgument( + "Selector index out of range: ", selected_input, + " >= ", data_input_impls_.size()); + } + + if (data_input_impls_[selected_input]) { + bool end_of_selected_input = false; + TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext( + ctx, out_tensors, &end_of_selected_input)); + + if (!end_of_selected_input) { return Status::OK(); } - int64 selected_input = selector_result[0].scalar()(); - if (selected_input < 0 || - selected_input >= data_input_impls_.size()) { - return errors::InvalidArgument( - "Selector index out of range: ", selected_input, - " >= ", data_input_impls_.size()); - } + data_input_impls_[selected_input].reset(); + --num_active_inputs_; - if (data_input_impls_[selected_input]) { - bool end_of_selected_input = false; - TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext( - ctx, out_tensors, &end_of_selected_input)); - - if (!end_of_selected_input) { - return Status::OK(); - } - - data_input_impls_[selected_input].reset(); - --num_active_inputs_; - - if (num_active_inputs_ == 0) { - selector_input_impl_.reset(); - *end_of_sequence = true; - return Status::OK(); - } - } - - VLOG(2) << "DirectedInterleave selected an exhausted input: " - << selected_input; - } - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeInterleaveManyNode(std::move(args)); - } - - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { - mutex_lock l(mu_); - if (selector_input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_)); - } else { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("selector_input_impl_empty"), "")); - } - for (size_t i = 0; i < data_input_impls_.size(); ++i) { - const auto& data_input_impl = data_input_impls_[i]; - if (data_input_impl) { - TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl)); - } else { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("data_input_impl_empty[", i, "]")), - "")); + if (num_active_inputs_ == 0) { + selector_input_impl_.reset(); + *end_of_sequence = true; + return Status::OK(); } } - return Status::OK(); - } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (!reader->Contains(full_name("selector_input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); - } else { - selector_input_impl_.reset(); - } - for (size_t i = 0; i < data_input_impls_.size(); ++i) { - if (!reader->Contains(full_name( - strings::StrCat("data_input_impl_empty[", i, "]")))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); - } else { - data_input_impls_[i].reset(); - } - } - return Status::OK(); + VLOG(2) << "DirectedInterleave selected an exhausted input: " + << selected_input; } - - private: - mutex mu_; - std::unique_ptr selector_input_impl_ TF_GUARDED_BY(mu_); - std::vector> data_input_impls_ - TF_GUARDED_BY(mu_); - int64 num_active_inputs_ TF_GUARDED_BY(mu_); - }; - - static PartialTensorShape MostSpecificCompatibleShape( - const PartialTensorShape& ts1, const PartialTensorShape& ts2) { - PartialTensorShape output_tensorshape; - if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) - return output_tensorshape; - auto dims1 = ts1.dim_sizes(); - auto dims2 = ts2.dim_sizes(); - for (int d = 0; d < ts1.dims(); d++) { - if (dims1[d] == dims2[d]) - output_tensorshape.Concatenate(dims1[d]); - else - output_tensorshape.Concatenate(-1); - } - return output_tensorshape; } - const DatasetBase* const selector_input_; - const std::vector data_inputs_; - std::vector output_shapes_; - }; + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeInterleaveManyNode(std::move(args)); + } + + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (selector_input_impl_) { + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_)); + } else { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")), "")); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const auto& data_input_impl = data_input_impls_[i]; + if (data_input_impl) { + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl)); + } else { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")), + "")); + } + } + return Status::OK(); + } + return Status::OK(); + } + + Status + RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("selector_input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); + } else { + selector_input_impl_.reset(); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + if (!reader->Contains( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); + } else { + data_input_impls_[i].reset(); + } + } + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr selector_input_impl_ TF_GUARDED_BY(mu_); + std::vector> data_input_impls_ + TF_GUARDED_BY(mu_); + int64 num_active_inputs_ TF_GUARDED_BY(mu_); }; +static PartialTensorShape MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2) { + PartialTensorShape output_tensorshape; + if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + return output_tensorshape; + auto dims1 = ts1.dim_sizes(); + auto dims2 = ts2.dim_sizes(); + for (int d = 0; d < ts1.dims(); ++d) { + if (dims1[d] == dims2[d]) + output_tensorshape.Concatenate(dims1[d]); + else + output_tensorshape.Concatenate(-1); + } + return output_tensorshape; +} + +const DatasetBase* const selector_input_; +const std::vector data_inputs_; +std::vector output_shapes_; +}; // namespace experimental + +DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp( + OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} + +void DirectedInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx, + DatasetBase** output) { + DatasetBase* selector_input; + OP_REQUIRES_OK(ctx, + GetDatasetFromVariantTensor(ctx->input(0), &selector_input)); + + OP_REQUIRES( + ctx, + selector_input->output_dtypes().size() == 1 && + selector_input->output_dtypes()[0] == DT_INT64 && + selector_input->output_shapes().size() == 1 && + selector_input->output_shapes()[0].IsCompatibleWith( + PartialTensorShape({})), + errors::InvalidArgument( + "The selector input must be a dataset of scalar int64 elements.")); + + std::vector data_inputs; + for (size_t i = 1; i < ctx->num_inputs(); ++i) { + DatasetBase* input; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); + data_inputs.push_back(input); + + OP_REQUIRES(ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(), + errors::InvalidArgument( + "All inputs must have the same output_dtypes. First input " + "has types ", + DataTypeVectorString(data_inputs[0]->output_dtypes()), + ", and input ", i - 1, " has types ", + DataTypeVectorString(input->output_dtypes()))); + } + *output = new Dataset(ctx, selector_input, std::move(data_inputs)); +} + +namespace { REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU), DirectedInterleaveDatasetOp); REGISTER_KERNEL_BUILDER( Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU), DirectedInterleaveDatasetOp); - } // namespace -} // namespace experimental } // namespace data } // namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h new file mode 100644 index 00000000000..03ee8ed0c3f --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class DirectedInterleaveDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "DirectedInterleave"; + static constexpr const char* const kSelectorInputDataset = + "selector_input_dataset"; + static constexpr const char* const kDataInputDatasets = "data_input_datasets"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kNumDatasets = "N"; + + explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_ From 6a580fd9ac1747508bf174fdb920e21127648573 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Mar 2020 21:46:55 -0700 Subject: [PATCH 124/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301519328 Change-Id: I95d5acbc2f5009f08a546e05cac3afea849830b3 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 6456f104ad3..52a9bf9551b 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From a8f9799bdaf71474b0d0627df3e9c4019767277b Mon Sep 17 00:00:00 2001 From: feihugis Date: Tue, 17 Mar 2020 23:59:12 -0500 Subject: [PATCH 125/492] Add tests for DirectedInterleaveDatasetOp --- .../directed_interleave_dataset_op.cc | 86 ++--- .../directed_interleave_dataset_op.h | 2 +- .../directed_interleave_dataset_op_test.cc | 364 ++++++++++++++++++ 3 files changed, 407 insertions(+), 45 deletions(-) create mode 100644 tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index 575b2e4ebeb..eea5ae6ea69 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -34,7 +34,7 @@ namespace experimental { /* static */ constexpr const char* const DirectedInterleaveDatasetOp::kOutputShapes; /* static */ constexpr const char* const - DirectedInterleaveDatasetOp::kNumDatasets; + DirectedInterleaveDatasetOp::kNumInputDatasets; class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { public: @@ -192,8 +192,8 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { if (selector_input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_)); } else { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("data_input_impl_empty[", i, "]")), "")); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("selector_input_impl_empty"), "")); } for (size_t i = 0; i < data_input_impls_.size(); ++i) { const auto& data_input_impl = data_input_impls_[i]; @@ -207,55 +207,53 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { } return Status::OK(); } - return Status::OK(); - } - Status - RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (!reader->Contains(full_name("selector_input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); - } else { - selector_input_impl_.reset(); - } - for (size_t i = 0; i < data_input_impls_.size(); ++i) { - if (!reader->Contains( - full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("selector_input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); } else { - data_input_impls_[i].reset(); + selector_input_impl_.reset(); } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + if (!reader->Contains( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); + } else { + data_input_impls_[i].reset(); + } + } + return Status::OK(); } - return Status::OK(); - } - private: - mutex mu_; - std::unique_ptr selector_input_impl_ TF_GUARDED_BY(mu_); - std::vector> data_input_impls_ - TF_GUARDED_BY(mu_); - int64 num_active_inputs_ TF_GUARDED_BY(mu_); -}; + private: + mutex mu_; + std::unique_ptr selector_input_impl_ TF_GUARDED_BY(mu_); + std::vector> data_input_impls_ + TF_GUARDED_BY(mu_); + int64 num_active_inputs_ TF_GUARDED_BY(mu_); + }; -static PartialTensorShape MostSpecificCompatibleShape( - const PartialTensorShape& ts1, const PartialTensorShape& ts2) { - PartialTensorShape output_tensorshape; - if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + static PartialTensorShape MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2) { + PartialTensorShape output_tensorshape; + if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + return output_tensorshape; + auto dims1 = ts1.dim_sizes(); + auto dims2 = ts2.dim_sizes(); + for (int d = 0; d < ts1.dims(); ++d) { + if (dims1[d] == dims2[d]) + output_tensorshape.Concatenate(dims1[d]); + else + output_tensorshape.Concatenate(-1); + } return output_tensorshape; - auto dims1 = ts1.dim_sizes(); - auto dims2 = ts2.dim_sizes(); - for (int d = 0; d < ts1.dims(); ++d) { - if (dims1[d] == dims2[d]) - output_tensorshape.Concatenate(dims1[d]); - else - output_tensorshape.Concatenate(-1); } - return output_tensorshape; -} -const DatasetBase* const selector_input_; -const std::vector data_inputs_; -std::vector output_shapes_; + const DatasetBase* const selector_input_; + const std::vector data_inputs_; + std::vector output_shapes_; }; // namespace experimental DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp( @@ -302,6 +300,6 @@ REGISTER_KERNEL_BUILDER( Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU), DirectedInterleaveDatasetOp); } // namespace +} // namespace experimental } // namespace data } // namespace tensorflow -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h index 03ee8ed0c3f..3dc689ea63b 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h @@ -29,7 +29,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { static constexpr const char* const kDataInputDatasets = "data_input_datasets"; static constexpr const char* const kOutputTypes = "output_types"; static constexpr const char* const kOutputShapes = "output_shapes"; - static constexpr const char* const kNumDatasets = "N"; + static constexpr const char* const kNumInputDatasets = "N"; explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx); diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc new file mode 100644 index 00000000000..7aed1d7be2f --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc @@ -0,0 +1,364 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h" + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace experimental { +namespace { + +constexpr char kNodeName[] = "directed_interleave_dataset"; + +class DirectedInterleaveDatasetParams : public DatasetParams { + public: + template + DirectedInterleaveDatasetParams(S selector_input_dataset_params, + std::vector input_dataset_params_vec, + DataTypeVector output_dtypes, + std::vector output_shapes, + int num_input_datasets, string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + std::move(node_name)), + num_input_datasets_(num_input_datasets) { + input_dataset_params_.push_back( + absl::make_unique(selector_input_dataset_params)); + for (auto input_dataset_params : input_dataset_params_vec) { + input_dataset_params_.push_back( + absl::make_unique(input_dataset_params)); + } + + if (!input_dataset_params_vec.empty()) { + iterator_prefix_ = name_utils::IteratorPrefix( + input_dataset_params_vec[0].dataset_type(), + input_dataset_params_vec[0].iterator_prefix()); + } + } + + std::vector GetInputTensors() const override { return {}; } + + Status GetInputNames(std::vector* input_names) const override { + input_names->clear(); + input_names->emplace_back( + DirectedInterleaveDatasetOp::kSelectorInputDataset); + for (int i = 0; i < num_input_datasets_; ++i) { + input_names->emplace_back(absl::StrCat( + DirectedInterleaveDatasetOp::kDataInputDatasets, "_", i)); + } + return Status::OK(); + } + + Status GetAttributes(AttributeVector* attr_vector) const override { + attr_vector->clear(); + attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputTypes, + output_dtypes_); + attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputShapes, + output_shapes_); + attr_vector->emplace_back(DirectedInterleaveDatasetOp::kNumInputDatasets, + num_input_datasets_); + return Status::OK(); + } + + string dataset_type() const override { + return DirectedInterleaveDatasetOp::kDatasetType; + } + + private: + int32 num_input_datasets_; +}; + +class DirectedInterleaveDatasetOpTest : public DatasetOpsTestBase {}; + +DirectedInterleaveDatasetParams AlternateInputsParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams SelectExhaustedInputParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 2, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams OneInputDatasetParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 0, 0, 0, 0, 0})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 6, 1)}, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({})}, + /*num_input_datasets=*/1, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams ZeroInputDatasetParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 0, 0, 0, 0, 0})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/std::vector{}, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({})}, + /*num_input_datasets=*/0, + /*node_name=*/kNodeName); +} + +// Test case: `num_input_datasets` is larger than the size of +// `input_dataset_params_vec`. +DirectedInterleaveDatasetParams LargeNumInputDatasetsParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/5, + /*node_name=*/kNodeName); +} + +// Test case: `num_input_datasets` is smaller than the size of +// `input_dataset_params_vec`. +DirectedInterleaveDatasetParams SmallNumInputDatasetsParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/1, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams InvalidSelectorOuputDataType() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams InvalidSelectorOuputShape() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6, 1}, + {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams InvalidSelectorValues() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {2, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams InvalidInputDatasetsDataType() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{ + RangeDatasetParams(0, 3, 1, {DT_INT32}), + RangeDatasetParams(10, 13, 1, {DT_INT64})}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +std::vector> +GetNextTestCases() { + return {{/*dataset_params=*/AlternateInputsParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}, + {/*dataset_params=*/SelectExhaustedInputParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {10}, {1}, {11}, {12}})}}, + {/*dataset_params=*/OneInputDatasetParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}}, + {/*dataset_params=*/LargeNumInputDatasetsParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}, + {/*dataset_params=*/SmallNumInputDatasetsParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}}; +} + +ITERATOR_GET_NEXT_TEST_P(DirectedInterleaveDatasetOpTest, + DirectedInterleaveDatasetParams, GetNextTestCases()) + +TEST_F(DirectedInterleaveDatasetOpTest, DatasetNodeName) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name())); +} + +TEST_F(DirectedInterleaveDatasetOpTest, DatasetTypeString) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetTypeString( + name_utils::OpName(DirectedInterleaveDatasetOp::kDatasetType))); +} + +TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputDtypes) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64})); +} + +TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputShapes) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})})); +} + +TEST_F(DirectedInterleaveDatasetOpTest, Cardinality) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetCardinality(kUnknownCardinality)); +} + +TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputDtypes) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64})); +} + +TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputShapes) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckIteratorOutputShapes({PartialTensorShape({})})); +} + +TEST_F(DirectedInterleaveDatasetOpTest, IteratorPrefix) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckIteratorPrefix( + name_utils::IteratorPrefix(DirectedInterleaveDatasetOp::kDatasetType, + dataset_params.iterator_prefix()))); +} + +std::vector> +IteratorSaveAndRestoreTestCases() { + return { + {/*dataset_params=*/AlternateInputsParams(), + /*breakpoints=*/{0, 5, 8}, + /*expected_outputs=*/ + CreateTensors(TensorShape{}, {{0}, {10}, {1}, {11}, {2}, {12}}), + /*compare_order=*/true}, + {/*dataset_params=*/SelectExhaustedInputParams(), + /*breakpoints=*/{0, 4, 8}, + /*expected_outputs=*/ + CreateTensors(TensorShape{}, {{0}, {10}, {1}, {11}, {12}}), + /*compare_order=*/true}, + {/*dataset_params=*/OneInputDatasetParams(), + /*breakpoints=*/{0, 5, 8}, + /*expected_outputs=*/ + {CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}}, + {/*dataset_params=*/LargeNumInputDatasetsParams(), + /*breakpoints=*/{0, 5, 8}, + /*expected_outputs=*/ + {CreateTensors(TensorShape({}), + {{0}, {10}, {1}, {11}, {2}, {12}})}}, + {/*dataset_params=*/SmallNumInputDatasetsParams(), + /*breakpoints=*/{0, 5, 8}, + /*expected_outputs=*/ + {CreateTensors(TensorShape({}), + {{0}, {10}, {1}, {11}, {2}, {12}})}}}; +} + +ITERATOR_SAVE_AND_RESTORE_TEST_P(DirectedInterleaveDatasetOpTest, + DirectedInterleaveDatasetParams, + IteratorSaveAndRestoreTestCases()) + +TEST_F(DirectedInterleaveDatasetOpTest, InvalidArguments) { + std::vector invalid_params_vec = { + InvalidSelectorOuputDataType(), InvalidSelectorOuputShape(), + InvalidInputDatasetsDataType(), ZeroInputDatasetParams()}; + for (auto& dataset_params : invalid_params_vec) { + EXPECT_EQ(Initialize(dataset_params).code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +TEST_F(DirectedInterleaveDatasetOpTest, InvalidSelectorValues) { + auto dataset_params = InvalidSelectorValues(); + TF_ASSERT_OK(Initialize(dataset_params)); + bool end_of_sequence = false; + std::vector next; + EXPECT_EQ( + iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence).code(), + tensorflow::error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace experimental +} // namespace data +} // namespace tensorflow From 4aad621976d7ea528d7e4f94bc18fc9c907796f5 Mon Sep 17 00:00:00 2001 From: feihugis Date: Wed, 18 Mar 2020 00:00:21 -0500 Subject: [PATCH 126/492] Throw a warning instead of error when CheckOpKernelInput() fails because the number of kernel input tensors is allowed to mismatch with input names --- tensorflow/core/kernels/data/dataset_test_base.cc | 15 +++++++++------ .../directed_interleave_dataset_op.cc | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 9ce29ddd0d5..67881827d71 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -25,7 +25,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -83,6 +82,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" namespace tensorflow { namespace data { @@ -321,7 +321,10 @@ Status DatasetOpsTestBase::CreateDatasetContext( gtl::InlinedVector* const inputs, std::unique_ptr* dataset_context_params, std::unique_ptr* dataset_context) { - TF_RETURN_IF_ERROR(CheckOpKernelInput(*dateset_kernel, *inputs)); + Status status = CheckOpKernelInput(*dateset_kernel, *inputs); + if (!status.ok()) { + VLOG(0) << "WARNING: " << status.ToString(); + } TF_RETURN_IF_ERROR(CreateOpKernelContext( dateset_kernel, inputs, dataset_context_params, dataset_context)); return Status::OK(); @@ -529,10 +532,10 @@ Status DatasetOpsTestBase::CreateSerializationContext( Status DatasetOpsTestBase::CheckOpKernelInput( const OpKernel& kernel, const gtl::InlinedVector& inputs) { - if (kernel.input_types().size() != inputs.size()) { - return errors::Internal("The number of input elements should be ", - kernel.input_types().size(), - ", but got: ", inputs.size()); + if (kernel.num_inputs() != inputs.size()) { + return errors::InvalidArgument("The number of input elements should be ", + kernel.num_inputs(), + ", but got: ", inputs.size()); } return Status::OK(); } diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index eea5ae6ea69..6e52f74a336 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -254,7 +254,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { const DatasetBase* const selector_input_; const std::vector data_inputs_; std::vector output_shapes_; -}; // namespace experimental +}; DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp( OpKernelConstruction* ctx) From 2d11ad74ca395bbc9949f81d84c77560692c484b Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Wed, 18 Mar 2020 00:00:58 -0700 Subject: [PATCH 127/492] Remove the redundant GetDelegates function as we could now get a delegate instance from registered delegate providers. PiperOrigin-RevId: 301533071 Change-Id: Ia2a5c80523e5f6d1246898aafaefa20228b2bdb6 --- .../tools/benchmark/benchmark_tflite_model.cc | 30 +++++++------------ .../tools/benchmark/benchmark_tflite_model.h | 7 ----- 2 files changed, 10 insertions(+), 27 deletions(-) diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 6483559a6f6..35a5f6f16ca 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -615,11 +615,14 @@ TfLiteStatus BenchmarkTfLiteModel::Init() { interpreter_->UseNNAPI(params_.Get("use_legacy_nnapi")); interpreter_->SetAllowFp16PrecisionForFp32(params_.Get("allow_fp16")); - delegates_ = GetDelegates(); - for (const auto& delegate : delegates_) { - if (interpreter_->ModifyGraphWithDelegate(delegate.second.get()) != - kTfLiteOk) { - TFLITE_LOG(ERROR) << "Failed to apply " << delegate.first << " delegate."; + for (const auto& delegate_provider : GetRegisteredDelegateProviders()) { + auto delegate = delegate_provider->CreateTfLiteDelegate(params_); + // It's possible that a delegate of certain type won't be created as + // user-specified benchmark params tells not to. + if (delegate == nullptr) continue; + if (interpreter_->ModifyGraphWithDelegate(delegate.get()) != kTfLiteOk) { + TFLITE_LOG(ERROR) << "Failed to apply " << delegate_provider->GetName() + << " delegate."; return kTfLiteError; } else { bool fully_delegated = true; @@ -629,7 +632,7 @@ TfLiteStatus BenchmarkTfLiteModel::Init() { int first_node_id = interpreter_->execution_plan()[0]; const TfLiteNode first_node = interpreter_->node_and_registration(first_node_id)->first; - if (delegate.second.get() != first_node.delegate) { + if (delegate.get() != first_node.delegate) { fully_delegated = false; } } @@ -639,7 +642,7 @@ TfLiteStatus BenchmarkTfLiteModel::Init() { } const std::string delegate_status = fully_delegated ? "completely" : "partially"; - TFLITE_LOG(INFO) << "Applied " << delegate.first + TFLITE_LOG(INFO) << "Applied " << delegate_provider->GetName() << " delegate, and the model graph will be " << delegate_status << " executed w/ the delegate."; } @@ -698,19 +701,6 @@ TfLiteStatus BenchmarkTfLiteModel::LoadModel() { return kTfLiteOk; } -BenchmarkTfLiteModel::TfLiteDelegatePtrMap BenchmarkTfLiteModel::GetDelegates() - const { - TfLiteDelegatePtrMap delegates; - for (const auto& delegate_util : GetRegisteredDelegateProviders()) { - auto delegate = delegate_util->CreateTfLiteDelegate(params_); - if (delegate != nullptr) { - delegates.emplace(delegate_util->GetName(), std::move(delegate)); - } - } - - return delegates; -} - std::unique_ptr BenchmarkTfLiteModel::GetOpResolver() const { auto resolver = new tflite::ops::builtin::BuiltinOpResolver(); diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index 73082c01be6..16d5c08ac44 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -69,11 +69,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel { int64_t MayGetModelFileSize() override; - // Allow subclasses to create custom delegates to be applied during init. - using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr; - using TfLiteDelegatePtrMap = std::map; - virtual TfLiteDelegatePtrMap GetDelegates() const; - virtual TfLiteStatus LoadModel(); // Allow subclasses to create a customized Op resolver during init. @@ -123,8 +118,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel { std::vector inputs_data_; std::unique_ptr profiling_listener_ = nullptr; std::unique_ptr ruy_profiling_listener_ = nullptr; - TfLiteDelegatePtrMap delegates_; - std::mt19937 random_engine_; }; From c2f7bee60493424147af553d965a26aada5bf426 Mon Sep 17 00:00:00 2001 From: sunchenggen Date: Wed, 18 Mar 2020 15:43:21 +0800 Subject: [PATCH 128/492] fix a bug in AddSymbolicGradients --- tensorflow/cc/framework/gradients.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index bd225c95f7c..8dfdd01318d 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -521,15 +521,15 @@ Status SymbolicGradientBuilder::AddGradients() { // gradient function to the src node/output to which it should be // backpropped. Maybe grad functions can return a vector of Output pairs to // make this association explicit. - size_t dx_index = 0; for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; - if (dx_index == dx.size()) { + int dx_index = e->dst_input(); + if (dx_index >= dx.size()) { return errors::Internal( "Invalid gradient output index: ", dx_index, " size: ", dx.size()); } TF_RETURN_IF_ERROR( - BackpropAlongEdge(dx[dx_index++], {e->src(), e->src_output()})); + BackpropAlongEdge(dx[dx_index], {e->src(), e->src_output()})); } } From 56944a814851a7b4474bbe1f5b2866a75f2783c4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 02:03:43 -0700 Subject: [PATCH 129/492] compat: Update forward compatibility horizon to 2020-03-18 PiperOrigin-RevId: 301546941 Change-Id: Ibf494050d6159e0e33dccf7ce03a87c931e652cf --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 017404dba75..17b23b616d1 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 17) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 18) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From a602383a32a498648d932f8c232b84030645ffcd Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Wed, 18 Mar 2020 02:44:10 -0700 Subject: [PATCH 130/492] Support creating a XNNPACK delegate in model evaluation namespace, and removed unnecessary dependencies on TFLite's runtime. PiperOrigin-RevId: 301551749 Change-Id: I37dd5322093b33f269a9b2d2b6698f27b0c9be88 --- tensorflow/lite/tools/evaluation/BUILD | 26 +++++- .../evaluation_delegate_provider.cc | 80 +++++++++++++++++++ .../evaluation/evaluation_delegate_provider.h | 38 +++++++++ .../evaluation_delegate_provider_test.cc | 44 ++++++++++ .../evaluation/proto/evaluation_stages.proto | 1 + tensorflow/lite/tools/evaluation/utils.cc | 48 +++++++---- tensorflow/lite/tools/evaluation/utils.h | 26 +++--- 7 files changed, 237 insertions(+), 26 deletions(-) create mode 100644 tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc create mode 100644 tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h create mode 100644 tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc diff --git a/tensorflow/lite/tools/evaluation/BUILD b/tensorflow/lite/tools/evaluation/BUILD index 619ff0bd333..a028607482c 100644 --- a/tensorflow/lite/tools/evaluation/BUILD +++ b/tensorflow/lite/tools/evaluation/BUILD @@ -40,9 +40,9 @@ cc_library( hdrs = ["utils.h"], copts = tflite_copts(), deps = [ - "//tensorflow/lite:context", - "//tensorflow/lite:framework", + "//tensorflow/lite/c:common", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", ] + select({ "//tensorflow:android": [ "//tensorflow/lite/delegates/gpu:delegate", @@ -59,6 +59,17 @@ cc_library( }), ) +cc_library( + name = "evaluation_delegate_provider", + srcs = ["evaluation_delegate_provider.cc"], + hdrs = ["evaluation_delegate_provider.h"], + copts = tflite_copts(), + deps = [ + ":utils", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + ], +) + cc_test( name = "utils_test", srcs = ["utils_test.cc"], @@ -74,3 +85,14 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "evaluation_delegate_provider_test", + srcs = ["evaluation_delegate_provider_test.cc"], + linkopts = tflite_linkopts(), + deps = [ + ":evaluation_delegate_provider", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc new file mode 100644 index 00000000000..925cae8d140 --- /dev/null +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" + +namespace tflite { +namespace evaluation { +namespace { +constexpr char kNnapiDelegate[] = "nnapi"; +constexpr char kGpuDelegate[] = "gpu"; +constexpr char kHexagonDelegate[] = "hexagon"; +constexpr char kXnnpackDelegate[] = "xnnpack"; +} // namespace + +TfliteInferenceParams::Delegate ParseStringToDelegateType( + const std::string& val) { + if (val == kNnapiDelegate) return TfliteInferenceParams::NNAPI; + if (val == kGpuDelegate) return TfliteInferenceParams::GPU; + if (val == kHexagonDelegate) return TfliteInferenceParams::HEXAGON; + if (val == kXnnpackDelegate) return TfliteInferenceParams::XNNPACK; + return TfliteInferenceParams::NONE; +} + +TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params, + std::string* error_msg) { + const auto type = params.delegate(); + + switch (type) { + case TfliteInferenceParams::NNAPI: { + auto p = CreateNNAPIDelegate(); + if (!p && error_msg) *error_msg = "NNAPI not supported"; + return p; + } + case TfliteInferenceParams::GPU: { + auto p = CreateGPUDelegate(); + if (!p && error_msg) *error_msg = "GPU delegate not supported."; + return p; + } + case TfliteInferenceParams::HEXAGON: { + auto p = CreateHexagonDelegate(/*library_directory_path=*/"", + /*profiling=*/false); + if (!p && error_msg) { + *error_msg = + "Hexagon delegate is not supported on the platform or required " + "libraries are missing."; + } + return p; + } + case TfliteInferenceParams::XNNPACK: { + auto p = CreateXNNPACKDelegate(params.num_threads()); + if (!p && error_msg) *error_msg = "XNNPACK delegate not supported."; + return p; + } + case TfliteInferenceParams::NONE: + if (error_msg) *error_msg = "No delegate type is specified."; + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); + default: + if (error_msg) { + *error_msg = "Creation of delegate type: " + + TfliteInferenceParams::Delegate_Name(type) + + " not supported yet."; + } + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); + } +} + +} // namespace evaluation +} // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h new file mode 100644 index 00000000000..7f093295be2 --- /dev/null +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_ +#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_ + +#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" +#include "tensorflow/lite/tools/evaluation/utils.h" + +namespace tflite { +namespace evaluation { + +// Parse a string 'val' to the corresponding delegate type defined by +// TfliteInferenceParams::Delegate. +TfliteInferenceParams::Delegate ParseStringToDelegateType( + const std::string& val); + +// Create a TfLite delegate based on the given TfliteInferenceParams 'params'. +// If there's an error during the creation, an error message will be recorded to +// 'error_msg' if provided. +TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params, + std::string* error_msg = nullptr); +} // namespace evaluation +} // namespace tflite + +#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_ diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc new file mode 100644 index 00000000000..1b984206eb6 --- /dev/null +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" + +#include +#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" + +namespace tflite { +namespace evaluation { +namespace { +TEST(EvaluationDelegateProviderTest, ParseStringToDelegateType) { + EXPECT_EQ(TfliteInferenceParams::NNAPI, ParseStringToDelegateType("nnapi")); + EXPECT_EQ(TfliteInferenceParams::GPU, ParseStringToDelegateType("gpu")); + EXPECT_EQ(TfliteInferenceParams::HEXAGON, + ParseStringToDelegateType("hexagon")); + EXPECT_EQ(TfliteInferenceParams::XNNPACK, + ParseStringToDelegateType("xnnpack")); + + EXPECT_EQ(TfliteInferenceParams::NONE, ParseStringToDelegateType("Gpu")); + EXPECT_EQ(TfliteInferenceParams::NONE, ParseStringToDelegateType("Testing")); +} + +TEST(EvaluationDelegateProviderTest, CreateTfLiteDelegate) { + TfliteInferenceParams params; + params.set_delegate(TfliteInferenceParams::NONE); + // A NONE delegate type will return a nullptr TfLite delegate ptr. + EXPECT_TRUE(!CreateTfLiteDelegate(params)); +} + +} // namespace +} // namespace evaluation +} // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto b/tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto index 4b3da52c136..09765d71726 100644 --- a/tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto +++ b/tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto @@ -111,6 +111,7 @@ message TfliteInferenceParams { NNAPI = 1; GPU = 2; HEXAGON = 3; + XNNPACK = 4; } optional Delegate delegate = 2; // Number of threads available to the TFLite Interpreter. diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc index 39e93bee930..e7c477e359d 100644 --- a/tensorflow/lite/tools/evaluation/utils.cc +++ b/tensorflow/lite/tools/evaluation/utils.cc @@ -30,8 +30,8 @@ namespace evaluation { namespace { -Interpreter::TfLiteDelegatePtr CreateNullDelegate() { - return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); +TfLiteDelegatePtr CreateNullDelegate() { + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); } } // namespace @@ -94,21 +94,20 @@ TfLiteStatus GetSortedFileNames( #endif // TODO(b/138448769): Migrate delegate helper APIs to lite/testing. -Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() { +TfLiteDelegatePtr CreateNNAPIDelegate() { #if defined(__ANDROID__) - return Interpreter::TfLiteDelegatePtr( + return TfLiteDelegatePtr( NnApiDelegate(), // NnApiDelegate() returns a singleton, so provide a no-op deleter. [](TfLiteDelegate*) {}); #else - return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); #endif // defined(__ANDROID__) } -Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate( - StatefulNnApiDelegate::Options options) { +TfLiteDelegatePtr CreateNNAPIDelegate(StatefulNnApiDelegate::Options options) { #if defined(__ANDROID__) - return Interpreter::TfLiteDelegatePtr( + return TfLiteDelegatePtr( new StatefulNnApiDelegate(options), [](TfLiteDelegate* delegate) { delete reinterpret_cast(delegate); }); @@ -118,14 +117,13 @@ Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate( } #if defined(__ANDROID__) -Interpreter::TfLiteDelegatePtr CreateGPUDelegate( - TfLiteGpuDelegateOptionsV2* options) { - return Interpreter::TfLiteDelegatePtr(TfLiteGpuDelegateV2Create(options), - &TfLiteGpuDelegateV2Delete); +TfLiteDelegatePtr CreateGPUDelegate(TfLiteGpuDelegateOptionsV2* options) { + return TfLiteDelegatePtr(TfLiteGpuDelegateV2Create(options), + &TfLiteGpuDelegateV2Delete); } #endif // defined(__ANDROID__) -Interpreter::TfLiteDelegatePtr CreateGPUDelegate() { +TfLiteDelegatePtr CreateGPUDelegate() { #if defined(__ANDROID__) TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY; @@ -138,7 +136,7 @@ Interpreter::TfLiteDelegatePtr CreateGPUDelegate() { #endif // defined(__ANDROID__) } -Interpreter::TfLiteDelegatePtr CreateHexagonDelegate( +TfLiteDelegatePtr CreateHexagonDelegate( const std::string& library_directory_path, bool profiling) { #if defined(__ANDROID__) && (defined(__arm__) || defined(__aarch64__)) if (library_directory_path.empty()) { @@ -155,7 +153,7 @@ Interpreter::TfLiteDelegatePtr CreateHexagonDelegate( TfLiteHexagonTearDown(); return CreateNullDelegate(); } - return Interpreter::TfLiteDelegatePtr(delegate, [](TfLiteDelegate* delegate) { + return TfLiteDelegatePtr(delegate, [](TfLiteDelegate* delegate) { TfLiteHexagonDelegateDelete(delegate); TfLiteHexagonTearDown(); }); @@ -164,5 +162,25 @@ Interpreter::TfLiteDelegatePtr CreateHexagonDelegate( #endif // defined(__ANDROID__) } +TfLiteDelegatePtr CreateXNNPACKDelegate() { + TfLiteXNNPackDelegateOptions xnnpack_options = + TfLiteXNNPackDelegateOptionsDefault(); + return CreateXNNPACKDelegate(&xnnpack_options); +} + +TfLiteDelegatePtr CreateXNNPACKDelegate( + const TfLiteXNNPackDelegateOptions* xnnpack_options) { + auto xnnpack_delegate = TfLiteXNNPackDelegateCreate(xnnpack_options); + return TfLiteDelegatePtr(xnnpack_delegate, [](TfLiteDelegate* delegate) { + TfLiteXNNPackDelegateDelete(delegate); + }); +} + +TfLiteDelegatePtr CreateXNNPACKDelegate(int num_threads) { + auto opts = TfLiteXNNPackDelegateOptionsDefault(); + // Note that we don't want to use the thread pool for num_threads == 1. + opts.num_threads = num_threads > 1 ? num_threads : 0; + return CreateXNNPACKDelegate(&opts); +} } // namespace evaluation } // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/utils.h b/tensorflow/lite/tools/evaluation/utils.h index a143daf637a..0602ed4259e 100644 --- a/tensorflow/lite/tools/evaluation/utils.h +++ b/tensorflow/lite/tools/evaluation/utils.h @@ -27,12 +27,18 @@ limitations under the License. #endif #endif -#include "tensorflow/lite/context.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" namespace tflite { namespace evaluation { + +// Same w/ Interpreter::TfLiteDelegatePtr to avoid pulling +// tensorflow/lite/interpreter.h dependency +using TfLiteDelegatePtr = + std::unique_ptr; + std::string StripTrailingSlashes(const std::string& path); bool ReadFileLines(const std::string& file_path, @@ -50,20 +56,22 @@ inline TfLiteStatus GetSortedFileNames(const std::string& directory, std::unordered_set()); } -Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(); +TfLiteDelegatePtr CreateNNAPIDelegate(); -Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate( - StatefulNnApiDelegate::Options options); +TfLiteDelegatePtr CreateNNAPIDelegate(StatefulNnApiDelegate::Options options); -Interpreter::TfLiteDelegatePtr CreateGPUDelegate(); +TfLiteDelegatePtr CreateGPUDelegate(); #if defined(__ANDROID__) -Interpreter::TfLiteDelegatePtr CreateGPUDelegate( - TfLiteGpuDelegateOptionsV2* options); +TfLiteDelegatePtr CreateGPUDelegate(TfLiteGpuDelegateOptionsV2* options); #endif -Interpreter::TfLiteDelegatePtr CreateHexagonDelegate( +TfLiteDelegatePtr CreateHexagonDelegate( const std::string& library_directory_path, bool profiling); +TfLiteDelegatePtr CreateXNNPACKDelegate(); +TfLiteDelegatePtr CreateXNNPACKDelegate( + const TfLiteXNNPackDelegateOptions* options); +TfLiteDelegatePtr CreateXNNPACKDelegate(int num_threads); } // namespace evaluation } // namespace tflite From 3e03279a370fcf374957f742112ea3a0049ef26c Mon Sep 17 00:00:00 2001 From: Stefano Galarraga Date: Wed, 18 Mar 2020 03:30:00 -0700 Subject: [PATCH 131/492] Add capability to disable NNAPI CPU and check NNAPI Errno. PiperOrigin-RevId: 301557307 Change-Id: Ie01923f55502a7472560a96a26b95c59cf597646 --- .../tensorflow/lite/nnapi/NnApiDelegate.java | 56 ++++++++++++++++++- .../src/main/native/nnapi_delegate_jni.cc | 16 +++++- .../lite/nnapi/NnApiDelegateTest.java | 31 ++++++++++ 3 files changed, 100 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java b/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java index 989cb2c1480..e4092ec5684 100644 --- a/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java +++ b/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java @@ -89,11 +89,23 @@ public class NnApiDelegate implements Delegate, AutoCloseable { return this; } + /** + * Enable or disable the NNAPI CPU Device "nnapi-reference". If unset it will use the NNAPI + * default settings. + * + *

Only effective on Android 10 and above. + */ + public Options setUseNnapiCpu(boolean enable) { + this.useNnapiCpu = !enable; + return this; + } + private int executionPreference = EXECUTION_PREFERENCE_UNDEFINED; private String acceleratorName = null; private String cacheDir = null; private String modelToken = null; private Integer maxDelegatedPartitions = null; + private Boolean useNnapiCpu = null; } public NnApiDelegate(Options options) { @@ -105,7 +117,11 @@ public class NnApiDelegate implements Delegate, AutoCloseable { options.acceleratorName, options.cacheDir, options.modelToken, - options.maxDelegatedPartitions != null ? options.maxDelegatedPartitions : -1); + options.maxDelegatedPartitions != null ? options.maxDelegatedPartitions : -1, + /*overrideDisallowCpu=*/ options.useNnapiCpu != null, + /*disallowCpuValue=*/ options.useNnapiCpu != null + ? !options.useNnapiCpu.booleanValue() + : false); } public NnApiDelegate() { @@ -130,13 +146,49 @@ public class NnApiDelegate implements Delegate, AutoCloseable { } } + /** + * Returns the latest error code returned by an NNAPI call or zero if NO calls to NNAPI failed. + * The error code is reset when the delegate is associated with an {@link + * #org.tensorflow.lite.Interpreter interpreter}). + * + *

For details on NNAPI error codes see the NNAPI + * documentation. + * + * @throws IllegalStateException if the method is called after {@link #close() close}. + */ + public int getNnapiErrno() { + checkNotClosed(); + return getNnapiErrno(delegateHandle); + } + + /** + * Returns true if any NNAPI call failed since this delegate was associated with an {@link + * #org.tensorflow.lite.Interpreter interpreter}). + * + * @throws IllegalStateException if the method is called after {@link #close() close}. + */ + public boolean hasErrors() { + return getNnapiErrno(delegateHandle) != 0 /*ANEURALNETWORKS_NO_ERROR*/; + } + + private void checkNotClosed() { + if (delegateHandle == INVALID_DELEGATE_HANDLE) { + throw new IllegalStateException("Should not access delegate after it has been closed."); + } + } + // private static native long createDelegate( int preference, String deviceName, String cacheDir, String modelToken, - int maxDelegatedPartitions); + int maxDelegatedPartitions, + boolean overrideDisallowCpu, + boolean disallowCpuValue); private static native void deleteDelegate(long delegateHandle); + + private static native int getNnapiErrno(long delegateHandle); } diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc index d256faedd11..6b5171ddfef 100644 --- a/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc +++ b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc @@ -26,7 +26,8 @@ using namespace tflite; JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate( JNIEnv* env, jclass clazz, jint preference, jstring accelerator_name, - jstring cache_dir, jstring model_token, jint max_delegated_partitions) { + jstring cache_dir, jstring model_token, jint max_delegated_partitions, + jboolean override_disallow_cpu, jboolean disallow_cpu_value) { StatefulNnApiDelegate::Options options = StatefulNnApiDelegate::Options(); options.execution_preference = (StatefulNnApiDelegate::Options::ExecutionPreference)preference; @@ -44,6 +45,10 @@ Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate( options.max_number_delegated_partitions = max_delegated_partitions; } + if (override_disallow_cpu) { + options.disallow_nnapi_cpu = disallow_cpu_value; + } + auto delegate = new StatefulNnApiDelegate(options); if (options.accelerator_name) { @@ -61,6 +66,15 @@ Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate( return reinterpret_cast(delegate); } +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_nnapi_NnApiDelegate_getNnapiErrno(JNIEnv* env, + jclass clazz, + jlong delegate) { + StatefulNnApiDelegate* nnapi_delegate = + reinterpret_cast(delegate); + return nnapi_delegate->GetNnApiErrno(); +} + JNIEXPORT void JNICALL Java_org_tensorflow_lite_nnapi_NnApiDelegate_deleteDelegate(JNIEnv* env, jclass clazz, diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java index 82d4da0cefb..e3742dab9a3 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.lite.nnapi; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; import java.nio.ByteBuffer; import org.junit.Test; @@ -54,4 +55,34 @@ public final class NnApiDelegateTest { assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); } } + + @Test + public void testGetNnApiErrnoReturnsZeroIfNoNnapiCallFailed() throws Exception { + Interpreter.Options options = new Interpreter.Options(); + try (NnApiDelegate delegate = new NnApiDelegate(); + Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) { + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + + assertThat(delegate.getNnapiErrno()).isEqualTo(0); + assertThat(delegate.hasErrors()).isFalse(); + } + } + + @Test + public void testGetNnApiErrnoThrowsExceptionAfterClosingDelegate() { + NnApiDelegate delegate = new NnApiDelegate(); + assertThat(delegate.getNnapiErrno()).isEqualTo(0); + + delegate.close(); + try { + delegate.getNnapiErrno(); + fail("Expected IllegalStateException to be thrown."); + } catch (IllegalStateException expected) { + } + } } From b4f47b79c81c915e81b9c37dd5b956b24b5d516d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 03:46:05 -0700 Subject: [PATCH 132/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301558945 Change-Id: Ic034bbf1ddca63deb4dcd85c0946907c982eaa11 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 52a9bf9551b..6456f104ad3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From b1e880c7ccae5985488059b11ca0b29c6c4f993b Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 18 Mar 2020 04:38:47 -0700 Subject: [PATCH 133/492] Fix gpu_backend_lib for https://review.llvm.org/D75579. Bump tensorflow open source LLVM revision to 98369178bc69. PiperOrigin-RevId: 301564784 Change-Id: I80e155730cb214fe445045d31f236174d0dfdba4 --- .../gpu/llvm_gpu_backend/gpu_backend_lib.cc | 38 +++++++++++-------- tensorflow/workspace.bzl | 4 +- third_party/mlir/BUILD | 14 +++++++ 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 85e5c2dedee..060a0375271 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -32,7 +32,7 @@ limitations under the License. #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/CodeGen/CommandFlags.inc" +#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -69,6 +69,8 @@ namespace xla { namespace gpu { namespace { +static llvm::codegen::RegisterCodeGenFlags CGF; + // Inline threshold value to use in LLVM AMDGPU backend. const int kAMDGPUInlineThreshold = 0x100000; @@ -129,38 +131,41 @@ std::unique_ptr GetTargetMachine( llvm::Triple triple, absl::string_view cpu_name, const HloModuleConfig& hlo_module_config, absl::string_view feature_str) { std::string error; - const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error); + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget("", triple, error); if (target == nullptr) { LOG(FATAL) << "Unable to find Target for triple '" << triple.str() << "'" << " -- " << error; return nullptr; } - TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); + llvm::TargetOptions target_options = + llvm::codegen::InitTargetOptionsFromCodeGenFlags(); // Set the verbose assembly options. target_options.MCOptions.AsmVerbose = false; // The selection of codegen optimization level is copied from function // GetCodeGenOptLevel in //third_party/llvm/llvm/tools/opt/opt.cpp. - CodeGenOpt::Level codegen_opt_level; + llvm::CodeGenOpt::Level codegen_opt_level; switch (hlo_module_config.debug_options().xla_backend_optimization_level()) { case 1: - codegen_opt_level = CodeGenOpt::Less; + codegen_opt_level = llvm::CodeGenOpt::Less; break; case 2: - codegen_opt_level = CodeGenOpt::Default; + codegen_opt_level = llvm::CodeGenOpt::Default; break; case 3: - codegen_opt_level = CodeGenOpt::Aggressive; + codegen_opt_level = llvm::CodeGenOpt::Aggressive; break; default: - codegen_opt_level = CodeGenOpt::None; + codegen_opt_level = llvm::CodeGenOpt::None; } return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), - llvm_ir::AsStringRef(feature_str), target_options, getRelocModel(), - getCodeModel(), codegen_opt_level)); + llvm_ir::AsStringRef(feature_str), target_options, + llvm::codegen::getExplicitRelocModel(), + llvm::codegen::getExplicitCodeModel(), codegen_opt_level)); } // Adds the standard LLVM optimization passes, based on the speed optimization @@ -172,7 +177,7 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level, llvm::legacy::PassManagerBase* module_passes, llvm::legacy::FunctionPassManager* function_passes, int inline_threshold) { - PassManagerBuilder builder; + llvm::PassManagerBuilder builder; builder.OptLevel = opt_level; builder.SizeLevel = size_level; @@ -195,7 +200,7 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level, } // Emits the given module to a bit code file. -void EmitBitcodeToFile(const Module& module, absl::string_view filename) { +void EmitBitcodeToFile(const llvm::Module& module, absl::string_view filename) { std::error_code error_code; llvm::ToolOutputFile outfile(string(filename).c_str(), error_code, llvm::sys::fs::F_None); @@ -209,7 +214,8 @@ void EmitBitcodeToFile(const Module& module, absl::string_view filename) { // Emits the given module to PTX. target_machine is an initialized TargetMachine // for the NVPTX target. -string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { +string EmitModuleToPTX(llvm::Module* module, + llvm::TargetMachine* target_machine) { std::string ptx; // need a std::string instead of a ::string. { llvm::raw_string_ostream stream(ptx); @@ -277,8 +283,8 @@ Status LinkWithBitcodeVector(llvm::Module* module, LoadIRModule(bitcode_path, &module->getContext()); if (linker.linkInModule( std::move(bitcode_module), llvm::Linker::Flags::LinkOnlyNeeded, - [](Module& M, const StringSet<>& GVS) { - internalizeModule(M, [&GVS](const GlobalValue& GV) { + [](llvm::Module& M, const llvm::StringSet<>& GVS) { + internalizeModule(M, [&GVS](const llvm::GlobalValue& GV) { return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { @@ -561,7 +567,7 @@ static std::vector GetROCDLPaths(int amdgpu_version, // Emits the given module to HSA Code Object. target_machine is an initialized // TargetMachine for the AMDGPU target. StatusOr> EmitModuleToHsaco( - Module* module, llvm::TargetMachine* target_machine) { + llvm::Module* module, llvm::TargetMachine* target_machine) { auto* env = tensorflow::Env::Default(); std::vector tempdir_vector; env->GetLocalTempDirectories(&tempdir_vector); diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 1277a72416f..f02e2eb1538 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -597,8 +597,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "398b497cd0e20ca7245bf30c12c761b444581da5" - LLVM_SHA256 = "789fd647d166774dde233a13c30d53d8a6c9098d82c4cd12d203b6f37e2555e1" + LLVM_COMMIT = "98369178bc695ba5d64314beb62d5ba5c9f14e2e" + LLVM_SHA256 = "c30eb278889c64e5a57e31d9bad794c6019d5396ce58a6ba874b0e4763f21097" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 22b91ba36a8..0e4ac2c07b6 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -557,6 +557,18 @@ cc_library( ], ) +cc_library( + name = "LLVMIRTransforms", + srcs = glob(["lib/Dialect/LLVMIR/Transforms/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/LLVMIR/Transforms/*.h"]), + includes = ["include"], + deps = [ + ":IR", + ":LLVMDialect", + ":Pass", + ], +) + filegroup( name = "GPUOpsTdFiles", srcs = [ @@ -1746,6 +1758,7 @@ cc_library( ":IR", ":LLVMConversionIncGen", ":LLVMDialect", + ":LLVMIRTransforms", ":OpenMPDialect", ":Support", "@llvm-project//llvm:core", @@ -1975,6 +1988,7 @@ cc_library( ":GPUTransforms", ":IR", ":LLVMDialect", + ":LLVMIRTransforms", ":LinalgOps", ":LinalgToLLVM", ":LinalgToSPIRV", From 9230911f8c85679e899bf4435d1bc85b30f1e1d8 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Wed, 18 Mar 2020 05:08:38 -0700 Subject: [PATCH 134/492] Move the partition helper class in GPU delegate to lite/delegates/utils as it's general enough to be used in other delegate implementations. PiperOrigin-RevId: 301568189 Change-Id: I5f52e07387695275b09b1b1c603768d2aea41b6b --- tensorflow/lite/delegates/BUILD | 2 + tensorflow/lite/delegates/gpu/common/BUILD | 1 + .../delegates/gpu/common/model_builder.cc | 166 ++++-------------- tensorflow/lite/delegates/utils.cc | 94 ++++++++++ tensorflow/lite/delegates/utils.h | 72 ++++++++ tensorflow/lite/delegates/utils_test.cc | 151 ++++++++++++++++ 6 files changed, 351 insertions(+), 135 deletions(-) diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD index a238861aaed..df671675ec9 100644 --- a/tensorflow/lite/delegates/BUILD +++ b/tensorflow/lite/delegates/BUILD @@ -26,6 +26,8 @@ cc_library( hdrs = ["utils.h"], copts = tflite_copts(), deps = [ + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:util", "//tensorflow/lite/c:common", ], ) diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index d5d82877f0c..08945c70d0b 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -120,6 +120,7 @@ cc_library( "//tensorflow/lite:kernel_api", "//tensorflow/lite:util", "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates:utils", "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index a01014cf0f4..2a03ff9ff14 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/util.h" @@ -2771,131 +2772,18 @@ Status GetNodeAndRegistration(TfLiteContext* context, int node_id, return OkStatus(); } -using IsNodeSupportedFn = - std::function; +using IsNodeSupportedFn = tflite::delegates::IsNodeSupportedFn; -// A utility class to help model graph parition and decide the partition to be -// offloaded to GPU. -// TODO(b/151152967): move the following to lite/delegates/utils -class GraphPartitionHelper { - public: - GraphPartitionHelper(TfLiteContext* context, - IsNodeSupportedFn is_node_supported_fn) - : is_node_supported_fn_(is_node_supported_fn), context_(context) {} - - virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); } - - // Partitions the graph into multiple subgraphs, each of which is in - // dependency order with others - virtual Status Partition(std::set* unsupported_nodes_info) { - RETURN_IF_ERROR(PrepareSupportedNodes(unsupported_nodes_info)); - - TfLiteDelegateParams* partition_params_array_ = nullptr; - int num_partitions_ = 0; - if (context_->PreviewDelegatePartitioning(context_, supported_nodes_, - &partition_params_array_, - &num_partitions_) != kTfLiteOk) { - return InvalidArgumentError("Unable to preview delegate partition."); - } - - for (int i = 0; i < num_partitions_; ++i) { - partitions_.push_back(partition_params_array_ + i); - } - - return OkStatus(); - } - - // Returns the first n largest partitions or all if #partitions is less than - // 'n'. Note that partitions are ranked according to the number of nodes that - // a partition has, and the returned TfLiteDelegateParams objects are *owned* - // by the TfLite runtime. - std::vector GetFirstNLargestPartitions(int n) { - const int total = num_partitions(); - // We only sort partitions according to their sizes if necessary. - if (n < total) { - partitions_.sort(CompareTwoPartitions); - } - std::vector results; - auto p_it = partitions_.begin(); - for (int i = 0; i < std::min(total, n); ++i, ++p_it) { - results.push_back(*p_it); - } - return results; - } - - int num_total_nodes() const { return num_total_nodes_; } - int num_partitions() const { return partitions_.size(); } - - private: - static bool CompareTwoPartitions(TfLiteDelegateParams* left, - TfLiteDelegateParams* right) { - // Reverse sort - return left->nodes_to_replace->size > right->nodes_to_replace->size; - } - - Status PrepareSupportedNodes( - std::set* unsupported_nodes_info = nullptr) { - TfLiteIntArray* execution_plan = nullptr; - if (context_->GetExecutionPlan(context_, &execution_plan) != kTfLiteOk) { - return InvalidArgumentError("Unable to get graph execution plan."); - } - - num_total_nodes_ = execution_plan->size; - supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_); - supported_nodes_->size = 0; - for (int node_id : TfLiteIntArrayView(execution_plan)) { - TfLiteNode* node; - TfLiteRegistration* registration; - auto status = - GetNodeAndRegistration(context_, node_id, &node, ®istration); - if (!status.ok()) { - supported_nodes_->size = 0; - return status; - } - - status = IsNodeSupported(context_, node, registration, node_id); - if (status.ok()) { - supported_nodes_->data[supported_nodes_->size++] = node_id; - } else if (unsupported_nodes_info) { - unsupported_nodes_info->insert( - absl::StrCat(GetOpNameByRegistration(*registration), ": ", - status.error_message())); - } - } - return OkStatus(); - } - - // The number of total nodes passed in for partition (i.e. the - // execution_plan size) - int num_total_nodes_ = 0; - - // Tells whether a node is replaceable. - const IsNodeSupportedFn is_node_supported_fn_; - TfLiteIntArray* supported_nodes_; // owns the memory - - protected: - virtual Status IsNodeSupported(TfLiteContext* context, TfLiteNode* node, - TfLiteRegistration* registration, - int node_id) { - return is_node_supported_fn_(context, node, registration); - } - - TfLiteContext* const context_ = nullptr; - - // Doesn't own the memory of each TfLiteDelegateParams object as it's - // managed by the TfLite runtime itself. See - // TfLiteContext::PreviewDelegatePartitioning for details. - std::list partitions_; -}; - -class GraphWithDequantPartitionHelper : public GraphPartitionHelper { +class GraphWithDequantPartitionHelper + : public tflite::delegates::GraphPartitionHelper { public: GraphWithDequantPartitionHelper(TfLiteContext* context, IsNodeSupportedFn is_node_supported_fn) : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} - Status Partition(std::set* unsupported_nodes_info) override { - auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); + TfLiteStatus Partition( + std::set* unsupported_nodes_info) override { + const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); // Clean up those partitions that have a single dequant op. NoteThose // removed dequant ops have to be reserved in the graph and should not be // delegated. @@ -2924,9 +2812,9 @@ class GraphWithDequantPartitionHelper : public GraphPartitionHelper { } protected: - Status IsNodeSupported(TfLiteContext* context, TfLiteNode* node, - TfLiteRegistration* registration, - int node_id) override { + bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration, int node_id, + std::string* unsupported_details) override { // If we need to handle dequant nodes, we have to remap input tensors of // this node if some of them come from a dequant node before testing if // the node is supported. @@ -2937,10 +2825,10 @@ class GraphWithDequantPartitionHelper : public GraphPartitionHelper { // dequant node is first added as supported. Later, this dequant node // will be removed if it has to be preserved in the graph which happens // when its immediate downstream nodes cannot be supported. - return OkStatus(); + return true; } const auto status = GraphPartitionHelper::IsNodeSupported( - context, node, registration, node_id); + context, node, registration, node_id, unsupported_details); RestoreToOrigInputTensors(node, orig_inputs); return status; } @@ -3096,21 +2984,29 @@ Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { IsNodeSupportedFn node_supported_fn = [=](TfLiteContext* context, TfLiteNode* node, - TfLiteRegistration* registration) -> Status { - RETURN_IF_ERROR(IsSupported(context, node, registration)); - return (IsAllFloatTensors(context, node->inputs) && - IsAllFloatTensors(context, node->outputs)) - ? OkStatus() - : FailedPreconditionError( - "OP is supported, but tensor type isn't matched!"); + TfLiteRegistration* registration, + std::string* unsupported_details) -> bool { + const auto status = IsSupported(context, node, registration); + if (!status.ok()) { + if (unsupported_details) *unsupported_details = status.error_message(); + return false; + } + + if (!IsAllFloatTensors(context, node->inputs) || + !IsAllFloatTensors(context, node->outputs)) { + if (unsupported_details) { + *unsupported_details = + "OP is supported, but tensor type isn't matched!"; + } + return false; + } + return true; }; GraphWithDequantPartitionHelper partition_helper(context, node_supported_fn); std::set unsupported_nodes_info; - auto status = partition_helper.Partition(&unsupported_nodes_info); - if (!status.ok()) { - TF_LITE_KERNEL_LOG(context, status.error_message().c_str()); - return nullptr; + if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) { + return TfLiteIntArrayCreate(0); } // We simply get 1st largest partition, but we could later explore whether diff --git a/tensorflow/lite/delegates/utils.cc b/tensorflow/lite/delegates/utils.cc index c4e6d5fbec4..75839d53560 100644 --- a/tensorflow/lite/delegates/utils.cc +++ b/tensorflow/lite/delegates/utils.cc @@ -15,7 +15,12 @@ limitations under the License. #include "tensorflow/lite/delegates/utils.h" +#include +#include + #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace delegates { @@ -42,5 +47,94 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context, return kTfLiteOk; } +TfLiteStatus GraphPartitionHelper::Partition( + std::set* unsupported_nodes_info) { + const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info); + if (prepare_status != kTfLiteOk) return prepare_status; + + TfLiteDelegateParams* partition_params_array_ = nullptr; + int num_partitions_ = 0; + if (context_->PreviewDelegatePartitioning(context_, supported_nodes_, + &partition_params_array_, + &num_partitions_) != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context_, "Unable to preview delegate partition.\n"); + return kTfLiteError; + } + + for (int i = 0; i < num_partitions_; ++i) { + partitions_.push_back(partition_params_array_ + i); + } + + return kTfLiteOk; +} + +std::vector +GraphPartitionHelper::GetFirstNLargestPartitions( + int n, int min_nodes_per_partition) const { + // In general, the number of partitions in a delegate is never likely to be + // high enough to cause latency issues. Also considering this is generally a + // one-time work, we simply unconditionally sort partitions here according to + // the size. + std::vector sorted_partitions(partitions_); + std::sort(sorted_partitions.begin(), sorted_partitions.end(), + [](TfLiteDelegateParams* left, TfLiteDelegateParams* right) { + // Reverse sort + return left->nodes_to_replace->size > + right->nodes_to_replace->size; + }); + + std::vector results; + auto p_it = sorted_partitions.begin(); + const int total = sorted_partitions.size(); + for (int i = 0; i < std::min(total, n); ++i, ++p_it) { + auto* p = (*p_it); + if (p->nodes_to_replace->size < min_nodes_per_partition) { + break; + } + results.push_back(p); + } + return results; +} + +TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes( + std::set* unsupported_nodes_info) { + TfLiteIntArray* execution_plan = nullptr; + auto status = context_->GetExecutionPlan(context_, &execution_plan); + if (status != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context_, "Unable to get graph execution plan.\n"); + return status; + } + + num_total_nodes_ = execution_plan->size; + supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_); + supported_nodes_->size = 0; + for (int node_id : TfLiteIntArrayView(execution_plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + + status = context_->GetNodeAndRegistration(context_, node_id, &node, + ®istration); + if (status != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context_, + "Couldn't get node and registration info for op: %d\n", + node_id); + supported_nodes_->size = 0; + return status; + } + + std::string unsupported_details; + if (IsNodeSupported(context_, node, registration, node_id, + &unsupported_details)) { + supported_nodes_->data[supported_nodes_->size++] = node_id; + } else if (unsupported_nodes_info) { + std::string node_info = GetOpNameByRegistration(*registration); + node_info.append(": "); + node_info.append(unsupported_details); + unsupported_nodes_info->insert(node_info); + } + } + return kTfLiteOk; +} + } // namespace delegates } // namespace tflite diff --git a/tensorflow/lite/delegates/utils.h b/tensorflow/lite/delegates/utils.h index d2881dfde37..f894cae30fd 100644 --- a/tensorflow/lite/delegates/utils.h +++ b/tensorflow/lite/delegates/utils.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_ #define TENSORFLOW_LITE_DELEGATES_UTILS_H_ +#include +#include +#include +#include #include #include "tensorflow/lite/c/common.h" @@ -31,6 +35,74 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context, TfLiteTensor** new_tensor, int* new_tensor_index); +using IsNodeSupportedFn = + std::function; + +// A utility class to help model graph parition. +// Note the class *needs* to be used in TfLiteDelegate::Prepare. +class GraphPartitionHelper { + public: + // TODO(b/151152967): Support use-cases where a list of supported nodes are + // directly passed-in. + GraphPartitionHelper(TfLiteContext* context, + IsNodeSupportedFn is_node_supported_fn) + : context_(context), is_node_supported_fn_(is_node_supported_fn) {} + + virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); } + + // Partition the graph into node subsets such that each subset could be + // replaced with one delegate kernel (i.e. a kTfLiteBuiltinDelegate op). + // If 'unsupported_nodes_info' is provided, it will be populated with + // information about all different unsupported nodes. + virtual TfLiteStatus Partition(std::set* unsupported_nodes_info); + + // Returns the first n largest partitions or all if #partitions is less than + // 'n' and each parition has at least (>=) 'min_nodes_per_partition' nodes. + // Note that partitions are ranked according to the number of nodes that + // a partition has, and the returned TfLiteDelegateParams objects are *owned* + // by the TfLite runtime. + std::vector GetFirstNLargestPartitions( + int n = std::numeric_limits::max(), + int min_nodes_per_partition = 0) const; + + int num_total_nodes() const { return num_total_nodes_; } + int num_partitions() const { return partitions_.size(); } + + protected: + virtual bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration, int node_id, + std::string* unsupported_details) { + return is_node_supported_fn_(context, node, registration, + unsupported_details); + } + + TfLiteContext* const context_ = nullptr; + + // Doesn't own the memory of each TfLiteDelegateParams object as it's + // managed by the TfLite runtime itself. See + // TfLiteContext::PreviewDelegatePartitioning for details. + std::vector partitions_; + + private: + // Generate a list of supported nodes (i.e. populating 'supported_nodes_') by + // iterating over all nodes (i,e. those listed in the execution_plan + // associated w/ 'context_'). + // If 'unsupported_nodes_info' is provided, it will be populated with + // information about all different unsupported nodes. + TfLiteStatus PrepareSupportedNodes( + std::set* unsupported_nodes_info = nullptr); + + // The number of total nodes passed in for partitioning (i.e. the + // execution_plan size associated w/ 'context_') + int num_total_nodes_ = 0; + + // Tells if a node is supported as it could be delegated. + const IsNodeSupportedFn is_node_supported_fn_; + + // Contains an array of supported node indices. + TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory +}; } // namespace delegates } // namespace tflite diff --git a/tensorflow/lite/delegates/utils_test.cc b/tensorflow/lite/delegates/utils_test.cc index 25b36753222..a67778fee1f 100644 --- a/tensorflow/lite/delegates/utils_test.cc +++ b/tensorflow/lite/delegates/utils_test.cc @@ -72,6 +72,157 @@ TEST(UtilsTest, CreateNewTensorWithDifferentTypeTest) { TfLiteIntArrayFree(tensors[1].dims); } +// A mock TfLiteContext to be used for GraphPartitionHelperTest. +class MockTfLiteContext : public TfLiteContext { + public: + MockTfLiteContext() : TfLiteContext({0}) { + // Simply create a 10-node execution plan. + exec_plan_ = TfLiteIntArrayCreate(10); + for (int i = 0; i < 10; ++i) exec_plan_->data[i] = i; + + // Create {1}, {0,3,7,8}, {2,4,9}, {5,6} 4 partitions. + TfLiteDelegateParams params1({nullptr}); + params1.nodes_to_replace = TfLiteIntArrayCreate(1); + params1.nodes_to_replace->data[0] = 1; + delegate_params_.emplace_back(params1); + + TfLiteDelegateParams params2({nullptr}); + params2.nodes_to_replace = TfLiteIntArrayCreate(4); + params2.nodes_to_replace->data[0] = 0; + params2.nodes_to_replace->data[1] = 3; + params2.nodes_to_replace->data[2] = 7; + params2.nodes_to_replace->data[3] = 8; + delegate_params_.emplace_back(params2); + + TfLiteDelegateParams params3({nullptr}); + params3.nodes_to_replace = TfLiteIntArrayCreate(3); + params3.nodes_to_replace->data[0] = 2; + params3.nodes_to_replace->data[1] = 4; + params3.nodes_to_replace->data[2] = 9; + delegate_params_.emplace_back(params3); + + TfLiteDelegateParams params4({nullptr}); + params4.nodes_to_replace = TfLiteIntArrayCreate(2); + params4.nodes_to_replace->data[0] = 5; + params4.nodes_to_replace->data[1] = 6; + delegate_params_.emplace_back(params4); + + // We need to mock the following 3 functions inside TfLiteContext object + // that are used by GraphPartitionHelper implementation. + this->GetExecutionPlan = MockGetExecutionPlan; + this->GetNodeAndRegistration = MockGetNodeAndRegistration; + this->PreviewDelegatePartitioning = MockPreviewDelegatePartitioning; + } + ~MockTfLiteContext() { + TfLiteIntArrayFree(exec_plan_); + for (auto params : delegate_params_) { + TfLiteIntArrayFree(params.nodes_to_replace); + TfLiteIntArrayFree(params.input_tensors); + TfLiteIntArrayFree(params.output_tensors); + } + } + + TfLiteIntArray* exec_plan() const { return exec_plan_; } + TfLiteNode* node() { return &node_; } + TfLiteRegistration* registration() { return ®istration_; } + TfLiteDelegateParams* delegate_params() { return &delegate_params_.front(); } + int num_delegate_params() { return delegate_params_.size(); } + + private: + static TfLiteStatus MockGetExecutionPlan(TfLiteContext* context, + TfLiteIntArray** execution_plan) { + MockTfLiteContext* mock = reinterpret_cast(context); + *execution_plan = mock->exec_plan(); + return kTfLiteOk; + } + + static TfLiteStatus MockGetNodeAndRegistration( + TfLiteContext* context, int node_index, TfLiteNode** node, + TfLiteRegistration** registration) { + MockTfLiteContext* mock = reinterpret_cast(context); + *node = mock->node(); + *registration = mock->registration(); + return kTfLiteOk; + } + + static TfLiteStatus MockPreviewDelegatePartitioning( + TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegateParams** partition_params_array, int* num_partitions) { + MockTfLiteContext* mock = reinterpret_cast(context); + *partition_params_array = mock->delegate_params(); + *num_partitions = mock->num_delegate_params(); + return kTfLiteOk; + } + + // The execution plan of this mocked TfLiteContext object. + TfLiteIntArray* exec_plan_; + + // For simplicity, the mocked graph has only type of node and one + // registration. + TfLiteNode node_; + TfLiteRegistration registration_; + + // The TfLiteDelegateParams object that's manually populated inside the mocked + // TfLiteContext::PreviewDelegatePartitioning. + std::vector delegate_params_; +}; + +bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration, + std::string* unsupported_details) { + return true; +} + +std::vector GetNodesToReplaceFromPartitions( + const std::vector& partitions) { + std::vector nodes; + for (const auto p : partitions) { + nodes.insert(nodes.end(), p->nodes_to_replace->data, + p->nodes_to_replace->data + p->nodes_to_replace->size); + } + return nodes; +} + +TEST(GraphPartitionHelper, CheckPartitions) { + // The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}. + MockTfLiteContext mocked_context; + GraphPartitionHelper helper(&mocked_context, IsNodeSupported); + EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr)); + EXPECT_EQ(10, helper.num_total_nodes()); + EXPECT_EQ(4, helper.num_partitions()); + + auto partitions = helper.GetFirstNLargestPartitions(1, 0); + EXPECT_EQ(1, partitions.size()); + auto nodes = GetNodesToReplaceFromPartitions(partitions); + EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8})); + + // Get the largest partition but requiring at least 5 nodes, so empty result. + partitions = helper.GetFirstNLargestPartitions(1, 5); + EXPECT_TRUE(partitions.empty()); + + partitions = helper.GetFirstNLargestPartitions(10, 3); + EXPECT_EQ(2, partitions.size()); + EXPECT_EQ(4, partitions[0]->nodes_to_replace->size); + EXPECT_EQ(3, partitions[1]->nodes_to_replace->size); + nodes = GetNodesToReplaceFromPartitions(partitions); + EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9})); +} + +TfLiteStatus ErrorGetExecutionPlan(TfLiteContext* context, + TfLiteIntArray** execution_plan) { + return kTfLiteError; +} + +void EmptyReportError(TfLiteContext* context, const char* format, ...) {} + +TEST(GraphPartitionHelper, CheckPrepareErrors) { + TfLiteContext error_context({0}); + error_context.GetExecutionPlan = ErrorGetExecutionPlan; + error_context.ReportError = EmptyReportError; + GraphPartitionHelper helper(&error_context, IsNodeSupported); + EXPECT_EQ(kTfLiteError, helper.Partition(nullptr)); +} + } // namespace } // namespace delegates } // namespace tflite From 9af80726f59fc334e265d75d01d50efb4b096bca Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Wed, 18 Mar 2020 05:29:17 -0700 Subject: [PATCH 135/492] XNNPACK does not build for Fuchsia, so exclude XNNPACK delegate creation helper functions. PiperOrigin-RevId: 301570304 Change-Id: I0b448cd4b03f7b8fa69a6d3049b5e036af111427 --- tensorflow/lite/tools/evaluation/BUILD | 6 +++++- tensorflow/lite/tools/evaluation/utils.cc | 3 +++ tensorflow/lite/tools/evaluation/utils.h | 9 ++++++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/tools/evaluation/BUILD b/tensorflow/lite/tools/evaluation/BUILD index a028607482c..caa4a637766 100644 --- a/tensorflow/lite/tools/evaluation/BUILD +++ b/tensorflow/lite/tools/evaluation/BUILD @@ -42,7 +42,6 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", - "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", ] + select({ "//tensorflow:android": [ "//tensorflow/lite/delegates/gpu:delegate", @@ -56,6 +55,11 @@ cc_library( "//tensorflow/lite/experimental/delegates/hexagon:hexagon_delegate", ], "//conditions:default": [], + }) + select({ + "//tensorflow:fuchsia": [], + "//conditions:default": [ + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + ], }), ) diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc index e7c477e359d..f86a7316ecf 100644 --- a/tensorflow/lite/tools/evaluation/utils.cc +++ b/tensorflow/lite/tools/evaluation/utils.cc @@ -162,6 +162,8 @@ TfLiteDelegatePtr CreateHexagonDelegate( #endif // defined(__ANDROID__) } +// TODO(b/149248802): include XNNPACK delegate when the issue is resolved. +#if !defined(__Fuchsia__) TfLiteDelegatePtr CreateXNNPACKDelegate() { TfLiteXNNPackDelegateOptions xnnpack_options = TfLiteXNNPackDelegateOptionsDefault(); @@ -182,5 +184,6 @@ TfLiteDelegatePtr CreateXNNPACKDelegate(int num_threads) { opts.num_threads = num_threads > 1 ? num_threads : 0; return CreateXNNPACKDelegate(&opts); } +#endif } // namespace evaluation } // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/utils.h b/tensorflow/lite/tools/evaluation/utils.h index 0602ed4259e..d1717f92e5f 100644 --- a/tensorflow/lite/tools/evaluation/utils.h +++ b/tensorflow/lite/tools/evaluation/utils.h @@ -27,9 +27,13 @@ limitations under the License. #endif #endif +// TODO(b/149248802): include XNNPACK delegate when the issue is resolved. +#if !defined(__Fuchsia__) +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#endif + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" namespace tflite { namespace evaluation { @@ -68,10 +72,13 @@ TfLiteDelegatePtr CreateGPUDelegate(TfLiteGpuDelegateOptionsV2* options); TfLiteDelegatePtr CreateHexagonDelegate( const std::string& library_directory_path, bool profiling); +// TODO(b/149248802): include XNNPACK delegate when the issue is resolved. +#if !defined(__Fuchsia__) TfLiteDelegatePtr CreateXNNPACKDelegate(); TfLiteDelegatePtr CreateXNNPACKDelegate( const TfLiteXNNPackDelegateOptions* options); TfLiteDelegatePtr CreateXNNPACKDelegate(int num_threads); +#endif } // namespace evaluation } // namespace tflite From 6803beded9a98b3929cce4cf4e6031f328914d80 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 05:46:22 -0700 Subject: [PATCH 136/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301572010 Change-Id: Ic3439b7f4ac45837b7e8cbbf221da56e304bc516 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 6456f104ad3..52a9bf9551b 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From 7f7a31cf99be04951e04ebfff95e684ed4240006 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Wed, 18 Mar 2020 05:48:25 -0700 Subject: [PATCH 137/492] [XLA] Change Gather clamping optimization to prevent incorrect type casts. PiperOrigin-RevId: 301572198 Change-Id: I2e6eb13135e14a4fba4b7fb7f1ff941980a3f74a --- .../xla/service/elemental_ir_emitter.cc | 30 +++++++--- tensorflow/compiler/xla/tests/BUILD | 3 +- .../xla/tests/gather_operation_test.cc | 60 +++++++++++++++++++ 3 files changed, 83 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index c4420932e45..1d18b2c65a8 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1854,8 +1854,17 @@ StatusOr ElementalIrEmitter::EmitElementalGather( } auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { - llvm::Value* gather_dim_component_extended = - SExtOrTrunc(index_component, index_type); + auto index_component_type = index_component->getType(); + auto extended_type = index_component_type->getScalarSizeInBits() >= + index_type->getScalarSizeInBits() + ? index_component_type + : index_type; + // Possibly extend the value at the beginning to ensure clamping logic stays + // in bounds. + auto maybe_extended_index = + index_component_type != extended_type + ? b_->CreateSExt(index_component, extended_type) + : index_component; int64 operand_dim = dim_numbers.start_index_map(dim); int64 output_dim = operand_to_output_dim[operand_dim]; // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. @@ -1868,18 +1877,21 @@ StatusOr ElementalIrEmitter::EmitElementalGather( CHECK_GE(largest_valid_start_index, 0); // Clamp the gather index so that the gather region fits in the operand. - // gather_dim_component_extended_inbound = + // clamped_index = // clamp(gather_dim_component_extended, 0, largest_valid_start_index); bool is_signed = ShapeUtil::ElementIsSigned(indices_shape); - auto gather_dim_component_extended_inbound = EmitIntegralMin( - index.GetConstantWithIndexType(largest_valid_start_index), - EmitIntegralMax(index.GetConstantWithIndexType(0), - gather_dim_component_extended, is_signed), + auto clamped_index = EmitIntegralMin( + llvm::ConstantInt::get(extended_type, largest_valid_start_index), + EmitIntegralMax(llvm::ConstantInt::get(extended_type, 0), + maybe_extended_index, is_signed), is_signed); + // Truncate at the end to the optimized index size + auto maybe_truncated_clamped_index = extended_type != index_type + ? Trunc(clamped_index, index_type) + : clamped_index; operand_multi_index[operand_dim] = - Add(operand_multi_index[operand_dim], - gather_dim_component_extended_inbound); + Add(operand_multi_index[operand_dim], maybe_truncated_clamped_index); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e42af57e19b..d4fba7d28ac 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -971,11 +971,12 @@ xla_test( ":client_library_test_base", ":hlo_test_base", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 9bba59787a1..9a19427a96a 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -283,6 +284,65 @@ ENTRY main { RunTest(hlo_text, &operand, &start_indices); } +// The next 2 tests uses data types that require extra steps on some backends so +// only run them on known good backends. +#if defined(XLA_TEST_BACKEND_GPU) || defined(XLA_TEST_BACKEND_CPU) || \ + defined(XLA_TEST_BACKEND_INTERPRETER) + +XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex64Bit) { + // Out of bounds indices must not crash, even when the value is of a type + // larger than needed to access all values in the input, and the indices + // produce the same values across all backends. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s64[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, + index_vector_dim=1, + slice_sizes={1,1} + ROOT result = s32[6]{0} reshape(gather) +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal start_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {21474836407, 1}, {1, 2}}); + RunTest(hlo_text, &operand, &start_indices); +} + +XLA_TEST_F(GatherOperationTest, TooSmallIndex8Bit) { + // Indices of a type too small to index all locations in gather should not + // fail. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[512, 512]{1,0} parameter(0) + indices = u8[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, + index_vector_dim=1, + slice_sizes={1,1} + ROOT result = s32[6]{0} reshape(gather) +} +)"; + Literal operand = LiteralUtil::MakeIdentityR2(512); + Literal start_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {7, 1}, {1, 2}}); + RunTest(hlo_text, &operand, &start_indices); +} + +#endif + XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { // Out of bounds indices must not crash, and the indices in range should // produce the same values across all backends. From 8324ef7e348bcc71453e82ecc172c48a372510a3 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Wed, 18 Mar 2020 12:40:00 +0000 Subject: [PATCH 138/492] Enable 3D tensor support for softmax quantized in TFLu --- .../lite/micro/kernels/cmsis-nn/softmax.cc | 31 ++-- tensorflow/lite/micro/kernels/softmax.cc | 11 +- tensorflow/lite/micro/kernels/softmax_test.cc | 150 ++++++++++++++++++ 3 files changed, 174 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc index 90fe83b744b..13b33b3f2cb 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc @@ -138,23 +138,27 @@ void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, GetTensorShape(output), GetTensorData(output)); } -void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { +void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { SoftmaxParams op_params; op_params.input_multiplier = data->input_multiplier; op_params.input_left_shift = data->input_left_shift; op_params.diff_min = data->diff_min; + if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { - arm_softmax_s8( - GetTensorData(input), - input->dims->data[0] * input->dims->data[1] * input->dims->data[2], - input->dims->data[3], op_params.input_multiplier, - op_params.input_left_shift, op_params.diff_min, - GetTensorData(output)); + const unsigned int num_dims = NumDimensions(input); + + arm_softmax_s8(GetTensorData(input), + (num_dims == 4 ? input->dims->data[0] : 1) * + input->dims->data[num_dims - 3] * + input->dims->data[num_dims - 2], + input->dims->data[num_dims - 1], op_params.input_multiplier, + op_params.input_left_shift, op_params.diff_min, + GetTensorData(output)); } } @@ -198,13 +202,14 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax2DQuantized(input, output, params, data); return kTfLiteOk; } - if (NumDimensions(input) == 4) { - Softmax4DQuantized(input, output, params, data); + if (NumDimensions(input) == 3 || NumDimensions(input) == 4) { + SoftmaxQuantized(input, output, params, data); return kTfLiteOk; } - TF_LITE_KERNEL_LOG(context, - "Only 2D and 4D tensors supported currently, got %dD.", - NumDimensions(input)); + TF_LITE_KERNEL_LOG( + context, + "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); return kTfLiteError; } default: diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc index fe2bfce5c7a..184456f4b89 100644 --- a/tensorflow/lite/micro/kernels/softmax.cc +++ b/tensorflow/lite/micro/kernels/softmax.cc @@ -168,8 +168,8 @@ void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, GetTensorShape(output), GetTensorData(output)); } -void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { +void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { SoftmaxParams op_params; op_params.input_multiplier = data->input_multiplier; op_params.input_left_shift = data->input_left_shift; @@ -233,12 +233,13 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax2DQuantized(input, output, params, data); return kTfLiteOk; } - if (NumDimensions(input) == 4) { - Softmax4DQuantized(input, output, params, data); + if (NumDimensions(input) == 3 || NumDimensions(input) == 4) { + SoftmaxQuantized(input, output, params, data); return kTfLiteOk; } TF_LITE_KERNEL_LOG( - context, "Only 1D, 2D and 4D tensors supported currently, got %dD.", + context, + "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.", NumDimensions(input)); return kTfLiteError; } diff --git a/tensorflow/lite/micro/kernels/softmax_test.cc b/tensorflow/lite/micro/kernels/softmax_test.cc index 0e7715cccf2..f229ab021f9 100644 --- a/tensorflow/lite/micro/kernels/softmax_test.cc +++ b/tensorflow/lite/micro/kernels/softmax_test.cc @@ -394,6 +394,156 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedSigned2D) { output_data); } +TF_LITE_MICRO_TEST(SimpleTestQuantizedSigned3D) { + using tflite::testing::F2QS; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float output_min = 0.0f; + const float output_max = (255.0f / 256.0f); + const int output_dims_count = 60; + int8_t output_data[output_dims_count]; + tflite::testing::TestSoftmaxQuantizedSigned( // + {3, 3, 4, 5}, // Input shape. + { // n = 0 + // c = 0 + // h = 0 + F2QS(3.00, input_min, input_max), F2QS(6.00, input_min, input_max), + F2QS(-5.00, input_min, input_max), F2QS(4.00, input_min, input_max), + F2QS(-9.00, input_min, input_max), + // h = 1 + F2QS(-10.00, input_min, input_max), F2QS(-10.00, input_min, input_max), + F2QS(-8.00, input_min, input_max), F2QS(2.00, input_min, input_max), + F2QS(2.00, input_min, input_max), + // h = 2 + F2QS(8.00, input_min, input_max), F2QS(-5.00, input_min, input_max), + F2QS(-8.00, input_min, input_max), F2QS(5.00, input_min, input_max), + F2QS(-6.00, input_min, input_max), + // h = 3 + F2QS(-8.00, input_min, input_max), F2QS(6.00, input_min, input_max), + F2QS(1.00, input_min, input_max), F2QS(-10.00, input_min, input_max), + F2QS(-8.00, input_min, input_max), + + // c = 1 + // h = 0 + F2QS(7.00, input_min, input_max), F2QS(6.00, input_min, input_max), + F2QS(-10.00, input_min, input_max), F2QS(-4.00, input_min, input_max), + F2QS(-5.00, input_min, input_max), + // h = 1 + F2QS(2.00, input_min, input_max), F2QS(7.00, input_min, input_max), + F2QS(9.00, input_min, input_max), F2QS(-9.00, input_min, input_max), + F2QS(7.00, input_min, input_max), + // h = 2 + F2QS(-4.00, input_min, input_max), F2QS(-2.00, input_min, input_max), + F2QS(8.00, input_min, input_max), F2QS(2.00, input_min, input_max), + F2QS(2.00, input_min, input_max), + // h = 3 + F2QS(3.00, input_min, input_max), F2QS(6.00, input_min, input_max), + F2QS(6.00, input_min, input_max), F2QS(2.00, input_min, input_max), + F2QS(4.00, input_min, input_max), + + // c = 2 + // h = 0 + F2QS(9.00, input_min, input_max), F2QS(7.00, input_min, input_max), + F2QS(-7.00, input_min, input_max), F2QS(0.00, input_min, input_max), + F2QS(4.00, input_min, input_max), + // h = 1 + F2QS(-3.00, input_min, input_max), F2QS(8.00, input_min, input_max), + F2QS(8.00, input_min, input_max), F2QS(-3.00, input_min, input_max), + F2QS(-4.00, input_min, input_max), + // h = 2 + F2QS(-9.00, input_min, input_max), F2QS(-9.00, input_min, input_max), + F2QS(4.00, input_min, input_max), F2QS(-8.00, input_min, input_max), + F2QS(-1.00, input_min, input_max), + // h = 3 + F2QS(-10.00, input_min, input_max), F2QS(-2.00, input_min, input_max), + F2QS(6.00, input_min, input_max), F2QS(-7.00, input_min, input_max), + F2QS(0.00, input_min, input_max)}, + input_min, input_max, // Input quantized range. + { // Expected results. + // n = 0 + // c = 0 + // h = 0 + F2QS(0.042009463, output_min, output_max), + F2QS(0.843782625, output_min, output_max), + F2QS(0.000014093, output_min, output_max), + F2QS(0.114193561, output_min, output_max), + F2QS(0.000000258, output_min, output_max), + // h = 1 + F2QS(0.000003072, output_min, output_max), + F2QS(0.000003072, output_min, output_max), + F2QS(0.000022699, output_min, output_max), + F2QS(0.499985578, output_min, output_max), + F2QS(0.499985578, output_min, output_max), + // h = 2 + F2QS(0.952571219, output_min, output_max), + F2QS(0.000002153, output_min, output_max), + F2QS(0.000000107, output_min, output_max), + F2QS(0.047425728, output_min, output_max), + F2QS(0.000000792, output_min, output_max), + // h = 3 + F2QS(0.000000826, output_min, output_max), + F2QS(0.993305397, output_min, output_max), + F2QS(0.006692839, output_min, output_max), + F2QS(0.000000112, output_min, output_max), + F2QS(0.000000826, output_min, output_max), + + // c = 1 + // h = 0 + F2QS(0.731046347, output_min, output_max), + F2QS(0.268936922, output_min, output_max), + F2QS(0.000000030, output_min, output_max), + F2QS(0.000012210, output_min, output_max), + F2QS(0.000004492, output_min, output_max), + // h = 1 + F2QS(0.000717124, output_min, output_max), + F2QS(0.106430599, output_min, output_max), + F2QS(0.786421666, output_min, output_max), + F2QS(0.000000012, output_min, output_max), + F2QS(0.106430599, output_min, output_max), + // h = 2 + F2QS(0.000006114, output_min, output_max), + F2QS(0.000045174, output_min, output_max), + F2QS(0.995015917, output_min, output_max), + F2QS(0.002466398, output_min, output_max), + F2QS(0.002466398, output_min, output_max), + // h = 3 + F2QS(0.022595176, output_min, output_max), + F2QS(0.453836234, output_min, output_max), + F2QS(0.453836234, output_min, output_max), + F2QS(0.008312301, output_min, output_max), + F2QS(0.061420055, output_min, output_max), + + // c = 2 + // h = 0 + F2QS(0.875505904, output_min, output_max), + F2QS(0.118486839, output_min, output_max), + F2QS(0.000000099, output_min, output_max), + F2QS(0.000108046, output_min, output_max), + F2QS(0.005899112, output_min, output_max), + // h = 1 + F2QS(0.000008351, output_min, output_max), + F2QS(0.499990113, output_min, output_max), + F2QS(0.499990113, output_min, output_max), + F2QS(0.000008351, output_min, output_max), + F2QS(0.000003072, output_min, output_max), + // h = 2 + F2QS(0.000002245, output_min, output_max), + F2QS(0.000002245, output_min, output_max), + F2QS(0.993296627, output_min, output_max), + F2QS(0.000006103, output_min, output_max), + F2QS(0.006692780, output_min, output_max), + // h = 3 + F2QS(0.000000112, output_min, output_max), + F2QS(0.000334520, output_min, output_max), + F2QS(0.997191323, output_min, output_max), + F2QS(0.000002254, output_min, output_max), + F2QS(0.002471790, output_min, output_max)}, + {3, 3, 4, 5}, // Output shape. + output_min, output_max, // Output quantized range. + output_data); +} + TF_LITE_MICRO_TEST(SimpleTestQuantizedSigned4D) { using tflite::testing::F2QS; From 8c1ead8d919c7536c3d8f52fbab44e1dfcd9fd79 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Wed, 18 Mar 2020 07:25:45 -0700 Subject: [PATCH 139/492] Inherit parent name scope stack when building branches of control flow ops. PiperOrigin-RevId: 301584650 Change-Id: Idb5e69190312ab36208ee4bed1f09dc555884ba6 --- tensorflow/python/framework/func_graph.py | 4 +++ .../python/framework/op_callbacks_test.py | 6 ++-- .../python/kernel_tests/cond_v2_test.py | 25 +++++++++++++++++ .../kernel_tests/control_flow_ops_py_test.py | 5 ++-- .../python/kernel_tests/while_v2_test.py | 28 +++++++++++++++++++ .../python/ops/control_flow_v2_func_graphs.py | 26 ++++++++++++----- 6 files changed, 82 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index d702771cef3..768dc3fee9c 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -1281,3 +1281,7 @@ def dismantle_func_graph(func_graph): """ func_graph.clear_captures() ops.dismantle_graph(func_graph) + + +def override_func_graph_name_scope(func_graph, name_scope): + func_graph._name_stack = name_scope # pylint: disable=protected-access diff --git a/tensorflow/python/framework/op_callbacks_test.py b/tensorflow/python/framework/op_callbacks_test.py index 8f21fff10f8..f04d85bba21 100644 --- a/tensorflow/python/framework/op_callbacks_test.py +++ b/tensorflow/python/framework/op_callbacks_test.py @@ -632,7 +632,7 @@ class OpCallbacksTest(test_util.TensorFlowTestCase): greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP] self.assertEqual(len(greater_op_outputs), 1) self.assertAllClose(greater_op_outputs[0], False) - pow_op_outputs = instrument.graph_internal_ndarrays[b"pow"] + pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"] self.assertEqual(len(pow_op_outputs), 1) self.assertAllClose(pow_op_outputs[0], -64.0) @@ -660,9 +660,9 @@ class OpCallbacksTest(test_util.TensorFlowTestCase): # Check the graph internal ndarrays recorded at runtime. read_variable_op_outputs = instrument.graph_internal_ndarrays[ - _READ_VARIABLE_OP] + b"while/" + _READ_VARIABLE_OP] self.assertAllClose(read_variable_op_outputs, [1.0, 2.0, 4.0, 8.0]) - less_op_outputs = instrument.graph_internal_ndarrays[_LESS_OP] + less_op_outputs = instrument.graph_internal_ndarrays[b"while/" + _LESS_OP] self.assertAllClose(less_op_outputs, [True, True, True, True, False]) # TODO(cais): The following isn't decorated with diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 94a0e73e64f..de8ea8d89d7 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -260,6 +260,31 @@ class CondV2Test(test.TestCase): self.assertRegexpMatches( cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*") + def testInheritParentNameScope(self): + + @def_function.function + def f(): + with ops.name_scope("foo"): + + def then_branch(): + with ops.name_scope("then"): + actual_name_scope = ops.get_name_scope() + expected_name_scope = "foo/cond/then" + self.assertEqual(actual_name_scope, expected_name_scope) + return 0. + + def else_branch(): + with ops.name_scope("else"): + actual_name_scope = ops.get_name_scope() + expected_name_scope = "foo/cond/else" + self.assertEqual(actual_name_scope, expected_name_scope) + return 0. + + return cond_v2.cond_v2( + constant_op.constant(True), then_branch, else_branch) + + f() + @test_util.run_v1_only("b/120545219") def testDefunInCond(self): x = constant_op.constant(1.0, name="x") diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index ec9d97c4bcc..99fff136314 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -811,7 +811,7 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp( ValueError, - "Tensor true_branch:0 in true_fn is accessed from false_fn."): + "Tensor cond/true_branch:0 in true_fn is accessed from false_fn."): f() def testSwitchCaseAccessBranch1TensorInBranch4Raises(self): @@ -838,7 +838,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp( ValueError, - "Tensor br1_identity:0 in branch 1 is accessed from branch 4."): + "Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is " + "accessed from branch 4."): f() def testCondListOutput(self): diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 8f131723cd2..1fa6c179e7a 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -1175,6 +1175,34 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): Fn() + def testInheritParentNameScope(self): + + @def_function.function + def F(): + with ops.name_scope("foo"): + + def Cond(unused_i): + with ops.name_scope("cond"): + actual_name_scope = ops.get_name_scope() + expected_name_scope = "foo/while/cond" + assert actual_name_scope == expected_name_scope, ( + "%s does not match %s" % + (actual_name_scope, expected_name_scope)) + return False + + def Body(i): + with ops.name_scope("body"): + actual_name_scope = ops.get_name_scope() + expected_name_scope = "foo/while/body" + assert actual_name_scope == expected_name_scope, ( + "%s does not match %s" % + (actual_name_scope, expected_name_scope)) + return i + + return while_v2.while_loop(Cond, Body, [0.]) + + F() + def ScalarShape(): return ops.convert_to_tensor([], dtype=dtypes.int32) diff --git a/tensorflow/python/ops/control_flow_v2_func_graphs.py b/tensorflow/python/ops/control_flow_v2_func_graphs.py index 1a96d397c5b..537ad2b4b8a 100644 --- a/tensorflow/python/ops/control_flow_v2_func_graphs.py +++ b/tensorflow/python/ops/control_flow_v2_func_graphs.py @@ -18,28 +18,40 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework.func_graph import FuncGraph +from tensorflow.python.framework import func_graph -class CondBranchFuncGraph(FuncGraph): +class CondBranchFuncGraph(func_graph.FuncGraph): """FuncGraph for branches of tf.cond(). This is used to distinguish cond branches from other functions. """ - pass + + def __init__(self, *args, **kwargs): + super(CondBranchFuncGraph, self).__init__(*args, **kwargs) + func_graph.override_func_graph_name_scope(self, + self.outer_graph.get_name_scope()) -class WhileCondFuncGraph(FuncGraph): +class WhileCondFuncGraph(func_graph.FuncGraph): """FuncGraph for the condition of tf.while_loop(). This is used to distinguish while conditions from other functions. """ - pass + + def __init__(self, *args, **kwargs): + super(WhileCondFuncGraph, self).__init__(*args, **kwargs) + func_graph.override_func_graph_name_scope(self, + self.outer_graph.get_name_scope()) -class WhileBodyFuncGraph(FuncGraph): +class WhileBodyFuncGraph(func_graph.FuncGraph): """FuncGraph for the body of tf.while_loop(). This is used to distinguish while bodies from other functions. """ - pass + + def __init__(self, *args, **kwargs): + super(WhileBodyFuncGraph, self).__init__(*args, **kwargs) + func_graph.override_func_graph_name_scope(self, + self.outer_graph.get_name_scope()) From 3aa632b55220931f9fa536b9fbc88114651a58aa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 07:46:20 -0700 Subject: [PATCH 140/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301587193 Change-Id: I78f45f5c34b3e848c957106edc3eaaa9acd648e2 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 52a9bf9551b..6456f104ad3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From fe84fb19eb0df509a5efd2017641ab557e70ee6e Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Wed, 18 Mar 2020 08:44:12 -0700 Subject: [PATCH 141/492] [tf.data service] Add master and worker proto definitions. PiperOrigin-RevId: 301596660 Change-Id: Ibe7e271e345919ab58c1be61abec81b0463f972f --- tensorflow/core/data/service/BUILD | 58 ++++++++++++++++++ tensorflow/core/data/service/common.proto | 40 +++++++++++++ tensorflow/core/data/service/master.proto | 73 +++++++++++++++++++++++ tensorflow/core/data/service/worker.proto | 31 ++++++++++ 4 files changed, 202 insertions(+) create mode 100644 tensorflow/core/data/service/BUILD create mode 100644 tensorflow/core/data/service/common.proto create mode 100644 tensorflow/core/data/service/master.proto create mode 100644 tensorflow/core/data/service/worker.proto diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD new file mode 100644 index 00000000000..6003362406f --- /dev/null +++ b/tensorflow/core/data/service/BUILD @@ -0,0 +1,58 @@ +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", +) + +package( + default_visibility = [ + "//tensorflow:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) + +tf_proto_library( + name = "common_proto", + srcs = ["common.proto"], + cc_api_version = 2, + protodeps = tf_additional_all_protos(), +) + +tf_proto_library( + name = "master_proto", + srcs = ["master.proto"], + has_services = 1, + cc_api_version = 2, + protodeps = tf_additional_all_protos() + [ + ":common_proto", + ], +) + +tf_proto_library( + name = "worker_proto", + srcs = ["worker.proto"], + has_services = 1, + cc_api_version = 2, + protodeps = tf_additional_all_protos() + [ + ":common_proto", + ], +) + +cc_grpc_library( + name = "master_cc_grpc_proto", + srcs = [":master_proto"], + generate_mocks = True, + grpc_only = True, + deps = [":master_proto_cc"], +) + +cc_grpc_library( + name = "worker_cc_grpc_proto", + srcs = [":worker_proto"], + generate_mocks = True, + grpc_only = True, + deps = [":worker_proto_cc"], +) diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto new file mode 100644 index 00000000000..0faaa661e08 --- /dev/null +++ b/tensorflow/core/data/service/common.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package tensorflow.data; + +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +message DatasetDef { + // We represent datasets as tensorflow GraphDefs which define the operations + // needed to create a tf.data dataset. + GraphDef graph = 1; +} + +message ComponentMetadata { + // The dtype of the component tensor. + .tensorflow.DataType dtype = 1; + // The shape of the component tensor. + .tensorflow.TensorShapeProto tensor_shape = 2; + // Size of the uncompressed tensor bytes. For tensors serialized as + // TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors, + // this is the size of the buffer underlying the Tensor. + int64 tensor_size_bytes = 3; +} + +message CompressedElement { + // Compressed tensor bytes for all components of the element. + bytes data = 1; + // Metadata for the components of the element. + repeated ComponentMetadata component_metadata = 2; +} + +message TaskDef { + // The dataset to iterate over. + // TODO(aaudibert): load the dataset from disk instead of passing it here. + DatasetDef dataset = 1; + int64 dataset_id = 2; + int64 task_id = 3; + int64 epoch_id = 4; +} diff --git a/tensorflow/core/data/service/master.proto b/tensorflow/core/data/service/master.proto new file mode 100644 index 00000000000..03be51c79e7 --- /dev/null +++ b/tensorflow/core/data/service/master.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; + +package tensorflow.data; + +import "tensorflow/core/data/service/common.proto"; + +message RegisterWorkerRequest { + // The address of the registering worker. + string worker_address = 1; +} + +message RegisterWorkerResponse { + // An id for the worker. + int64 worker_id = 1; + // Tasks to begin processing. + repeated TaskDef tasks = 2; +} + +message GetOrRegisterDatasetRequest { + // The dataset to register. + DatasetDef dataset = 1; +} + +message GetOrRegisterDatasetResponse { + // The id for the registered dataset. + int64 dataset_id = 1; +} + +message BeginEpochRequest { + // The id of the dataset to iterate over. + int64 dataset_id = 1; +} + +message BeginEpochResponse { + // The id for the created epoch. + int64 epoch_id = 1; +} + +message GetTasksRequest { + // The epoch to look up tasks for. + int64 epoch_id = 1; +} + +message TaskInfo { + // The address of the worker processing the task. + string worker_address = 1; + // The task id. + int64 id = 2; +} + +message GetTasksResponse { + // A list of all tasks for an epoch. + repeated TaskInfo task_info = 1; +} + +service MasterService { + // Registers a worker with the master. + rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse); + + // Registers a dataset with the server, or returns its id if it is already + // registered. + // + // The dataset is constructed in a new graph, so it must not refer to + // external resources or variables. + rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest) + returns (GetOrRegisterDatasetResponse); + + // Begins an epoch over a dataset. + rpc BeginEpoch(BeginEpochRequest) returns (BeginEpochResponse); + + // Reports a list of all tasks for an epoch. + rpc GetTasks(GetTasksRequest) returns (GetTasksResponse); +} diff --git a/tensorflow/core/data/service/worker.proto b/tensorflow/core/data/service/worker.proto new file mode 100644 index 00000000000..04b8f03474c --- /dev/null +++ b/tensorflow/core/data/service/worker.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package tensorflow.data; + +import "tensorflow/core/data/service/common.proto"; + +message ProcessTaskRequest { + TaskDef task = 1; +} + +message ProcessTaskResponse {} + +message GetElementRequest { + // The task to fetch an element from. + int64 task_id = 1; +} + +message GetElementResponse { + // The produced element. + CompressedElement compressed_element = 3; + // Boolean to indicate whether the iterator has been exhausted. + bool end_of_sequence = 2; +} + +service WorkerService { + // Processes an task for a dataset, making elements available to clients. + rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse); + + // Gets the next dataset element. + rpc GetElement(GetElementRequest) returns (GetElementResponse); +} From caa68bf2d03017a0e4828316d3fec0c1dd219ae9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Mar 2020 01:09:19 +0900 Subject: [PATCH 142/492] minor spelling tweaks --- .../acceleration_test_util_internal_test.cc | 2 +- tensorflow/lite/kernels/activations.cc | 8 ++--- .../bidirectional_sequence_lstm_test.cc | 10 +++---- .../bidirectional_sequence_rnn_test.cc | 2 +- tensorflow/lite/kernels/cpu_backend_context.h | 2 +- .../kernels/cpu_backend_gemm_custom_gemv.h | 6 ++-- .../lite/kernels/cpu_backend_gemm_eigen.cc | 2 +- .../kernels/detection_postprocess_test.cc | 2 +- tensorflow/lite/kernels/fully_connected.cc | 2 +- .../lite/kernels/fully_connected_test.cc | 2 +- ...epthwiseconv_per_channel_quantized_test.cc | 4 +-- .../internal/optimized/depthwiseconv_float.h | 30 +++++++++---------- .../internal/optimized/depthwiseconv_uint8.h | 28 ++++++++--------- .../depthwiseconv_uint8_3x3_filter.h | 2 +- .../optimized/integer_ops/depthwise_conv.h | 28 ++++++++--------- .../integer_ops/depthwise_conv_3x3_filter.h | 6 ++-- .../internal/optimized/integer_ops/mean.h | 2 +- .../internal/optimized/neon_tensor_utils.cc | 2 +- .../internal/optimized/optimized_ops.h | 14 ++++----- .../kernels/internal/quantization_util.cc | 2 +- .../internal/reference/binary_function.h | 2 +- .../lite/kernels/internal/spectrogram.cc | 2 +- .../lite/kernels/internal/tensor_utils.h | 12 ++++---- .../kernels/internal/tensor_utils_test.cc | 8 ++--- tensorflow/lite/kernels/kernel_util.cc | 2 +- tensorflow/lite/kernels/lstm.cc | 2 +- tensorflow/lite/kernels/lstm_eval.cc | 4 +-- tensorflow/lite/kernels/matrix_diag_test.cc | 2 +- tensorflow/lite/kernels/pad_test.cc | 18 +++++------ tensorflow/lite/kernels/rfft2d.cc | 2 +- tensorflow/lite/kernels/strided_slice_test.cc | 2 +- tensorflow/lite/kernels/subgraph_test_util.h | 2 +- .../lite/kernels/subgraph_test_util_test.cc | 8 ++--- tensorflow/lite/kernels/svdf_test.cc | 4 +-- tensorflow/lite/kernels/variable_ops_test.cc | 4 +-- tensorflow/lite/kernels/while_test.cc | 2 +- 36 files changed, 116 insertions(+), 116 deletions(-) diff --git a/tensorflow/lite/kernels/acceleration_test_util_internal_test.cc b/tensorflow/lite/kernels/acceleration_test_util_internal_test.cc index 0195f213616..71e0c9e9912 100644 --- a/tensorflow/lite/kernels/acceleration_test_util_internal_test.cc +++ b/tensorflow/lite/kernels/acceleration_test_util_internal_test.cc @@ -110,7 +110,7 @@ TEST_F(ReadAccelerationConfigTest, IgnoresCommentedLines) { EXPECT_TRUE(blacklist_.empty()); } -TEST_F(ReadAccelerationConfigTest, CommentCanHaveTralingBlanks) { +TEST_F(ReadAccelerationConfigTest, CommentCanHaveTrailingBlanks) { ReadAccelerationConfig(" #key,value", consumer_); EXPECT_TRUE(whitelist_.empty()); diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 7e65fbb5306..df82d6919b7 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -809,7 +809,7 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { params.input_range_radius = data->input_range_radius; params.input_multiplier = data->input_multiplier; params.input_left_shift = data->input_left_shift; - optimized_ops::Tanh16bitPercision( + optimized_ops::Tanh16bitPrecision( params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { @@ -824,7 +824,7 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { params.input_range_radius = data->input_range_radius; params.input_multiplier = data->input_multiplier; params.input_left_shift = data->input_left_shift; - optimized_ops::Tanh16bitPercision( + optimized_ops::Tanh16bitPrecision( params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { @@ -881,7 +881,7 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { params.input_range_radius = data->input_range_radius; params.input_multiplier = data->input_multiplier; params.input_left_shift = data->input_left_shift; - optimized_ops::Logistic16bitPercision( + optimized_ops::Logistic16bitPrecision( params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { @@ -896,7 +896,7 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { params.input_range_radius = data->input_range_radius; params.input_multiplier = data->input_multiplier; params.input_left_shift = data->input_left_shift; - optimized_ops::Logistic16bitPercision( + optimized_ops::Logistic16bitPrecision( params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc index 9c397fefa9f..12b33c9661d 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -2766,11 +2766,11 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) { // Aux input and input are the same, so we should observe the same outputs // as there's no aux input. lstm.SetAuxInput(0, batch0_start, batch0_end); - std::vector dummpy_weights(n_cell * n_input, 0.0f); - lstm.SetAuxInputToInputWeights(dummpy_weights); - lstm.SetAuxInputToForgetWeights(dummpy_weights); - lstm.SetAuxInputToCellWeights(dummpy_weights); - lstm.SetAuxInputToOutputWeights(dummpy_weights); + std::vector dummy_weights(n_cell * n_input, 0.0f); + lstm.SetAuxInputToInputWeights(dummy_weights); + lstm.SetAuxInputToForgetWeights(dummy_weights); + lstm.SetAuxInputToCellWeights(dummy_weights); + lstm.SetAuxInputToOutputWeights(dummy_weights); lstm.Invoke(); diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc index a5210da243b..34441e2b300 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -1346,7 +1346,7 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestCrossLinkingAuxInputOnlyTimeMajor) { } // Same as BlackBox test, but the input tensor and weights tensor are split -// along the last dimension and passed to both regular and auxiliry inputs and +// along the last dimension and passed to both regular and auxiliary inputs and // weights. The output in this case is the same. To understand this, let's // define W and V as regular input weights matrix and auxiliary input weights // matrix correspondingly. It's easy to see that this is equivalent to a regular diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h index 82d990aa3ab..2d3d76deaea 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.h +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -55,7 +55,7 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { const std::unique_ptr ruy_context_; const std::unique_ptr gemmlowp_context_; - // The maxinum of threads used for parallelizing TfLite ops. However, + // The maximum of threads used for parallelizing TfLite ops. However, // cpu_backend_threadpool::Execute creates as many threads as it's // asked to, regardless of this. Typically a call site would query // cpu_backend_context->max_num_threads() and used that to determine diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h index 9b09123a979..b19d5bc990b 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h @@ -593,10 +593,10 @@ struct CustomGemvImplkeep_num_dims) { // When number of dimensions are kept the filter operates along the last - // dimenions. In other words, for an input tensor with shape + // dimentions. In other words, for an input tensor with shape // [batch_size, ..., n_inputs] and a filter of shape [n_inputs, n_units] // this Op produces an output of shape [batch_size, ..., n_units]. TF_LITE_ENSURE_EQ(context, input->dims->data[input->dims->size - 1], diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 6c8337f17ab..1f671cae0fc 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -790,7 +790,7 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) { // The shuffled weights block shape is 4x16. The shape of the weights matrix // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16. - // This means that output_depth must be a multiple of 4, and input_deth must + // This means that output_depth must be a multiple of 4, and input_depth must // be a multiple of 16. for (int input_depth_numblocks : {1, 3}) { for (int output_depth_numblocks : {1, 3}) { diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc index 3fb824ca902..8336b63b0ba 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc @@ -290,7 +290,7 @@ void TryTestOneDepthwiseConv3x3Filter() { // It's hard to come up with a right multiplier, random guess basically makes // all the results saturated and becomes meaningfulless, so we first use // reference impl to poke the min/max value of the accumulation, then use that - // value as a guided suggestion for us to populate meaningful mulitplier & + // value as a guided suggestion for us to populate meaningful multiplier & // shift. PickReasonableMultiplier( params, output_activation_min, output_activation_max, output_depth, @@ -305,7 +305,7 @@ void TryTestOneDepthwiseConv3x3Filter() { dilation_width_factor, dilation_height_factor, pad_width, pad_height, depth_multiplier, output_shape_inference, 0, output_shift.data())); - // The following tests compare referene impl and Neon general impl agrees, + // The following tests compare reference impl and Neon general impl agrees, // and reference impl loosely agrees with fast kernel since they use different // rounding strategy. reference_integer_ops::DepthwiseConvPerChannel( diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h index 09d880f4cec..4995f480adf 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -787,37 +787,37 @@ void FloatDepthwiseConvAccumRow(int stride, int dilation_factor, for (int filter_x = 0; filter_x < filter_width; ++filter_x) { // For the current (filter_x, filter_y) point in the filter, // compute the boundaries of the corresponding output row segment. - int out_x_loop_start_unclampled = 0; - int out_x_loop_end_unclampled = 0; - if (kAllowStrided) { + int out_x_loop_start_unclamped = 0; + int out_x_loop_end_unclamped = 0; + if (kAllowStrided) if (stride == 2) { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + 1) / 2; - out_x_loop_end_unclampled = + out_x_loop_end_unclamped = (pad_width + input_width - dilation_factor * filter_x + 1) / 2; } else if (stride == 4) { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + 3) / 4; - out_x_loop_end_unclampled = + out_x_loop_end_unclamped = (pad_width + input_width - dilation_factor * filter_x + 3) / 4; } else { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + stride - 1) / stride; - out_x_loop_end_unclampled = (pad_width + input_width - - dilation_factor * filter_x + stride - 1) / - stride; + out_x_loop_end_unclamped = (pad_width + input_width - + dilation_factor * filter_x + stride - 1) / + stride; } } else { - out_x_loop_start_unclampled = pad_width - dilation_factor * filter_x; - out_x_loop_end_unclampled = + out_x_loop_start_unclamped = pad_width - dilation_factor * filter_x; + out_x_loop_end_unclamped = pad_width + input_width - dilation_factor * filter_x; } // The kernel will have to iterate on the segment of the // output row that starts at out_x_loop_start and out_x_loop_end. const int out_x_loop_start = - std::max(out_x_buffer_start, out_x_loop_start_unclampled); + std::max(out_x_buffer_start, out_x_loop_start_unclamped); const int out_x_loop_end = - std::min(out_x_buffer_end, out_x_loop_end_unclampled); + std::min(out_x_buffer_end, out_x_loop_end_unclamped); float* acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h index fe3f72e2536..a758929a25b 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -1496,37 +1496,37 @@ void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, for (int filter_x = 0; filter_x < filter_width; ++filter_x) { // For the current (filter_x, filter_y) point in the filter, // compute the boundaries of the corresponding output row segment. - int out_x_loop_start_unclampled = 0; - int out_x_loop_end_unclampled = 0; + int out_x_loop_start_unclamped = 0; + int out_x_loop_end_unclamped = 0; if (kAllowStrided) { if (stride == 2) { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + 1) / 2; - out_x_loop_end_unclampled = + out_x_loop_end_unclamped = (pad_width + input_width - dilation_factor * filter_x + 1) / 2; } else if (stride == 4) { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + 3) / 4; - out_x_loop_end_unclampled = + out_x_loop_end_unclamped = (pad_width + input_width - dilation_factor * filter_x + 3) / 4; } else { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + stride - 1) / stride; - out_x_loop_end_unclampled = (pad_width + input_width - - dilation_factor * filter_x + stride - 1) / - stride; + out_x_loop_end_unclamped = (pad_width + input_width - + dilation_factor * filter_x + stride - 1) / + stride; } } else { - out_x_loop_start_unclampled = pad_width - dilation_factor * filter_x; - out_x_loop_end_unclampled = + out_x_loop_start_unclamped = pad_width - dilation_factor * filter_x; + out_x_loop_end_unclamped = pad_width + input_width - dilation_factor * filter_x; } // The kernel will have to iterate on the segment of the // output row that starts at out_x_loop_start and out_x_loop_end. const int out_x_loop_start = - std::max(out_x_buffer_start, out_x_loop_start_unclampled); + std::max(out_x_buffer_start, out_x_loop_start_unclamped); const int out_x_loop_end = - std::min(out_x_buffer_end, out_x_loop_end_unclampled); + std::min(out_x_buffer_end, out_x_loop_end_unclamped); int32* acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 3dc863dcccd..6fd101d1ca6 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -13128,7 +13128,7 @@ inline void DepthwiseConvDotProduct3x3Impl( // "next" data, of at least 16 bytes, even when at the end of the workspace. // It is relatively expensive to detect the end micro block. It is also very // difficult to test for (to trigger) erroneous reads (past end of array) in - // the depth multplication case. + // the depth multiplication case. int workspace_width_micro_repeats = (has_depth_multiplication ? kDepthwiseConvScratchWorkspaceSize - kWorkspaceExtension diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h index 642d7577a1b..2d0568fa4c8 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h @@ -1441,37 +1441,37 @@ void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, for (int filter_x = 0; filter_x < filter_width; ++filter_x) { // For the current (filter_x, filter_y) point in the filter, // compute the boundaries of the corresponding output row segment. - int out_x_loop_start_unclampled = 0; - int out_x_loop_end_unclampled = 0; + int out_x_loop_start_unclamped = 0; + int out_x_loop_end_unclamped = 0; if (kAllowStrided) { if (stride == 2) { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + 1) / 2; - out_x_loop_end_unclampled = + out_x_loop_end_unclamped = (pad_width + input_width - dilation_factor * filter_x + 1) / 2; } else if (stride == 4) { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + 3) / 4; - out_x_loop_end_unclampled = + out_x_loop_end_unclamped = (pad_width + input_width - dilation_factor * filter_x + 3) / 4; } else { - out_x_loop_start_unclampled = + out_x_loop_start_unclamped = (pad_width - dilation_factor * filter_x + stride - 1) / stride; - out_x_loop_end_unclampled = (pad_width + input_width - - dilation_factor * filter_x + stride - 1) / - stride; + out_x_loop_end_unclamped = (pad_width + input_width - + dilation_factor * filter_x + stride - 1) / + stride; } } else { - out_x_loop_start_unclampled = pad_width - dilation_factor * filter_x; - out_x_loop_end_unclampled = + out_x_loop_start_unclamped = pad_width - dilation_factor * filter_x; + out_x_loop_end_unclamped = pad_width + input_width - dilation_factor * filter_x; } // The kernel will have to iterate on the segment of the // output row that starts at out_x_loop_start and out_x_loop_end. const int out_x_loop_start = - std::max(out_x_buffer_start, out_x_loop_start_unclampled); + std::max(out_x_buffer_start, out_x_loop_start_unclamped); const int out_x_loop_end = - std::min(out_x_buffer_end, out_x_loop_end_unclampled); + std::min(out_x_buffer_end, out_x_loop_end_unclamped); int32* acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h index 999f3e0d771..1efe6c7e0fd 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h @@ -179,10 +179,10 @@ struct DepthwiseConvWindowPerChannel tasks; // TODO(b/131746020) don't create new heap allocations every time. diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 2de4c209209..cfe5ab10fb2 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -2339,7 +2339,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, const int32x4_t f2i0_i32x4 = RoundToNearest(mul0_f32x4); const int32x4_t f2i1_i32x4 = RoundToNearest(mul1_f32x4); - // Implements the vectorized version of the folowing block: + // Implements the vectorized version of the following block: // quantized_values[i] = std::min(kScale, std::max(-kScale, // quantized_value)); int32x4_t max0_i32x4 = vmaxq_s32(f2i0_i32x4, neg_scale_i32x4); diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 15006d12c08..64ddaff8a3d 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -1123,7 +1123,7 @@ inline void Mean(const tflite::MeanParams& op_params, MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias, output_shape, output_data, 0, output_depth); } else { - // Instead parrallel for batch, we loop for the output_depth since batch + // Instead parallel for batch, we loop for the output_depth since batch // is typical 1. std::vector tasks; // TODO(b/131746020) don't create new heap allocations every time. @@ -5714,7 +5714,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, // .... // // In order to minimize the reload of the multipliers & shifts, once we load - // the multipliers & shifts, we load & quantize the raw accumualtrs for every + // the multipliers & shifts, we load & quantize the raw accumulators for every // row. #ifdef USE_NEON const int32x4_t output_offset_vec = vdupq_n_s32(output_zp); @@ -6369,7 +6369,7 @@ inline void HardSwish(const HardSwishParams& params, // Unfortunately, the Intel arm_neon_sse.h implementation of vqshl* is // buggy in the case of zero shift amounts, see b/137199585. That is why // this NEON code path is restricted to true ARM NEON, excluding - // arm_neon_sse.h. Anyway, the arm_neon_sse.h implemenation of saturating + // arm_neon_sse.h. Anyway, the arm_neon_sse.h implementation of saturating // left shifts is slow scalar code, so there may not be much benefit in // running that over just plain reference code. // @@ -7039,7 +7039,7 @@ inline void ClampWithRangeAndStore(int8_t* output_dst, int8x16_t input_val, #endif // GEMMLOWP_NEON -inline void Tanh16bitPercision(const TanhParams& params, +inline void Tanh16bitPrecision(const TanhParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& output_shape, @@ -7146,7 +7146,7 @@ inline void Tanh16bitPercision(const TanhParams& params, } } -inline void Tanh16bitPercision(const TanhParams& params, +inline void Tanh16bitPrecision(const TanhParams& params, const RuntimeShape& input_shape, const int8* input_data, const RuntimeShape& output_shape, @@ -7239,7 +7239,7 @@ inline void Tanh16bitPercision(const TanhParams& params, } } -inline void Logistic16bitPercision(const LogisticParams& params, +inline void Logistic16bitPrecision(const LogisticParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& output_shape, @@ -7331,7 +7331,7 @@ inline void Logistic16bitPercision(const LogisticParams& params, } } -inline void Logistic16bitPercision(const LogisticParams& params, +inline void Logistic16bitPrecision(const LogisticParams& params, const RuntimeShape& input_shape, const int8* input_data, const RuntimeShape& output_shape, diff --git a/tensorflow/lite/kernels/internal/quantization_util.cc b/tensorflow/lite/kernels/internal/quantization_util.cc index d94ca5beba9..8e28361f1f4 100644 --- a/tensorflow/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/lite/kernels/internal/quantization_util.cc @@ -372,7 +372,7 @@ void FakeQuantizeArray(const float nudged_scale, const float nudged_min, bool CheckedLog2(const float x, int* log2_result) { // Using TfLiteRound instead of std::round and std::log instead of - // std::log2 to work around these fuctions being missing in a toolchain + // std::log2 to work around these functions being missing in a toolchain // used in some TensorFlow tests as of May 2018. const float x_log2 = std::log(x) * (1.0f / std::log(2.0f)); const float x_log2_rounded = TfLiteRound(x_log2); diff --git a/tensorflow/lite/kernels/internal/reference/binary_function.h b/tensorflow/lite/kernels/internal/reference/binary_function.h index 82095af84a4..51d9e2b711a 100644 --- a/tensorflow/lite/kernels/internal/reference/binary_function.h +++ b/tensorflow/lite/kernels/internal/reference/binary_function.h @@ -26,7 +26,7 @@ namespace reference_ops { // TODO(ycling): Refactoring. Remove BroadcastLogical and use the more // generalized and efficient BroadcastBinaryFunction. // -// Also appears to duplicte MinimumMaximum. +// Also appears to duplicate MinimumMaximum. // // R: Result type. T1: Input 1 type. T2: Input 2 type. template diff --git a/tensorflow/lite/kernels/internal/spectrogram.cc b/tensorflow/lite/kernels/internal/spectrogram.cc index 784e4bc99ef..a832962a38d 100644 --- a/tensorflow/lite/kernels/internal/spectrogram.cc +++ b/tensorflow/lite/kernels/internal/spectrogram.cc @@ -175,7 +175,7 @@ bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( for (int i = 0; i < output_frequency_channels_; ++i) { // Similar to the Complex case, except storing the norm. // But the norm function is known to be a performance killer, - // so do it this way with explicit real and imagninary temps. + // so do it this way with explicit real and imaginary temps. const double re = fft_input_output_[2 * i]; const double im = fft_input_output_[2 * i + 1]; // Which finally converts double to float if it needs to. diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h index 775b2e58bea..9a418bfe6e7 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/lite/kernels/internal/tensor_utils.h @@ -161,8 +161,8 @@ void SparseMatrixBatchVectorMultiplyAccumulate( // - multiplier and shift combined gives the scale. // - assumes input zero point is 0. // - scratch is created for optimization purpose only. -// TODO(jianlijianli): this can be removed if some furture optimization -// work makes it unnecesssary. +// TODO(jianlijianli): this can be removed if some future optimization +// work makes it unnecessary. void MatrixBatchVectorMultiplyAccumulate( const int8_t* input, const int32_t* bias, const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, @@ -192,8 +192,8 @@ void MatrixBatchVectorMultiplyAccumulate( // - multiplier and shift combined gives the scale. // - assumes input zero point is 0. // - scratch is created for optimization purpose only. -// TODO(jianlijianli): this can be removed if some furture optimization -// work makes it unnecesssary. +// TODO(jianlijianli): this can be removed if some future optimization +// work makes it unnecessary. void MatrixBatchVectorMultiplyAccumulate( const int8_t* input, const int32_t* bias, const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift, @@ -231,7 +231,7 @@ void MatrixBatchVectorMultiply(const int16_t* hidden, // - output: the 32bit output // Note: We do not need saturation because the int8 * int8 is safe from overflow // in (2^31-1) / (2^14) = 131072, which is bigger than the n_row. Non-zero -// initial output value is not exceiptionally large. +// initial output value is not exceptionally large. void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar, int32_t n_row, int32_t n_col, int32_t* output); @@ -372,7 +372,7 @@ inline void VectorVectorCwiseProduct(const T* __restrict__ vector1, } } -// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// Cwise product and accumulate of two vectors. Since it's a MAC operation, the // assumption here is that result array is initialized to valid values. template inline void VectorVectorCwiseProductAccumulate(const T* __restrict__ vector1, diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc index 3c34c435c2e..e039fb841ec 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc @@ -371,14 +371,14 @@ TEST(uKernels, QuantMatrixBatchVectorMultiplyAccumulate8x8_16Test) { const int32_t multiplier = 2080364544; const int32_t shift = -2; - std::vector scrach(2 * 9, 0); + std::vector scratch(2 * 9, 0); std::vector output = {10, 2, 33, 4, 5, 6, 65, 4, 3, 52, 1, 2, 8, -1, -2, 11, 17, -18}; MatrixBatchVectorMultiplyAccumulate( input.data(), input_zeropoint_times_weights.data(), input_to_gate_weights.data(), multiplier, shift, /*n_batch=*/2, /*n_input=*/30, /*n_output=*/9, /*output_zp=*/0, - scrach.data(), output.data(), &context); + scratch.data(), output.data(), &context); const std::vector expected_output = { -210, 331, 153, 139, -570, -657, 258, 515, -495, 91, -243, -73, 603, -744, -269, 169, -748, -174, @@ -497,11 +497,11 @@ TEST(uKernels, QuantMatrixBatchVectorMultiplyAccumulate8x8_8Test) { std::vector output = {1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1, 2, 8, -1, -2, 11, 17, 18}; - std::vector scrach(2 * 9, 0); + std::vector scratch(2 * 9, 0); MatrixBatchVectorMultiplyAccumulate( input.data(), input_zeropoint_times_weights.data(), input_to_gate_weights.data(), multiplier, shift, - /*n_batch=*/2, /*n_input=*/30, /*n_output=*/9, output_zp, scrach.data(), + /*n_batch=*/2, /*n_input=*/30, /*n_output=*/9, output_zp, scratch.data(), output.data(), &context); const std::vector expected_output = { 5, -9, -2, -30, -5, -11, -22, -18, 18, diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index f9c2352e95b..49700dc8d12 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -100,7 +100,7 @@ TfLiteStatus PopulateConvolutionQuantizationParams( context, input, filter, bias, output, &real_multiplier)); int exponent; - // Populate quantization parameteters with multiplier and shift. + // Populate quantization parameters with multiplier and shift. QuantizeMultiplier(real_multiplier, multiplier, &exponent); *shift = -exponent; } diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index fceea866fca..e7de22158a4 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -1248,7 +1248,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // Create a scratch buffer tensor for float case and hybrid case. - // TODO(jianlijianli): Create a is_float boolean and reorginze the temporary + // TODO(jianlijianli): Create a is_float boolean and reorganize the temporary // buffer allocation logic. if (!is_integer) { node->temporaries->data[0] = op_data->scratch_tensor_index; diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 85c3f506df4..454a223440e 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -861,7 +861,7 @@ inline void LstmStepHybrid( // // Layer norm coefficients of size 'n_cell', representing diagonal matrices. // layer_norm_input_weight_ptr - optional -// layer_norm_forput_weight_ptr - optional +// layer_norm_forget_weight_ptr - optional // layer_norm_cell_weight_ptr - optional // layer_norm_output_weight_ptr - optional // @@ -1187,7 +1187,7 @@ inline void LstmStepInteger( // // Layer norm coefficients of size 'n_cell', representing diagonal matrices. // layer_norm_input_weight_ptr - optional -// layer_norm_forput_weight_ptr - optional +// layer_norm_forget_weight_ptr - optional // layer_norm_cell_weight_ptr - optional // layer_norm_output_weight_ptr - optional // diff --git a/tensorflow/lite/kernels/matrix_diag_test.cc b/tensorflow/lite/kernels/matrix_diag_test.cc index 298ae264433..09a72e9b726 100644 --- a/tensorflow/lite/kernels/matrix_diag_test.cc +++ b/tensorflow/lite/kernels/matrix_diag_test.cc @@ -91,7 +91,7 @@ TEST(MatrixDiagTest, Int32TestTwoDimDiag) { EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32); } -TEST(MatrixDiagTest, DegenenerateCase) { +TEST(MatrixDiagTest, DegenerateCase) { MatrixDiagOpModel model({TensorType_UINT8, {1}}); model.PopulateTensor(model.input(), {1}); model.Invoke(); diff --git a/tensorflow/lite/kernels/pad_test.cc b/tensorflow/lite/kernels/pad_test.cc index 96500a5bf4c..8ef03290531 100644 --- a/tensorflow/lite/kernels/pad_test.cc +++ b/tensorflow/lite/kernels/pad_test.cc @@ -25,11 +25,11 @@ namespace { using ::testing::ElementsAreArray; using ::testing::Matcher; -template +template class PadOpModel : public SingleOpModel { public: - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); } template @@ -46,8 +46,8 @@ class PadOpModel : public SingleOpModel { PopulateTensor(paddings_, paddings); } - std::vector GetOutput() { - return ExtractVector(output_); + std::vector GetOutput() { + return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } @@ -128,17 +128,17 @@ class PadOpConstModel : public PadOpModel { }; // Test case where paddings is a non-const tensor. -template -class PadV2OpDynamicModel : public PadOpModel { +template +class PadV2OpDynamicModel : public PadOpModel { public: PadV2OpDynamicModel(const TensorData& input, std::initializer_list paddings_shape, - RegularInputOuput constant_values, + RegularInputOutput constant_values, const TensorData& output) { this->input_ = this->AddInput(input); this->paddings_ = this->AddInput(TensorType_INT32); this->constant_values_ = this->AddConstInput( - GetTensorType(), {constant_values}, {1}); + GetTensorType(), {constant_values}, {1}); this->output_ = this->AddOutput(output); this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options, diff --git a/tensorflow/lite/kernels/rfft2d.cc b/tensorflow/lite/kernels/rfft2d.cc index f46feccce66..a06b66735f6 100644 --- a/tensorflow/lite/kernels/rfft2d.cc +++ b/tensorflow/lite/kernels/rfft2d.cc @@ -360,7 +360,7 @@ TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) { double* fft_double_working_area_data = reinterpret_cast( GetTensorData(fft_double_working_area)); - // Process evert slice in the input buffer + // Process every slice in the input buffer for (int i = 0; i < num_slices; ++i) { PrepareInputBuffer(input_data, input_height, input_width, fft_height, fft_width, fft_input_output); diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 8db98dba0e9..c687d0761fc 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -87,7 +87,7 @@ TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) { "StridedSlice op only supports 1D-5D input arrays."); } -TYPED_TEST(StridedSliceOpTest, UnssupportedArgs) { +TYPED_TEST(StridedSliceOpTest, UnsupportedArgs) { EXPECT_DEATH( StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), "ellipsis_mask is not implemented yet."); diff --git a/tensorflow/lite/kernels/subgraph_test_util.h b/tensorflow/lite/kernels/subgraph_test_util.h index 972f1381af2..95b7206fc29 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.h +++ b/tensorflow/lite/kernels/subgraph_test_util.h @@ -63,7 +63,7 @@ class SubgraphBuilder { void BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs); // An accumulate loop body subgraph. Used to produce triangle number - // seqeuence. 2 inputs and 2 outpus + // sequence. 2 inputs and 2 outputs // Equivalent to (counter, value) -> (counter + 1, counter + 1 + value) void BuildAccumulateLoopBodySubgraph(Subgraph* subgraph); diff --git a/tensorflow/lite/kernels/subgraph_test_util_test.cc b/tensorflow/lite/kernels/subgraph_test_util_test.cc index 0a8646e8c4f..4bd0482da17 100644 --- a/tensorflow/lite/kernels/subgraph_test_util_test.cc +++ b/tensorflow/lite/kernels/subgraph_test_util_test.cc @@ -36,7 +36,7 @@ class SubgraphBuilderTest : public ::testing::Test { } protected: - void TestAccumelateLoopBody(int input1, int input2, int output1, + void TestAccumulateLoopBody(int input1, int input2, int output1, int output2) { interpreter_.reset(new Interpreter); builder_->BuildAccumulateLoopBodySubgraph( @@ -140,9 +140,9 @@ TEST_F(SubgraphBuilderTest, TestBuildLessEqualCondSubgraph) { } TEST_F(SubgraphBuilderTest, TestBuildAccumulateLoopBodySubgraph) { - TestAccumelateLoopBody(1, 1, 2, 3); - TestAccumelateLoopBody(2, 3, 3, 6); - TestAccumelateLoopBody(3, 6, 4, 10); + TestAccumulateLoopBody(1, 1, 2, 3); + TestAccumulateLoopBody(2, 3, 3, 6); + TestAccumulateLoopBody(3, 6, 4, 10); } TEST_F(SubgraphBuilderTest, TestBuildPadLoopBodySubgraph) { diff --git a/tensorflow/lite/kernels/svdf_test.cc b/tensorflow/lite/kernels/svdf_test.cc index f48e6f69e4d..1f5cfb040e7 100644 --- a/tensorflow/lite/kernels/svdf_test.cc +++ b/tensorflow/lite/kernels/svdf_test.cc @@ -547,7 +547,7 @@ TEST_F(SVDFOpTest, BlackBoxTestInteger) { svdf.SetBias({-0.0976817, 0.15294972, 0.39635518, -0.02702999}); - const std::vector> input_sequnces = { + const std::vector> input_sequences = { {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279}, {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406, 0.73463392}, @@ -585,7 +585,7 @@ TEST_F(SVDFOpTest, BlackBoxTestInteger) { }; for (int sequence_index = 0; sequence_index < 12; ++sequence_index) { - svdf.SetInput(input_sequnces[sequence_index]); + svdf.SetInput(input_sequences[sequence_index]); svdf.Invoke(); const std::vector res = svdf.GetOutput(); EXPECT_THAT(res, ElementsAreArray(expected_output[sequence_index])); diff --git a/tensorflow/lite/kernels/variable_ops_test.cc b/tensorflow/lite/kernels/variable_ops_test.cc index d6a3f916d12..2efac9d7d8f 100644 --- a/tensorflow/lite/kernels/variable_ops_test.cc +++ b/tensorflow/lite/kernels/variable_ops_test.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tflite { -// Forward declaraction for op kernels. +// Forward declaration for op kernels. namespace ops { namespace custom { @@ -104,7 +104,7 @@ TEST_F(VariableOpsTest, TestReadVariableBeforeAssign) { ASSERT_EQ(interpreter_.Invoke(), kTfLiteError); } -TEST_F(VariableOpsTest, TestReeasignToDifferentSize) { +TEST_F(VariableOpsTest, TestReassignToDifferentSize) { // 1st invocation. The variable is assigned as a scalar. { ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); diff --git a/tensorflow/lite/kernels/while_test.cc b/tensorflow/lite/kernels/while_test.cc index 1745f585ed0..dc69e496533 100644 --- a/tensorflow/lite/kernels/while_test.cc +++ b/tensorflow/lite/kernels/while_test.cc @@ -79,7 +79,7 @@ TEST_F(WhileTest, TestPadLoop) { TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0}); - // The extra invocation serves as a regiression test: There was a bug that + // The extra invocation serves as a regression test: There was a bug that // invoking a while loop with dynamic shaped body makes the interpreter // state uninvokable. ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); From 37dfcf6e2a28c4435c227c2d708fd5c66bbcd5ae Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 18 Mar 2020 09:14:38 -0700 Subject: [PATCH 143/492] [TF] Expose C API in libtensorflow_framework. While at it, expose the associated header files from tensorflow/c/ in pip package. Note, we expose the subset of C API that doesn't require tensorflow/cc linkage; specifically the core operations that exclude building while loops and gradient ops, and also excluding the experimental API. The experimental API can also be added in the future, by factoring it into "core" and "non-core" targets. Similarly for the C eager API. PiperOrigin-RevId: 301601988 Change-Id: I97eac79e684fc42ce90e67ee901cdcf6f7e91cbe --- tensorflow/BUILD | 1 - tensorflow/c/BUILD | 76 +- tensorflow/c/c_api.cc | 2083 ++++++++++++++++++++++- tensorflow/c/c_api.h | 1423 +++++++++++++++- tensorflow/c/c_api_internal.h | 8 +- tensorflow/c/c_core_api.cc | 2193 ------------------------- tensorflow/c/c_core_api.h | 1456 ---------------- tensorflow/c/eager/BUILD | 2 +- tensorflow/c/eager/c_api.cc | 2 +- tensorflow/c/eager/c_api.h | 2 +- tensorflow/tools/pip_package/BUILD | 1 - tensorflow/tools/pip_package/setup.py | 1 - 12 files changed, 3509 insertions(+), 3739 deletions(-) delete mode 100644 tensorflow/c/c_core_api.cc delete mode 100644 tensorflow/c/c_core_api.h diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 005acff27f7..55406a5686a 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -644,7 +644,6 @@ tf_cc_shared_object( "//tensorflow/core:lib_internal_impl", "//tensorflow/core/profiler:profiler_impl", "//tensorflow/stream_executor:stream_executor_impl", - "//tensorflow/c:c_core_api_no_xla", "//tensorflow:tf_framework_version_script.lds", ] + tf_additional_binary_deps(), ) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 248bb826c28..c5574793b74 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -23,7 +23,6 @@ filegroup( srcs = [ "c_api.h", "c_api_experimental.h", - "c_core_api.h", "tf_attrtype.h", "tf_datatype.h", "tf_file_statistics.h", @@ -74,7 +73,6 @@ tf_cuda_library( hdrs = [ "c_api.h", "c_api_internal.h", - "c_core_api.h", "tf_datatype.h", "tf_tensor.h", ], @@ -118,41 +116,10 @@ cc_library( visibility = ["//visibility:public"], ) -tf_cuda_library( - name = "c_core_api", - hdrs = [ - "c_core_api.h", - "tf_attrtype.h", - "tf_datatype.h", - "tf_file_statistics.h", - "tf_status.h", - "tf_tensor.h", - ], - copts = tf_copts(), - visibility = [ - "//visibility:public", - ], - deps = [ - ":c_core_api_no_xla", - ":c_api_internal", - ":tf_attrtype", - ":tf_status_internal", - ":tf_file_statistics", - ":tf_tensor_internal", - ] + select({ - "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/jit", - ], - "//conditions:default": [], - }), -) - tf_cuda_library( name = "c_api", hdrs = [ "c_api.h", - "c_core_api.h", "tf_attrtype.h", "tf_datatype.h", "tf_file_statistics.h", @@ -162,7 +129,6 @@ tf_cuda_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - ":c_core_api", ":c_api_no_xla", ":c_api_internal", ":tf_attrtype", @@ -178,48 +144,11 @@ tf_cuda_library( }), ) -tf_cuda_library( - name = "c_core_api_no_xla", - srcs = [ - "c_api_function.cc", - "c_core_api.cc", - ], - hdrs = [ - "c_core_api.h", - ], - copts = tf_copts(), - visibility = ["//tensorflow:__subpackages__"], - deps = [ - ":c_api_internal", - ":tf_attrtype", - ":tf_datatype", - ":tf_status_internal", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", - ], - "//conditions:default": [ - ":tf_status", - ":tf_tensor", - "@com_google_absl//absl/strings", - "//tensorflow/cc/saved_model:loader_lite", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:op_gen_lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/distributed_runtime:server_lib", - ], - }), - alwayslink = 1, -) - tf_cuda_library( name = "c_api_no_xla", srcs = [ "c_api.cc", + "c_api_function.cc", ], hdrs = [ "c_api.h", @@ -230,7 +159,6 @@ tf_cuda_library( "//third_party/llvm/llvm-project:__subpackages__", ], deps = [ - ":c_core_api_no_xla", ":c_api_internal", ":tf_attrtype", ":tf_datatype", @@ -256,6 +184,8 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/kernels:logging_ops", ], }), alwayslink = 1, diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 3a110e4c9f2..bc1fbd3fcf5 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -29,6 +29,9 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/while_loop.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/framework/logging.h" #include "tensorflow/core/framework/op_gen_lib.h" #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/c/c_api_internal.h" @@ -96,14 +99,566 @@ using tensorflow::TensorBuffer; using tensorflow::TensorId; using tensorflow::TensorShape; using tensorflow::TensorShapeProto; -using tensorflow::ToTensorId; using tensorflow::VersionDef; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; using tensorflow::strings::StrCat; +extern "C" { + +// -------------------------------------------------------------------------- +const char* TF_Version() { return TF_VERSION_STRING; } + +// -------------------------------------------------------------------------- + +// -------------------------------------------------------------------------- +TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } +void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } + +void TF_SetTarget(TF_SessionOptions* options, const char* target) { + options->options.target = target; +} + +void TF_SetConfig(TF_SessionOptions* options, const void* proto, + size_t proto_len, TF_Status* status) { + if (!options->options.config.ParseFromArray(proto, proto_len)) { + status->status = InvalidArgument("Unparseable ConfigProto"); + } +} +// -------------------------------------------------------------------------- +TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; } + +TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { + void* copy = tensorflow::port::Malloc(proto_len); + memcpy(copy, proto, proto_len); + + TF_Buffer* buf = new TF_Buffer; + buf->data = copy; + buf->length = proto_len; + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; + return buf; +} + +void TF_DeleteBuffer(TF_Buffer* buffer) { + if (buffer == nullptr) return; + if (buffer->data_deallocator != nullptr) { + (*buffer->data_deallocator)(const_cast(buffer->data), + buffer->length); + } + delete buffer; +} + +TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } + +// -------------------------------------------------------------------------- + +TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, + TF_Status* status) { + Session* session; + status->status = NewSession(opt->options, &session); + if (status->status.ok()) { + return new TF_DeprecatedSession({session}); + } else { + DCHECK_EQ(nullptr, session); + return nullptr; + } +} + +void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { + status->status = s->session->Close(); +} + +void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { + status->status = Status::OK(); + if (s == nullptr) return; + delete s->session; + delete s; +} + +void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto, + size_t proto_len, TF_Status* status) { + GraphDef g; + if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { + status->status = InvalidArgument("Invalid GraphDef"); + return; + } + status->status = s->session->Extend(g); +} + +} // end extern "C" + +// Reset helper for converting character arrays to string vectors. +static void TF_Reset_Helper(const TF_SessionOptions* opt, + const char** containers, int ncontainers, + TF_Status* status) { + std::vector container_names(ncontainers); + for (int i = 0; i < ncontainers; ++i) { + container_names[i] = containers[i]; + } + + status->status = Reset(opt->options, container_names); +} + +extern "C" { + +void TF_Reset(const TF_SessionOptions* opt, const char** containers, + int ncontainers, TF_Status* status) { + TF_Reset_Helper(opt, containers, ncontainers, status); +} + +} // end extern "C" + +namespace tensorflow { + + +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out) { + if (out->data != nullptr) { + return InvalidArgument("Passing non-empty TF_Buffer is invalid."); + } + const size_t proto_size = in.ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return tensorflow::errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '", + in.GetTypeName(), "' and size ", proto_size); + } + if (!in.SerializeWithCachedSizesToArray(static_cast(buf))) { + port::Free(buf); + return InvalidArgument("Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", + proto_size, " bytes) is too large?"); + } + out->data = buf; + out->length = proto_size; + out->data_deallocator = [](void* data, size_t length) { port::Free(data); }; + return Status::OK(); +} + +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) { + // If any session has already run this node_id, mark this session as + // unrunnable. + for (auto it : graph->sessions) { + mutex_lock session_lock(it.first->mu); + if (it.first->last_num_graph_nodes > op.node.id()) { + it.second = strings::StrCat( + "Operation '", op.node.DebugString(), "' was changed by ", + mutation_type, + " after it was run by a session. This mutation will have no effect, " + "and will trigger an error in the future. Either don't modify " + "nodes after running them or create a new session."); + } + } +} + namespace { + +// Helper method that creates a shape handle for a shape described by dims. +tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( + tensorflow::shape_inference::InferenceContext* ic, int num_dims, + const int64_t* dims) { + if (num_dims != -1) { + std::vector dim_vec; + dim_vec.reserve(num_dims); + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + return ic->MakeShape(dim_vec); + } else { + return ic->UnknownShape(); + } +} + +} // namespace + +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + + auto shape_and_type_vec = + std::vector( + num_shapes_and_types); + for (int i = 0; i < num_shapes_and_types; ++i) { + tensorflow::shape_inference::ShapeHandle shape_handle = + ShapeHandleFromDims(ic, ranks[i], shapes[i]); + shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( + shape_handle, static_cast(types[i])); + } + + ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); +} + +// Helpers for loading a TensorFlow plugin (a .so file). +Status LoadLibrary(const char* library_filename, void** result, + const void** buf, size_t* len); + +// TODO(josh11b,mrry): Change Session to be able to use a Graph* +// directly, instead of requiring us to serialize to a GraphDef and +// call Session::Extend(). +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { + if (session->graph != nullptr) { + // Take the graph lock before the session lock to avoid deadlock. This is + // safe since session->graph does not change. + session->graph->mu.lock(); + mutex_lock session_lock(session->mu); + const Graph& graph = session->graph->graph; + + const string& mutation_warning = session->graph->sessions[session]; + if (!mutation_warning.empty()) { + // TODO(b/74949947): turn this back into an error status + LOG(WARNING) << mutation_warning; + session->graph->sessions[session].clear(); + } + + const auto num_nodes = graph.num_node_ids(); + if (session->last_num_graph_nodes < num_nodes) { + // TODO(nolivia): check this on a subset of the graph instead of all of + // it. + status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + + GraphDef graph_def; + *graph_def.mutable_versions() = graph.versions(); + // Fill graph_def with nodes with ids in the range + // [session->last_num_graph_nodes, num_nodes), that is the nodes + // added since the last TF_SessionRun() call. + for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { + Node* const node = graph.FindNodeId(id); + if (node != nullptr && node->IsOp()) { + NodeDef* const node_def = graph_def.add_node(); + *node_def = node->def(); + } + } + *graph_def.mutable_library() = graph.flib_def().ToProto(); + session->graph->mu.unlock(); + status->status = session->session->Extend(std::move(graph_def)); + if (!status->status.ok()) { + // Contract is we always delete input_values[i]. + return false; + } + // Note: session->session is not modified if Extend() fails, so + // we only set last_num_graph_nodes if it succeeds. + session->last_num_graph_nodes = num_nodes; + } else { + session->graph->mu.unlock(); + } + } + return true; +} + +} // namespace tensorflow + +static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, + TF_Status* status) { + status->status = Status::OK(); + for (int i = 0; i < noutputs; ++i) { + c_outputs[i] = nullptr; + } +} + +static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, + std::vector>* input_pairs, + TF_Status* status) { + const int ninputs = input_pairs->size(); + for (int i = 0; i < ninputs; ++i) { + status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); + if (!status->status.ok()) return false; + } + return true; +} + +// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to +// result in a zero-sized tensor. +static TF_Tensor* EmptyTensor(TF_DataType dtype, + const tensorflow::TensorShape& shape) { + static char empty; + tensorflow::int64 nelems = 1; + std::vector dims; + for (int i = 0; i < shape.dims(); ++i) { + dims.push_back(shape.dim_size(i)); + nelems *= shape.dim_size(i); + } + CHECK_EQ(nelems, 0); + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + return TF_NewTensor( + dtype, reinterpret_cast(dims.data()), shape.dims(), + reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); +} + +static void TF_Run_Helper( + Session* session, const char* handle, const TF_Buffer* run_options, + // Input tensors + const std::vector>& input_pairs, + // Output tensors + const std::vector& output_tensor_names, TF_Tensor** c_outputs, + // Target nodes + const std::vector& target_oper_names, TF_Buffer* run_metadata, + TF_Status* status) { + const int noutputs = output_tensor_names.size(); + std::vector outputs(noutputs); + Status result; + + if (handle == nullptr) { + RunOptions run_options_proto; + if (run_options != nullptr && !run_options_proto.ParseFromArray( + run_options->data, run_options->length)) { + status->status = InvalidArgument("Unparseable RunOptions proto"); + return; + } + if (run_metadata != nullptr && run_metadata->data != nullptr) { + status->status = + InvalidArgument("Passing non-empty run_metadata is invalid."); + return; + } + + RunMetadata run_metadata_proto; + result = session->Run(run_options_proto, input_pairs, output_tensor_names, + target_oper_names, &outputs, &run_metadata_proto); + + // Serialize back to upstream client, who now owns the new buffer + if (run_metadata != nullptr) { + status->status = MessageToBuffer(run_metadata_proto, run_metadata); + if (!status->status.ok()) return; + } + } else { + // NOTE(zongheng): PRun does not support RunOptions yet. + result = session->PRun(handle, input_pairs, output_tensor_names, &outputs); + } + if (!result.ok()) { + status->status = result; + return; + } + + // Store results in c_outputs[] + for (int i = 0; i < noutputs; ++i) { + const Tensor& src = outputs[i]; + if (!src.IsInitialized() || src.NumElements() == 0) { + c_outputs[i] = + EmptyTensor(static_cast(src.dtype()), src.shape()); + continue; + } + c_outputs[i] = TF_TensorFromTensor(src, &status->status); + if (!status->status.ok()) return; + } +} + +extern "C" { + +void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, + // Input tensors + const char** c_input_names, TF_Tensor** c_inputs, int ninputs, + // Output tensors + const char** c_output_names, TF_Tensor** c_outputs, int noutputs, + // Target nodes + const char** c_target_oper_names, int ntargets, + TF_Buffer* run_metadata, TF_Status* status) { + TF_Run_Setup(noutputs, c_outputs, status); + std::vector> input_pairs(ninputs); + if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; + for (int i = 0; i < ninputs; ++i) { + input_pairs[i].first = c_input_names[i]; + } + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = c_output_names[i]; + } + std::vector target_oper_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_oper_names[i] = c_target_oper_names[i]; + } + TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names, + c_outputs, target_oper_names, run_metadata, status); +} + +void TF_PRunSetup(TF_DeprecatedSession* s, + // Input names + const char** c_input_names, int ninputs, + // Output names + const char** c_output_names, int noutputs, + // Target nodes + const char** c_target_oper_names, int ntargets, + const char** handle, TF_Status* status) { + *handle = nullptr; + + std::vector input_names(ninputs); + std::vector output_names(noutputs); + std::vector target_oper_names(ntargets); + for (int i = 0; i < ninputs; ++i) { + input_names[i] = c_input_names[i]; + } + for (int i = 0; i < noutputs; ++i) { + output_names[i] = c_output_names[i]; + } + for (int i = 0; i < ntargets; ++i) { + target_oper_names[i] = c_target_oper_names[i]; + } + string new_handle; + status->status = s->session->PRunSetup(input_names, output_names, + target_oper_names, &new_handle); + if (status->status.ok()) { + char* buf = new char[new_handle.size() + 1]; + memcpy(buf, new_handle.c_str(), new_handle.size() + 1); + *handle = buf; + } +} + +void TF_PRun(TF_DeprecatedSession* s, const char* handle, + // Input tensors + const char** c_input_names, TF_Tensor** c_inputs, int ninputs, + // Output tensors + const char** c_output_names, TF_Tensor** c_outputs, int noutputs, + // Target nodes + const char** c_target_oper_names, int ntargets, + TF_Status* status) { + TF_Run_Setup(noutputs, c_outputs, status); + std::vector> input_pairs(ninputs); + if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; + for (int i = 0; i < ninputs; ++i) { + input_pairs[i].first = c_input_names[i]; + } + + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = c_output_names[i]; + } + std::vector target_oper_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_oper_names[i] = c_target_oper_names[i]; + } + TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names, + c_outputs, target_oper_names, nullptr, status); +} + +TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { + TF_Library* lib_handle = new TF_Library; + status->status = tensorflow::LoadLibrary( + library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, + &lib_handle->op_list.length); + if (!status->status.ok()) { + delete lib_handle; + return nullptr; + } + return lib_handle; +} + +TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } + +void TF_DeleteLibraryHandle(TF_Library* lib_handle) { + if (lib_handle == nullptr) return; + tensorflow::port::Free(const_cast(lib_handle->op_list.data)); + delete lib_handle; +} + +TF_Buffer* TF_GetAllOpList() { + std::vector op_defs; + tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs); + tensorflow::OpList op_list; + for (const auto& op : op_defs) { + *(op_list.add_op()) = op; + } + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(op_list, ret)); + return ret; +} + +// -------------------------------------------------------------------------- +// ListDevices & SessionListDevices API + +void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; } + +TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { + TF_DeviceList* response = new TF_DeviceList; + status->status = session->session->ListDevices(&response->response); + return response; +} + +TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session, + TF_Status* status) { + TF_DeviceList* response = new TF_DeviceList; + status->status = session->session->ListDevices(&response->response); + return response; +} + +int TF_DeviceListCount(const TF_DeviceList* list) { + return list->response.size(); +} + +#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \ + return_type method_name(const TF_DeviceList* list, const int index, \ + TF_Status* status) { \ + if (list == nullptr) { \ + status->status = InvalidArgument("list is null!"); \ + return err_val; \ + } \ + if (index < 0 || index >= list->response.size()) { \ + status->status = InvalidArgument("index out of bounds"); \ + return err_val; \ + } \ + status->status = Status::OK(); \ + return list->response[index].accessor; \ + } + +TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); +TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), + nullptr); +TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); +TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0); + +#undef TF_DEVICELIST_METHOD + +} // end extern "C" + +// -------------------------------------------------------------------------- +// New Graph and Session API + +// Helper functions ----------------------------------------------------------- + +namespace { + +TF_Operation* ToOperation(Node* node) { + return static_cast(static_cast(node)); +} + +string OutputName(const TF_Output& output) { + return StrCat(output.oper->node.name(), ":", output.index); +} + +const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, + const char* attr_name, + TF_Status* status) { + const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); + if (attr == nullptr) { + status->status = InvalidArgument("Operation '", oper->node.name(), + "' has no attr named '", attr_name, "'."); + } + return attr; +} + +TensorId ToTensorId(const TF_Output& output) { + return TensorId(output.oper->node.name(), output.index); +} + #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) std::vector OutputsFromTFOutputs(TF_Output* tf_outputs, int n) { @@ -126,8 +681,1134 @@ void TFOutputsFromOutputs(const std::vector& outputs, } // namespace +// Shape functions ----------------------------------------------------------- + +void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, + const int64_t* dims, const int num_dims, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + tensorflow::shape_inference::ShapeHandle new_shape = + tensorflow::ShapeHandleFromDims(ic, num_dims, dims); + status->status = graph->refiner.SetShape(node, output.index, new_shape); +} + +int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return -1; + } + + tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); + + // Unknown rank means the number of dimensions is -1. + if (!ic->RankKnown(shape)) { + return -1; + } + + return ic->Rank(shape); +} + +void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, + int num_dims, TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + + tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); + + int rank = -1; + if (ic->RankKnown(shape)) { + rank = ic->Rank(shape); + } + + if (num_dims != rank) { + status->status = InvalidArgument("Expected rank is ", num_dims, + " but actual rank is ", rank); + return; + } + + if (num_dims == 0) { + // Output shape is a scalar. + return; + } + + // Rank is greater than 0, so fill in the values, if known, and + // -1 for unknown values. + for (int i = 0; i < num_dims; ++i) { + auto dim = ic->Dim(shape, i); + tensorflow::int64 value = -1; + if (ic->ValueKnown(dim)) { + value = ic->Value(dim); + } + dims[i] = value; + } +} + +// TF_OperationDescription functions ------------------------------------------ + extern "C" { +static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, + const char* op_type, + const char* oper_name) + TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + return new TF_OperationDescription(graph, op_type, oper_name); +} + +TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type, + const char* oper_name) { + mutex_lock l(graph->mu); + return TF_NewOperationLocked(graph, op_type, oper_name); +} + +void TF_SetDevice(TF_OperationDescription* desc, const char* device) { + desc->node_builder.Device(device); +} + +void TF_AddInput(TF_OperationDescription* desc, TF_Output input) { + desc->node_builder.Input(&input.oper->node, input.index); +} + +void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs, + int num_inputs) { + std::vector input_list; + input_list.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + input_list.emplace_back(&inputs[i].oper->node, inputs[i].index); + } + desc->node_builder.Input(input_list); +} + +void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { + desc->node_builder.ControlInput(&input->node); +} + +void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { + desc->colocation_constraints.emplace( + StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); +} + +void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, + const void* value, size_t length) { + tensorflow::StringPiece s(static_cast(value), length); + desc->node_builder.Attr(attr_name, s); +} + +void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values) { + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + desc->colocation_constraints.clear(); + for (int i = 0; i < num_values; ++i) { + desc->colocation_constraints.emplace(static_cast(values[i]), + lengths[i]); + } + } else { + std::vector v; + v.reserve(num_values); + for (int i = 0; i < num_values; ++i) { + v.emplace_back(static_cast(values[i]), lengths[i]); + } + desc->node_builder.Attr(attr_name, v); + } +} + +void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, + int64_t value) { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + desc->node_builder.Attr(attr_name, static_cast(value)); +} + +void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name, + const int64_t* values, int num_values) { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + desc->node_builder.Attr( + attr_name, + ArraySlice( + reinterpret_cast(values), num_values)); +} + +void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name, + float value) { + desc->node_builder.Attr(attr_name, value); +} + +void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name, + const float* values, int num_values) { + desc->node_builder.Attr(attr_name, + ArraySlice(values, num_values)); +} + +void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, + unsigned char value) { + desc->node_builder.Attr(attr_name, static_cast(value)); +} + +void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name, + const unsigned char* values, int num_values) { + std::unique_ptr b(new bool[num_values]); + for (int i = 0; i < num_values; ++i) { + b[i] = values[i]; + } + desc->node_builder.Attr(attr_name, + ArraySlice(b.get(), num_values)); +} + +void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, + TF_DataType value) { + desc->node_builder.Attr(attr_name, static_cast(value)); +} + +void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, + const TF_DataType* values, int num_values) { + desc->node_builder.Attr( + attr_name, ArraySlice( + reinterpret_cast(values), num_values)); +} + +void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name, + const char* placeholder) { + tensorflow::AttrValue attr_value; + attr_value.set_placeholder(placeholder); + desc->node_builder.Attr(attr_name, attr_value); +} + +void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, + const char* value, size_t length) { + tensorflow::NameAttrList func_name; + func_name.set_name(string(value, value + length)); + desc->node_builder.Attr(attr_name, func_name); +} + +void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, + const int64_t* dims, int num_dims) { + PartialTensorShape shape; + if (num_dims >= 0) { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + shape = PartialTensorShape(ArraySlice( + reinterpret_cast(dims), num_dims)); + } + desc->node_builder.Attr(attr_name, shape); +} + +void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name, + const int64_t* const* dims, const int* num_dims, + int num_shapes) { + std::vector shapes; + shapes.reserve(num_shapes); + for (int i = 0; i < num_shapes; ++i) { + if (num_dims[i] < 0) { + shapes.emplace_back(); + } else { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + shapes.emplace_back(ArraySlice( + reinterpret_cast(dims[i]), num_dims[i])); + } + } + desc->node_builder.Attr(attr_name, shapes); +} + +void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, + const char* attr_name, const void* proto, + size_t proto_len, TF_Status* status) { + // shape.ParseFromArray takes an int as length, this function takes size_t, + // make sure there is no information loss. + if (proto_len > std::numeric_limits::max()) { + status->status = InvalidArgument( + "proto_len (", proto_len, + " bytes) is too large to be parsed by the protocol buffer library"); + return; + } + TensorShapeProto shape; + if (shape.ParseFromArray(proto, static_cast(proto_len))) { + desc->node_builder.Attr(attr_name, shape); + status->status = Status::OK(); + } else { + status->status = InvalidArgument("Unparseable TensorShapeProto"); + } +} + +void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, + const char* attr_name, + const void* const* protos, + const size_t* proto_lens, int num_shapes, + TF_Status* status) { + std::vector shapes; + shapes.resize(num_shapes); + for (int i = 0; i < num_shapes; ++i) { + if (proto_lens[i] > std::numeric_limits::max()) { + status->status = InvalidArgument( + "length of element ", i, " in the list (", proto_lens[i], + " bytes) is too large to be parsed by the protocol buffer library"); + return; + } + if (!shapes[i].ParseFromArray(protos[i], static_cast(proto_lens[i]))) { + status->status = + InvalidArgument("Unparseable TensorShapeProto at index ", i); + return; + } + } + desc->node_builder.Attr(attr_name, shapes); + status->status = Status::OK(); +} + +void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, + TF_Tensor* value, TF_Status* status) { + Tensor t; + status->status = TF_TensorToTensor(value, &t); + if (status->status.ok()) desc->node_builder.Attr(attr_name, t); +} + +void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, + TF_Tensor* const* values, int num_values, + TF_Status* status) { + status->status = Status::OK(); + std::vector t; + t.reserve(num_values); + + for (int i = 0; i < num_values && status->status.ok(); ++i) { + Tensor v; + status->status = TF_TensorToTensor(values[i], &v); + t.emplace_back(v); + } + + if (status->status.ok()) desc->node_builder.Attr(attr_name, t); +} + +void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::AttrValue attr_value; + if (!attr_value.ParseFromArray(proto, proto_len)) { + status->status = InvalidArgument("Unparseable AttrValue proto"); + return; + } + + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + if (attr_value.value_case() != tensorflow::AttrValue::kList && + attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) { + status->status = + InvalidArgument("Expected \"list\" field for \"", + tensorflow::kColocationAttrName, "\" attribute"); + return; + } + desc->colocation_constraints.clear(); + for (const string& location : attr_value.list().s()) { + desc->colocation_constraints.insert(location); + } + } else { + desc->node_builder.Attr(attr_name, std::move(attr_value)); + } + + status->status = Status::OK(); +} + +static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, + TF_Status* status) + TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { + Node* ret = nullptr; + + if (desc->graph->name_map.count(desc->node_builder.node_name())) { + status->status = InvalidArgument("Duplicate node name in graph: '", + desc->node_builder.node_name(), "'"); + } else { + if (!desc->colocation_constraints.empty()) { + desc->node_builder.Attr( + tensorflow::kColocationAttrName, + std::vector(desc->colocation_constraints.begin(), + desc->colocation_constraints.end())); + } + status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret, + /*consume=*/true); + + if (status->status.ok()) { + // Run shape inference function for newly added node. + status->status = desc->graph->refiner.AddNode(ret); + } + if (status->status.ok()) { + // Add the node to the name-to-node mapping. + desc->graph->name_map[ret->name()] = ret; + } else if (ret != nullptr) { + desc->graph->graph.RemoveNode(ret); + ret = nullptr; + } + } + + delete desc; + + return ToOperation(ret); +} + +TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, + TF_Status* status) { + mutex_lock l(desc->graph->mu); + return TF_FinishOperationLocked(desc, status); +} + +// TF_Operation functions +// ---------------------------------------------------------- + +const char* TF_OperationName(TF_Operation* oper) { + return oper->node.name().c_str(); +} + +const char* TF_OperationOpType(TF_Operation* oper) { + return oper->node.type_string().c_str(); +} + +const char* TF_OperationDevice(TF_Operation* oper) { + return oper->node.requested_device().c_str(); +} + +int TF_OperationNumOutputs(TF_Operation* oper) { + return oper->node.num_outputs(); +} + +TF_DataType TF_OperationOutputType(TF_Output oper_out) { + return static_cast( + oper_out.oper->node.output_type(oper_out.index)); +} + +int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, + TF_Status* status) { + NameRangeMap name_ranges; + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); + if (!status->status.ok()) return -1; + auto iter = name_ranges.find(arg_name); + if (iter == name_ranges.end()) { + status->status = InvalidArgument("Output arg '", arg_name, "' not found"); + return -1; + } + return iter->second.second - iter->second.first; +} + +int TF_OperationNumInputs(TF_Operation* oper) { + return oper->node.num_inputs(); +} + +TF_DataType TF_OperationInputType(TF_Input oper_in) { + return static_cast(oper_in.oper->node.input_type(oper_in.index)); +} + +int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, + TF_Status* status) { + NameRangeMap name_ranges; + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); + if (!status->status.ok()) return -1; + auto iter = name_ranges.find(arg_name); + if (iter == name_ranges.end()) { + status->status = InvalidArgument("Input arg '", arg_name, "' not found"); + return -1; + } + return iter->second.second - iter->second.first; +} + +TF_Output TF_OperationInput(TF_Input oper_in) { + const tensorflow::Edge* edge; + Status s = oper_in.oper->node.input_edge(oper_in.index, &edge); + if (!s.ok()) { + return {nullptr, -1}; + } + + return {ToOperation(edge->src()), edge->src_output()}; +} + +void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs, + int max_inputs) { + for (auto* edge : oper->node.in_edges()) { + if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) { + inputs[edge->dst_input()] = {ToOperation(edge->src()), + edge->src_output()}; + } + } +} + +int TF_OperationOutputNumConsumers(TF_Output oper_out) { + int count = 0; + for (const auto* edge : oper_out.oper->node.out_edges()) { + if (edge->src_output() == oper_out.index) { + ++count; + } + } + return count; +} + +int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, + int max_consumers) { + int count = 0; + for (const auto* edge : oper_out.oper->node.out_edges()) { + if (edge->src_output() == oper_out.index) { + if (count < max_consumers) { + consumers[count] = {ToOperation(edge->dst()), edge->dst_input()}; + } + ++count; + } + } + return count; +} + +int TF_OperationNumControlInputs(TF_Operation* oper) { + int count = 0; + for (const auto* edge : oper->node.in_edges()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { + ++count; + } + } + return count; +} + +int TF_OperationGetControlInputs(TF_Operation* oper, + TF_Operation** control_inputs, + int max_control_inputs) { + int count = 0; + for (const auto* edge : oper->node.in_edges()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { + if (count < max_control_inputs) { + control_inputs[count] = ToOperation(edge->src()); + } + ++count; + } + } + return count; +} + +int TF_OperationNumControlOutputs(TF_Operation* oper) { + int count = 0; + for (const auto* edge : oper->node.out_edges()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { + ++count; + } + } + return count; +} + +int TF_OperationGetControlOutputs(TF_Operation* oper, + TF_Operation** control_outputs, + int max_control_outputs) { + int count = 0; + for (const auto* edge : oper->node.out_edges()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { + if (count < max_control_outputs) { + control_outputs[count] = ToOperation(edge->dst()); + } + ++count; + } + } + return count; +} + +TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, + const char* attr_name, + TF_Status* status) { + TF_AttrMetadata metadata; + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return metadata; + switch (attr->value_case()) { +#define SINGLE_CASE(kK, attr_type, size_expr) \ + case tensorflow::AttrValue::kK: \ + metadata.is_list = 0; \ + metadata.list_size = -1; \ + metadata.type = attr_type; \ + metadata.total_size = size_expr; \ + break; + + SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); + SINGLE_CASE(kI, TF_ATTR_INT, -1); + SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); + SINGLE_CASE(kB, TF_ATTR_BOOL, -1); + SINGLE_CASE(kType, TF_ATTR_TYPE, -1); + SINGLE_CASE(kShape, TF_ATTR_SHAPE, + attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); + SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); +#undef SINGLE_CASE + + case tensorflow::AttrValue::kList: + metadata.is_list = 1; + metadata.list_size = 0; + metadata.total_size = -1; +#define LIST_CASE(field, attr_type, ...) \ + if (attr->list().field##_size() > 0) { \ + metadata.type = attr_type; \ + metadata.list_size = attr->list().field##_size(); \ + __VA_ARGS__; \ + break; \ + } + + LIST_CASE( + s, TF_ATTR_STRING, metadata.total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { metadata.total_size += attr->list().s(i).size(); }); + LIST_CASE(i, TF_ATTR_INT); + LIST_CASE(f, TF_ATTR_FLOAT); + LIST_CASE(b, TF_ATTR_BOOL); + LIST_CASE(type, TF_ATTR_TYPE); + LIST_CASE( + shape, TF_ATTR_SHAPE, metadata.total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); + LIST_CASE(tensor, TF_ATTR_TENSOR); + LIST_CASE(tensor, TF_ATTR_FUNC); +#undef LIST_CASE + // All lists empty, determine the type from the OpDef. + if (metadata.list_size == 0) { + for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { + const auto& a = oper->node.op_def().attr(i); + if (a.name() != attr_name) continue; + const string& typestr = a.type(); + if (typestr == "list(string)") { + metadata.type = TF_ATTR_STRING; + } else if (typestr == "list(int)") { + metadata.type = TF_ATTR_INT; + } else if (typestr == "list(float)") { + metadata.type = TF_ATTR_FLOAT; + } else if (typestr == "list(bool)") { + metadata.type = TF_ATTR_BOOL; + } else if (typestr == "list(type)") { + metadata.type = TF_ATTR_TYPE; + } else if (typestr == "list(shape)") { + metadata.type = TF_ATTR_SHAPE; + } else if (typestr == "list(tensor)") { + metadata.type = TF_ATTR_TENSOR; + } else if (typestr == "list(func)") { + metadata.type = TF_ATTR_FUNC; + } else { + status->status = InvalidArgument( + "Attribute '", attr_name, + "' has an empty value of an unrecognized type '", typestr, "'"); + return metadata; + } + } + } + break; + + case tensorflow::AttrValue::kPlaceholder: + metadata.is_list = 0; + metadata.list_size = -1; + metadata.type = TF_ATTR_PLACEHOLDER; + metadata.total_size = -1; + break; + + case tensorflow::AttrValue::kFunc: + metadata.is_list = 0; + metadata.list_size = -1; + metadata.type = TF_ATTR_FUNC; + metadata.total_size = -1; + break; + + case tensorflow::AttrValue::VALUE_NOT_SET: + status->status = + InvalidArgument("Attribute '", attr_name, "' has no value set"); + break; + } + return metadata; +} + +void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, + void* value, size_t max_length, + TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + if (attr->value_case() != tensorflow::AttrValue::kS) { + status->status = + InvalidArgument("Attribute '", attr_name, "' is not a string"); + return; + } + if (max_length <= 0) { + return; + } + const auto& s = attr->s(); + std::memcpy(value, s.data(), std::min(s.length(), max_length)); +} + +void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, + void** values, size_t* lengths, + int max_values, void* storage, + size_t storage_size, TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + if (attr->value_case() != tensorflow::AttrValue::kList) { + status->status = + InvalidArgument("Value for '", attr_name, "' is not a list"); + return; + } + const auto len = std::min(max_values, attr->list().s_size()); + char* p = static_cast(storage); + for (int i = 0; i < len; ++i) { + const string& s = attr->list().s(i); + values[i] = p; + lengths[i] = s.size(); + if ((p + s.size()) > (static_cast(storage) + storage_size)) { + status->status = InvalidArgument( + "Not enough storage to hold the requested list of strings"); + return; + } + memcpy(values[i], s.data(), s.size()); + p += s.size(); + } +} + +#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ + void func(TF_Operation* oper, const char* attr_name, c_type* value, \ + TF_Status* status) { \ + cpp_type v; \ + status->status = \ + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ + *value = static_cast(v); \ + } \ + void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ + int max_values, TF_Status* status) { \ + const auto* attr = GetAttrValue(oper, attr_name, status); \ + if (!status->status.ok()) return; \ + if (attr->value_case() != tensorflow::AttrValue::kList) { \ + status->status = \ + InvalidArgument("Value for '", attr_name, "' is not a list."); \ + return; \ + } \ + const auto len = std::min(max_values, attr->list().list_field##_size()); \ + for (int i = 0; i < len; ++i) { \ + values[i] = static_cast(attr->list().list_field(i)); \ + } \ + } +DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); +DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); +DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b); +DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); +#undef DEFINE_GETATTR + +void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, + int64_t* value, int num_dims, TF_Status* status) { + PartialTensorShape shape; + status->status = + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); + if (!status->status.ok()) return; + auto len = std::min(shape.dims(), num_dims); + for (int i = 0; i < len; ++i) { + value[i] = shape.dim_size(i); + } +} + +void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, + int64_t** dims, int* num_dims, int num_shapes, + int64_t* storage, int storage_size, + TF_Status* status) { + std::vector shapes; + status->status = + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); + if (!status->status.ok()) return; + auto len = std::min(static_cast(shapes.size()), num_shapes); + int64_t* p = storage; + int storage_left = storage_size; + for (int i = 0; i < len; ++i) { + // shapes[i].dims() == -1 for shapes with an unknown rank. + int64_t n = shapes[i].dims(); + num_dims[i] = n; + dims[i] = p; + if (n < 0) { + continue; + } + if (storage_left < n) { + status->status = InvalidArgument( + "Not enough storage to hold the requested list of shapes"); + return; + } + storage_left -= n; + for (int j = 0; j < n; ++j, ++p) { + *p = shapes[i].dim_size(j); + } + } +} + +void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, + const char* attr_name, + TF_Buffer* value, TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + if (attr->value_case() != tensorflow::AttrValue::kShape) { + status->status = + InvalidArgument("Value for '", attr_name, "' is not a shape."); + return; + } + status->status = MessageToBuffer(attr->shape(), value); +} + +void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, + const char* attr_name, + TF_Buffer** values, int max_values, + TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + if (attr->value_case() != tensorflow::AttrValue::kList) { + status->status = + InvalidArgument("Value for '", attr_name, "' is not a list"); + return; + } + const auto len = std::min(max_values, attr->list().shape_size()); + for (int i = 0; i < len; ++i) { + values[i] = TF_NewBuffer(); + status->status = MessageToBuffer(attr->list().shape(i), values[i]); + if (!status->status.ok()) { + // Delete everything allocated to far, the operation has failed. + for (int j = 0; j <= i; ++j) { + TF_DeleteBuffer(values[j]); + } + return; + } + } +} + +void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, + TF_Tensor** value, TF_Status* status) { + *value = nullptr; + Tensor t; + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); + if (!status->status.ok()) return; + *value = TF_TensorFromTensor(t, &status->status); +} + +void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, + TF_Tensor** values, int max_values, + TF_Status* status) { + std::vector ts; + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); + if (!status->status.ok()) return; + const auto len = std::min(max_values, static_cast(ts.size())); + for (int i = 0; i < len; ++i) { + values[i] = TF_TensorFromTensor(ts[i], &status->status); + } +} + +void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, + TF_Buffer* output_attr_value, + TF_Status* status) { + const auto* attr = GetAttrValue(oper, attr_name, status); + if (!status->status.ok()) return; + status->status = MessageToBuffer(*attr, output_attr_value); +} + +void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, + TF_Status* status) { + status->status = MessageToBuffer(oper->node.def(), output_node_def); +} + +// TF_Graph functions --------------------------------------------------------- + +TF_Graph::TF_Graph() + : graph(tensorflow::OpRegistry::Global()), + refiner(graph.versions().producer(), graph.op_registry()), + delete_requested(false), + parent(nullptr), + parent_inputs(nullptr) { + // Tell the shape refiner to also run shape inference on functions. + refiner.set_function_library_for_shape_inference(&graph.flib_def()); +} + +TF_Graph* TF_NewGraph() { return new TF_Graph; } + +void TF_DeleteGraph(TF_Graph* g) { + if (g == nullptr) return; + g->mu.lock(); + g->delete_requested = true; + const bool del = g->sessions.empty(); + g->mu.unlock(); + if (del) delete g; +} + +TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) { + mutex_lock l(graph->mu); + auto iter = graph->name_map.find(oper_name); + if (iter == graph->name_map.end()) { + return nullptr; + } else { + return ToOperation(iter->second); + } +} + +TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) { + if (*pos == 0) { + // Advance past the first sentinel nodes in every graph (the source & sink). + *pos += 2; + } else { + // Advance to the next node. + *pos += 1; + } + + mutex_lock l(graph->mu); + while (*pos < static_cast(graph->graph.num_node_ids())) { + Node* node = graph->graph.FindNodeId(*pos); + // FindNodeId() returns nullptr for nodes that have been deleted. + // We aren't currently allowing nodes to be deleted, but it is safer + // to still check. + if (node != nullptr) return ToOperation(node); + *pos += 1; + } + + // No more nodes. + return nullptr; +} + +void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, + TF_Status* status) { + GraphDef def; + { + mutex_lock l(graph->mu); + graph->graph.ToGraphDef(&def); + } + status->status = MessageToBuffer(def, output_graph_def); +} + +void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, + TF_Buffer* output_op_def, TF_Status* status) { + const OpDef* op_def; + { + mutex_lock l(graph->mu); + status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); + if (!status->status.ok()) return; + } + status->status = MessageToBuffer(*op_def, output_op_def); +} + +void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def, + TF_Status* status) { + VersionDef versions; + { + mutex_lock l(graph->mu); + versions = graph->graph.versions(); + } + status->status = MessageToBuffer(versions, output_version_def); +} + +TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { + return new TF_ImportGraphDefOptions; +} +void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) { + delete opts; +} +void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, + const char* prefix) { + opts->opts.prefix = prefix; +} +void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts, + const char* device) { + opts->opts.default_device = device; +} + +void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_names) { + opts->opts.uniquify_names = uniquify_names; +} + +void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_prefix) { + opts->opts.uniquify_prefix = uniquify_prefix; +} + +void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, + const char* src_name, + int src_index, TF_Output dst) { + opts->tensor_id_data.push_back(src_name); + const string& src_name_str = opts->tensor_id_data.back(); + // We don't need to store dst's name in tensor_id_data, since `dst` must + // outlive the ImportGraphDef call. + opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); +} + +void TF_ImportGraphDefOptionsRemapControlDependency( + TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) { + opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] = + TensorId(dst->node.name(), tensorflow::Graph::kControlSlot); +} + +extern void TF_ImportGraphDefOptionsAddControlDependency( + TF_ImportGraphDefOptions* opts, TF_Operation* oper) { + opts->opts.control_dependencies.push_back(oper->node.name()); +} + +void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, + const char* oper_name, int index) { + opts->tensor_id_data.push_back(oper_name); + const string& oper_name_str = opts->tensor_id_data.back(); + opts->opts.return_tensors.emplace_back(oper_name_str, index); +} + +int TF_ImportGraphDefOptionsNumReturnOutputs( + const TF_ImportGraphDefOptions* opts) { + return opts->opts.return_tensors.size(); +} + +void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, + const char* oper_name) { + opts->opts.return_nodes.push_back(oper_name); +} + +int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts) { + return opts->opts.return_nodes.size(); +} + +void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, + int* num_outputs, + TF_Output** outputs) { + *num_outputs = results->return_tensors.size(); + *outputs = results->return_tensors.data(); +} + +void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, + int* num_opers, + TF_Operation*** opers) { + *num_opers = results->return_nodes.size(); + *opers = results->return_nodes.data(); +} + +void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, + const char*** src_names, int** src_indexes) { + *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); + *src_names = results->missing_unused_key_names.data(); + *src_indexes = results->missing_unused_key_indexes.data(); +} + +void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { + delete results; +} + +static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, + const TF_ImportGraphDefOptions* opts, + TF_ImportGraphDefResults* tf_results, + TF_Status* status) + TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + const int last_node_id = graph->graph.num_node_ids(); + tensorflow::ImportGraphDefResults results; + status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, + &graph->refiner, &results); + if (!status->status.ok()) return; + + // Add new nodes to name_map + for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { + auto* node = graph->graph.FindNodeId(i); + if (node != nullptr) graph->name_map[node->name()] = node; + } + + // Populate return_tensors + DCHECK(tf_results->return_tensors.empty()); + tf_results->return_tensors.resize(results.return_tensors.size()); + for (int i = 0; i < results.return_tensors.size(); ++i) { + tf_results->return_tensors[i].oper = + ToOperation(results.return_tensors[i].first); + tf_results->return_tensors[i].index = results.return_tensors[i].second; + } + + // Populate return_nodes + DCHECK(tf_results->return_nodes.empty()); + tf_results->return_nodes.resize(results.return_nodes.size()); + for (int i = 0; i < results.return_nodes.size(); ++i) { + tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); + } + + // Populate missing unused map keys + DCHECK(tf_results->missing_unused_key_names.empty()); + DCHECK(tf_results->missing_unused_key_indexes.empty()); + DCHECK(tf_results->missing_unused_key_names_data.empty()); + + size_t size = results.missing_unused_input_map_keys.size(); + tf_results->missing_unused_key_names.resize(size); + tf_results->missing_unused_key_indexes.resize(size); + + for (int i = 0; i < size; ++i) { + TensorId id = results.missing_unused_input_map_keys[i]; + tf_results->missing_unused_key_names_data.emplace_back(id.first); + tf_results->missing_unused_key_names[i] = + tf_results->missing_unused_key_names_data.back().c_str(); + tf_results->missing_unused_key_indexes[i] = id.second; + } +} + +TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status) { + GraphDef def; + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { + status->status = InvalidArgument("Invalid GraphDef"); + return nullptr; + } + auto results = new TF_ImportGraphDefResults(); + mutex_lock l(graph->mu); + GraphImportGraphDefLocked(graph, def, options, results, status); + if (!status->status.ok()) { + delete results; + return nullptr; + } + return results; +} + +void TF_GraphImportGraphDefWithReturnOutputs( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, + int num_return_outputs, TF_Status* status) { + if (num_return_outputs != options->opts.return_tensors.size()) { + status->status = InvalidArgument("Expected 'num_return_outputs' to be ", + options->opts.return_tensors.size(), + ", got ", num_return_outputs); + return; + } + if (num_return_outputs > 0 && return_outputs == nullptr) { + status->status = InvalidArgument( + "'return_outputs' must be preallocated to length ", num_return_outputs); + return; + } + GraphDef def; + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { + status->status = InvalidArgument("Invalid GraphDef"); + return; + } + TF_ImportGraphDefResults results; + mutex_lock l(graph->mu); + GraphImportGraphDefLocked(graph, def, options, &results, status); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); + memcpy(return_outputs, results.return_tensors.data(), + num_return_outputs * sizeof(TF_Output)); +} + +void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, + TF_Status* status) { + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); + TF_DeleteImportGraphDefResults(results); +} + // While loop functions ------------------------------------------------------- namespace { @@ -480,4 +2161,404 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } +// TF_Session functions ---------------------------------------------- + +TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) + : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {} + +TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, + TF_Status* status) { + Session* session; + status->status = NewSession(opt->options, &session); + if (status->status.ok()) { + TF_Session* new_session = new TF_Session(session, graph); + if (graph != nullptr) { + mutex_lock l(graph->mu); + graph->sessions[new_session] = ""; + } + return new_session; + } else { + DCHECK_EQ(nullptr, session); + return nullptr; + } +} + +TF_Session* TF_LoadSessionFromSavedModel( + const TF_SessionOptions* session_options, const TF_Buffer* run_options, + const char* export_dir, const char* const* tags, int tags_len, + TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) { +// TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring +// that the tensorflow/cc/saved_model:loader build target is mobile friendly. +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Loading a SavedModel is not supported on mobile. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); + return nullptr; +#else + mutex_lock l(graph->mu); + if (!graph->name_map.empty()) { + status->status = InvalidArgument("Graph is non-empty."); + return nullptr; + } + + RunOptions run_options_proto; + if (run_options != nullptr && !run_options_proto.ParseFromArray( + run_options->data, run_options->length)) { + status->status = InvalidArgument("Unparseable RunOptions proto"); + return nullptr; + } + + std::unordered_set tag_set; + for (int i = 0; i < tags_len; i++) { + tag_set.insert(string(tags[i])); + } + + tensorflow::SavedModelBundle bundle; + status->status = + tensorflow::LoadSavedModel(session_options->options, run_options_proto, + export_dir, tag_set, &bundle); + if (!status->status.ok()) return nullptr; + + // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session + // extends using GraphDefs. The Graph instance is different, but equivalent + // to the one used to create the session. + // + // TODO(jhseu): When Session is modified to take Graphs instead of + // GraphDefs, return the Graph generated in LoadSavedModel(). + TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefResults results; + GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), + import_opts, &results, status); + TF_DeleteImportGraphDefOptions(import_opts); + if (!status->status.ok()) return nullptr; + + if (meta_graph_def != nullptr) { + status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); + if (!status->status.ok()) return nullptr; + } + + TF_Session* session = new TF_Session(bundle.session.release(), graph); + + graph->sessions[session] = ""; + session->last_num_graph_nodes = graph->graph.num_node_ids(); + return session; +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +void TF_CloseSession(TF_Session* s, TF_Status* status) { + status->status = s->session->Close(); +} + +void TF_DeleteSession(TF_Session* s, TF_Status* status) { + status->status = Status::OK(); + if (s == nullptr) return; + TF_Graph* const graph = s->graph; + if (graph != nullptr) { + graph->mu.lock(); + graph->sessions.erase(s); + const bool del = graph->delete_requested && graph->sessions.empty(); + graph->mu.unlock(); + if (del) delete graph; + } + delete s->session; + delete s; +} + +void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, + const TF_Output* inputs, TF_Tensor* const* input_values, + int ninputs, const TF_Output* outputs, + TF_Tensor** output_values, int noutputs, + const TF_Operation* const* target_opers, int ntargets, + TF_Buffer* run_metadata, TF_Status* status) { + // TODO(josh11b,mrry): Change Session to be able to use a Graph* + // directly, instead of requiring us to serialize to a GraphDef and + // call Session::Extend(). + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; + } + + TF_Run_Setup(noutputs, output_values, status); + + // Convert from TF_Output and TF_Tensor to a string and Tensor. + std::vector> input_pairs(ninputs); + if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; + for (int i = 0; i < ninputs; ++i) { + input_pairs[i].first = OutputName(inputs[i]); + } + + // Convert from TF_Output to string names. + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = OutputName(outputs[i]); + } + + // Convert from TF_Operation* to string names. + std::vector target_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_names[i] = target_opers[i]->node.name(); + } + + // Actually run. + TF_Run_Helper(session->session, nullptr, run_options, input_pairs, + output_names, output_values, target_names, run_metadata, + status); +} + +void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, + int ninputs, const TF_Output* outputs, int noutputs, + const TF_Operation* const* target_opers, int ntargets, + const char** handle, TF_Status* status) { + *handle = nullptr; + + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; + } + + std::vector input_names(ninputs); + for (int i = 0; i < ninputs; ++i) { + input_names[i] = OutputName(inputs[i]); + } + + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = OutputName(outputs[i]); + } + + std::vector target_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_names[i] = target_opers[i]->node.name(); + } + + string new_handle; + status->status = session->session->PRunSetup(input_names, output_names, + target_names, &new_handle); + if (status->status.ok()) { + char* buf = new char[new_handle.size() + 1]; + memcpy(buf, new_handle.c_str(), new_handle.size() + 1); + *handle = buf; + } +} + +void TF_DeletePRunHandle(const char* handle) { + delete[] handle; + // TODO(suharshs): Free up any resources held by the partial run state. +} + +void TF_SessionPRun(TF_Session* session, const char* handle, + const TF_Output* inputs, TF_Tensor* const* input_values, + int ninputs, const TF_Output* outputs, + TF_Tensor** output_values, int noutputs, + const TF_Operation* const* target_opers, int ntargets, + TF_Status* status) { + // TODO(josh11b,mrry): Change Session to be able to use a Graph* + // directly, instead of requiring us to serialize to a GraphDef and + // call Session::Extend(). + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; + } + + TF_Run_Setup(noutputs, output_values, status); + + // Convert from TF_Output and TF_Tensor to a string and Tensor. + std::vector> input_pairs(ninputs); + if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; + for (int i = 0; i < ninputs; ++i) { + input_pairs[i].first = OutputName(inputs[i]); + } + + // Convert from TF_Output to string names. + std::vector output_names(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names[i] = OutputName(outputs[i]); + } + + // Convert from TF_Operation* to string names. + std::vector target_names(ntargets); + for (int i = 0; i < ntargets; ++i) { + target_names[i] = target_opers[i]->node.name(); + } + + TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names, + output_values, target_names, nullptr, status); +} + +unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, + TF_Tensor** result, TF_Status* status) { + *result = nullptr; + mutex_lock l(graph->mu); + OutputTensor tensor(&output.oper->node, output.index); + bool evaluated; + Tensor result_tensor; + status->status = EvaluateConstantTensor( + tensor, graph->refiner, *graph->graph.op_registry(), + graph->graph.versions().producer(), &evaluated, &result_tensor); + if (evaluated) { + DCHECK(status->status.ok()); + *result = TF_TensorFromTensor(result_tensor, &status->status); + if (!status->status.ok()) evaluated = false; + } + return evaluated; +} + +TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { + tensorflow::OpList op_list; + if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { + status->status = InvalidArgument("Unparseable OpList"); + return nullptr; + } + status->status = Status::OK(); + return new TF_ApiDefMap(op_list); +} + +void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } + +void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, + size_t text_len, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported on mobile."); +#else + mutex_lock l(api_def_map->lock); + if (api_def_map->update_docs_called) { + status->status = FailedPrecondition( + "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " + "called."); + return; + } + string api_def_text(text, text_len); + status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, + size_t name_len, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported on mobile."); + return nullptr; +#else + mutex_lock l(api_def_map->lock); + if (!api_def_map->update_docs_called) { + api_def_map->api_def_map.UpdateDocs(); + api_def_map->update_docs_called = true; + } + string name_str(name, name_len); + const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); + if (api_def == nullptr) { + return nullptr; + } + + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(*api_def, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { + tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(kernel_list, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +} + +TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { + tensorflow::KernelList kernel_list = + tensorflow::GetRegisteredKernelsForOp(name); + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(kernel_list, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +} + +// TF_Server functions ---------------------------------------------- + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +TF_Server::TF_Server(std::unique_ptr server) + : target(server->target()), server(std::move(server)) {} +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + +TF_Server* TF_NewServer(const void* proto, size_t proto_len, + TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported on mobile"); + return nullptr; +#else + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, static_cast(proto_len))) { + status->status = InvalidArgument( + "Could not parse provided bytes into a ServerDef protocol buffer"); + return nullptr; + } + + std::unique_ptr out_server; + status->status = tensorflow::NewServer(server_def, &out_server); + if (!status->status.ok()) return nullptr; + + return new TF_Server(std::move(out_server)); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +void TF_ServerStart(TF_Server* server, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported on mobile"); +#else + status->status = server->server->Start(); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +void TF_ServerStop(TF_Server* server, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported on mobile"); +#else + status->status = server->server->Stop(); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +void TF_ServerJoin(TF_Server* server, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported on mobile"); +#else + status->status = server->server->Join(); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + +const char* TF_ServerTarget(TF_Server* server) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + return nullptr; +#else + return server->target.c_str(); +#endif +} + +void TF_DeleteServer(TF_Server* server) { +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + delete server; +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +} + +void TF_RegisterLogListener(void (*listener)(const char*)) { +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + tensorflow::logging::RegisterListener(listener); +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +} + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index f9942239eec..0c413f6ebae 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -19,21 +19,884 @@ limitations under the License. #include #include -#include "tensorflow/c/c_core_api.h" #include "tensorflow/c/tf_attrtype.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" // -------------------------------------------------------------------------- -// Non-core C API for TensorFlow. +// C API for TensorFlow. // -// This file contains the non-core C API for TensorFlow. Most of the -// API documentation and functionality resides in c_core_api.h. +// The API leans towards simplicity and uniformity instead of convenience +// since most usage will be by language specific wrappers. +// +// Conventions: +// * We use the prefix TF_ for everything in the API. +// * Objects are always passed around as pointers to opaque structs +// and these structs are allocated/deallocated via the API. +// * TF_Status holds error information. It is an object type +// and therefore is passed around as a pointer to an opaque +// struct as mentioned above. +// * Every call that has a TF_Status* argument clears it on success +// and fills it with error info on failure. +// * unsigned char is used for booleans (instead of the 'bool' type). +// In C++ bool is a keyword while in C99 bool is a macro defined +// in stdbool.h. It is possible for the two to be inconsistent. +// For example, neither the C99 nor the C++11 standard force a byte +// size on the bool type, so the macro defined in stdbool.h could +// be inconsistent with the bool keyword in C++. Thus, the use +// of stdbool.h is avoided and unsigned char is used instead. +// * size_t is used to represent byte sizes of objects that are +// materialized in the address space of the calling process. +// * int is used as an index into arrays. +// * Deletion functions are safe to call on nullptr. +// +// Questions left to address: +// * Might at some point need a way for callers to provide their own Env. +// * Maybe add TF_TensorShape that encapsulates dimension info. +// +// Design decisions made: +// * Backing store for tensor memory has an associated deallocation +// function. This deallocation function will point to client code +// for tensors populated by the client. So the client can do things +// like shadowing a numpy array. +// * We do not provide TF_OK since it is not strictly necessary and we +// are not optimizing for convenience. +// * We make assumption that one session has one graph. This should be +// fine since we have the ability to run sub-graphs. +// * We could allow NULL for some arguments (e.g., NULL options arg). +// However since convenience is not a primary goal, we don't do this. +// * Devices are not in this API. Instead, they are created/used internally +// and the API just provides high level controls over the number of +// devices of each type. + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes. +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + #ifdef __cplusplus extern "C" { #endif +// -------------------------------------------------------------------------- +// TF_Version returns a string describing version information of the +// TensorFlow library. TensorFlow using semantic versioning. +TF_CAPI_EXPORT extern const char* TF_Version(void); + +// -------------------------------------------------------------------------- +// TF_Buffer holds a pointer to a block of data and its associated length. +// Typically, the data consists of a serialized protocol buffer, but other data +// may also be held in a buffer. +// +// By default, TF_Buffer itself does not do any memory management of the +// pointed-to block. If need be, users of this struct should specify how to +// deallocate the block by setting the `data_deallocator` function pointer. +typedef struct TF_Buffer { + const void* data; + size_t length; + void (*data_deallocator)(void* data, size_t length); +} TF_Buffer; + +// Makes a copy of the input and sets an appropriate deallocator. Useful for +// passing in read-only, input protobufs. +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, + size_t proto_len); + +// Useful for passing *out* a protobuf. +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); + +TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); + +TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); + +// -------------------------------------------------------------------------- +// TF_SessionOptions holds options that can be passed during session creation. +typedef struct TF_SessionOptions TF_SessionOptions; + +// Return a new options object. +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); + +// Set the target in TF_SessionOptions.options. +// target can be empty, a single entry, or a comma separated list of entries. +// Each entry is in one of the following formats : +// "local" +// ip:port +// host:port +TF_CAPI_EXPORT extern void TF_SetTarget(TF_SessionOptions* options, + const char* target); + +// Set the config in TF_SessionOptions.options. +// config should be a serialized tensorflow.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +TF_CAPI_EXPORT extern void TF_SetConfig(TF_SessionOptions* options, + const void* proto, size_t proto_len, + TF_Status* status); + +// Destroy an options object. +TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); + +// TODO(jeff,sanjay): +// - export functions to set Config fields + +// -------------------------------------------------------------------------- +// The new graph construction API, still under development. + +// Represents a computation graph. Graphs may be shared between sessions. +// Graphs are thread-safe when used as directed below. +typedef struct TF_Graph TF_Graph; + +// Return a new graph object. +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); + +// Destroy an options object. Graph will be deleted once no more +// TFSession's are referencing it. +TF_CAPI_EXPORT extern void TF_DeleteGraph(TF_Graph*); + +// Operation being built. The underlying graph must outlive this. +typedef struct TF_OperationDescription TF_OperationDescription; + +// Operation that has been added to the graph. Valid until the graph is +// deleted -- in particular adding a new operation to the graph does not +// invalidate old TF_Operation* pointers. +typedef struct TF_Operation TF_Operation; + +// Represents a specific input of an operation. +typedef struct TF_Input { + TF_Operation* oper; + int index; // The index of the input within oper. +} TF_Input; + +// Represents a specific output of an operation. +typedef struct TF_Output { + TF_Operation* oper; + int index; // The index of the output within oper. +} TF_Output; + +// TF_Function is a grouping of operations with defined inputs and outputs. +// Once created and added to graphs, functions can be invoked by creating an +// operation whose operation type matches the function name. +typedef struct TF_Function TF_Function; + +// Function definition options. TODO(iga): Define and implement +typedef struct TF_FunctionOptions TF_FunctionOptions; + +// Sets the shape of the Tensor referenced by `output` in `graph` to +// the shape described by `dims` and `num_dims`. +// +// If the number of dimensions is unknown, `num_dims` must be set to +// -1 and `dims` can be null. If a dimension is unknown, the +// corresponding entry in the `dims` array must be -1. +// +// This does not overwrite the existing shape associated with `output`, +// but merges the input shape with the existing shape. For example, +// setting a shape of [-1, 2] with an existing shape [2, -1] would set +// a final shape of [2, 2] based on shape merging semantics. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +// * An invalid shape is being set (e.g., the shape being set +// is incompatible with the existing shape). +TF_CAPI_EXPORT extern void TF_GraphSetTensorShape(TF_Graph* graph, + TF_Output output, + const int64_t* dims, + const int num_dims, + TF_Status* status); + +// Returns the number of dimensions of the Tensor referenced by `output` +// in `graph`. +// +// If the number of dimensions in the shape is unknown, returns -1. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +TF_CAPI_EXPORT extern int TF_GraphGetTensorNumDims(TF_Graph* graph, + TF_Output output, + TF_Status* status); + +// Returns the shape of the Tensor referenced by `output` in `graph` +// into `dims`. `dims` must be an array large enough to hold `num_dims` +// entries (e.g., the return value of TF_GraphGetTensorNumDims). +// +// If the number of dimensions in the shape is unknown or the shape is +// a scalar, `dims` will remain untouched. Otherwise, each element of +// `dims` will be set corresponding to the size of the dimension. An +// unknown dimension is represented by `-1`. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +// * `num_dims` does not match the actual number of dimensions. +TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, + TF_Output output, + int64_t* dims, int num_dims, + TF_Status* status); + +// Operation will only be added to *graph when TF_FinishOperation() is +// called (assuming TF_FinishOperation() does not return an error). +// *graph must not be deleted until after TF_FinishOperation() is +// called. +TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperation( + TF_Graph* graph, const char* op_type, const char* oper_name); + +// Specify the device for `desc`. Defaults to empty, meaning unconstrained. +TF_CAPI_EXPORT extern void TF_SetDevice(TF_OperationDescription* desc, + const char* device); + +// The calls to TF_AddInput and TF_AddInputList must match (in number, +// order, and type) the op declaration. For example, the "Concat" op +// has registration: +// REGISTER_OP("Concat") +// .Input("concat_dim: int32") +// .Input("values: N * T") +// .Output("output: T") +// .Attr("N: int >= 2") +// .Attr("T: type"); +// that defines two inputs, "concat_dim" and "values" (in that order). +// You must use TF_AddInput() for the first input (since it takes a +// single tensor), and TF_AddInputList() for the second input (since +// it takes a list, even if you were to pass a list with a single +// tensor), as in: +// TF_OperationDescription* desc = TF_NewOperation(graph, "Concat", "c"); +// TF_Output concat_dim_input = {...}; +// TF_AddInput(desc, concat_dim_input); +// TF_Output values_inputs[5] = {{...}, ..., {...}}; +// TF_AddInputList(desc, values_inputs, 5); + +// For inputs that take a single tensor. +TF_CAPI_EXPORT extern void TF_AddInput(TF_OperationDescription* desc, + TF_Output input); + +// For inputs that take a list of tensors. +// inputs must point to TF_Output[num_inputs]. +TF_CAPI_EXPORT extern void TF_AddInputList(TF_OperationDescription* desc, + const TF_Output* inputs, + int num_inputs); + +// Call once per control input to `desc`. +TF_CAPI_EXPORT extern void TF_AddControlInput(TF_OperationDescription* desc, + TF_Operation* input); + +// Request that `desc` be co-located on the device where `op` +// is placed. +// +// Use of this is discouraged since the implementation of device placement is +// subject to change. Primarily intended for internal libraries +TF_CAPI_EXPORT extern void TF_ColocateWith(TF_OperationDescription* desc, + TF_Operation* op); + +// Call some TF_SetAttr*() function for every attr that is not +// inferred from an input and doesn't have a default value you wish to +// keep. + +// `value` must point to a string of length `length` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrString(TF_OperationDescription* desc, + const char* attr_name, + const void* value, size_t length); +// `values` and `lengths` each must have lengths `num_values`. +// `values[i]` must point to a string of length `lengths[i]` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrStringList(TF_OperationDescription* desc, + const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrInt(TF_OperationDescription* desc, + const char* attr_name, int64_t value); +TF_CAPI_EXPORT extern void TF_SetAttrIntList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrFloat(TF_OperationDescription* desc, + const char* attr_name, float value); +TF_CAPI_EXPORT extern void TF_SetAttrFloatList(TF_OperationDescription* desc, + const char* attr_name, + const float* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrBool(TF_OperationDescription* desc, + const char* attr_name, + unsigned char value); +TF_CAPI_EXPORT extern void TF_SetAttrBoolList(TF_OperationDescription* desc, + const char* attr_name, + const unsigned char* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrType(TF_OperationDescription* desc, + const char* attr_name, + TF_DataType value); +TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, + const char* attr_name, + const TF_DataType* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc, + const char* attr_name, + const char* placeholder); + +// Set a 'func' attribute to the specified name. +// `value` must point to a string of length `length` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, + const char* attr_name, + const char* value, size_t length); + +// Set `num_dims` to -1 to represent "unknown rank". Otherwise, +// `dims` points to an array of length `num_dims`. `dims[i]` must be +// >= -1, with -1 meaning "unknown dimension". +TF_CAPI_EXPORT extern void TF_SetAttrShape(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* dims, int num_dims); +// `dims` and `num_dims` must point to arrays of length `num_shapes`. +// Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise, +// `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]` +// must be >= -1, with -1 meaning "unknown dimension". +TF_CAPI_EXPORT extern void TF_SetAttrShapeList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* const* dims, + const int* num_dims, + int num_shapes); +// `proto` must point to an array of `proto_len` bytes representing a +// binary-serialized TensorShapeProto. +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProto( + TF_OperationDescription* desc, const char* attr_name, const void* proto, + size_t proto_len, TF_Status* status); +// `protos` and `proto_lens` must point to arrays of length `num_shapes`. +// `protos[i]` must point to an array of `proto_lens[i]` bytes +// representing a binary-serialized TensorShapeProto. +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProtoList( + TF_OperationDescription* desc, const char* attr_name, + const void* const* protos, const size_t* proto_lens, int num_shapes, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_SetAttrTensor(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* value, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorList(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* const* values, + int num_values, + TF_Status* status); + +// `proto` should point to a sequence of bytes of length `proto_len` +// representing a binary serialization of an AttrValue protocol +// buffer. +TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// If this function succeeds: +// * *status is set to an OK value, +// * a TF_Operation is added to the graph, +// * a non-null value pointing to the added operation is returned -- +// this value is valid until the underlying graph is deleted. +// Otherwise: +// * *status is set to a non-OK value, +// * the graph is not modified, +// * a null value is returned. +// In either case, it deletes `desc`. +TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperation( + TF_OperationDescription* desc, TF_Status* status); + +// TF_Operation functions. Operations are immutable once created, so +// these are all query functions. + +TF_CAPI_EXPORT extern const char* TF_OperationName(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationOpType(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationDevice(TF_Operation* oper); + +TF_CAPI_EXPORT extern int TF_OperationNumOutputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationOutputType(TF_Output oper_out); +TF_CAPI_EXPORT extern int TF_OperationOutputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); + +TF_CAPI_EXPORT extern int TF_OperationNumInputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationInputType(TF_Input oper_in); +TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); + +// In this code: +// TF_Output producer = TF_OperationInput(consumer); +// There is an edge from producer.oper's output (given by +// producer.index) to consumer.oper's input (given by consumer.index). +TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in); + +// Get list of all inputs of a specific operation. `inputs` must point to +// an array of length at least `max_inputs` (ideally set to +// TF_OperationNumInputs(oper)). Beware that a concurrent +// modification of the graph can increase the number of inputs of +// an operation. +TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper, + TF_Output* inputs, + int max_inputs); + +// Get the number of current consumers of a specific output of an +// operation. Note that this number can change when new operations +// are added to the graph. +TF_CAPI_EXPORT extern int TF_OperationOutputNumConsumers(TF_Output oper_out); + +// Get list of all current consumers of a specific output of an +// operation. `consumers` must point to an array of length at least +// `max_consumers` (ideally set to +// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent +// modification of the graph can increase the number of consumers of +// an operation. Returns the number of output consumers (should match +// TF_OperationOutputNumConsumers(oper_out)). +TF_CAPI_EXPORT extern int TF_OperationOutputConsumers(TF_Output oper_out, + TF_Input* consumers, + int max_consumers); + +// Get the number of control inputs to an operation. +TF_CAPI_EXPORT extern int TF_OperationNumControlInputs(TF_Operation* oper); + +// Get list of all control inputs to an operation. `control_inputs` must +// point to an array of length `max_control_inputs` (ideally set to +// TF_OperationNumControlInputs(oper)). Returns the number of control +// inputs (should match TF_OperationNumControlInputs(oper)). +TF_CAPI_EXPORT extern int TF_OperationGetControlInputs( + TF_Operation* oper, TF_Operation** control_inputs, int max_control_inputs); + +// Get the number of operations that have `*oper` as a control input. +// Note that this number can change when new operations are added to +// the graph. +TF_CAPI_EXPORT extern int TF_OperationNumControlOutputs(TF_Operation* oper); + +// Get the list of operations that have `*oper` as a control input. +// `control_outputs` must point to an array of length at least +// `max_control_outputs` (ideally set to +// TF_OperationNumControlOutputs(oper)). Beware that a concurrent +// modification of the graph can increase the number of control +// outputs. Returns the number of control outputs (should match +// TF_OperationNumControlOutputs(oper)). +TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( + TF_Operation* oper, TF_Operation** control_outputs, + int max_control_outputs); + +// TF_AttrMetadata describes the value of an attribute on an operation. +typedef struct TF_AttrMetadata { + // A boolean: 1 if the attribute value is a list, 0 otherwise. + unsigned char is_list; + + // Length of the list if is_list is true. Undefined otherwise. + int64_t list_size; + + // Type of elements of the list if is_list != 0. + // Type of the single value stored in the attribute if is_list == 0. + TF_AttrType type; + + // Total size the attribute value. + // The units of total_size depend on is_list and type. + // (1) If type == TF_ATTR_STRING and is_list == 0 + // then total_size is the byte size of the string + // valued attribute. + // (2) If type == TF_ATTR_STRING and is_list == 1 + // then total_size is the cumulative byte size + // of all the strings in the list. + // (3) If type == TF_ATTR_SHAPE and is_list == 0 + // then total_size is the number of dimensions + // of the shape valued attribute, or -1 + // if its rank is unknown. + // (4) If type == TF_ATTR_SHAPE and is_list == 1 + // then total_size is the cumulative number + // of dimensions of all shapes in the list. + // (5) Otherwise, total_size is undefined. + int64_t total_size; +} TF_AttrMetadata; + +// Returns metadata about the value of the attribute `attr_name` of `oper`. +TF_CAPI_EXPORT extern TF_AttrMetadata TF_OperationGetAttrMetadata( + TF_Operation* oper, const char* attr_name, TF_Status* status); + +// Fills in `value` with the value of the attribute `attr_name`. `value` must +// point to an array of length at least `max_length` (ideally set to +// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrString(TF_Operation* oper, + const char* attr_name, + void* value, + size_t max_length, + TF_Status* status); + +// Get the list of strings in the value of the attribute `attr_name`. Fills in +// `values` and `lengths`, each of which must point to an array of length at +// least `max_values`. +// +// The elements of values will point to addresses in `storage` which must be at +// least `storage_size` bytes in length. Ideally, max_values would be set to +// TF_AttrMetadata.list_size and `storage` would be at least +// TF_AttrMetadata.total_size, obtained from TF_OperationGetAttrMetadata(oper, +// attr_name). +// +// Fails if storage_size is too small to hold the requested number of strings. +TF_CAPI_EXPORT extern void TF_OperationGetAttrStringList( + TF_Operation* oper, const char* attr_name, void** values, size_t* lengths, + int max_values, void* storage, size_t storage_size, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrInt(TF_Operation* oper, + const char* attr_name, + int64_t* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrIntList(TF_Operation* oper, + const char* attr_name, + int64_t* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloat(TF_Operation* oper, + const char* attr_name, + float* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloatList(TF_Operation* oper, + const char* attr_name, + float* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrBool(TF_Operation* oper, + const char* attr_name, + unsigned char* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrBoolList(TF_Operation* oper, + const char* attr_name, + unsigned char* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrType(TF_Operation* oper, + const char* attr_name, + TF_DataType* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTypeList(TF_Operation* oper, + const char* attr_name, + TF_DataType* values, + int max_values, + TF_Status* status); + +// Fills in `value` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `num_dims` (ideally set to +// TF_Attr_Meta.size from TF_OperationGetAttrMetadata(oper, attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrShape(TF_Operation* oper, + const char* attr_name, + int64_t* value, + int num_dims, + TF_Status* status); + +// Fills in `dims` with the list of shapes in the attribute `attr_name` of +// `oper` and `num_dims` with the corresponding number of dimensions. On return, +// for every i where `num_dims[i]` > 0, `dims[i]` will be an array of +// `num_dims[i]` elements. A value of -1 for `num_dims[i]` indicates that the +// i-th shape in the list is unknown. +// +// The elements of `dims` will point to addresses in `storage` which must be +// large enough to hold at least `storage_size` int64_ts. Ideally, `num_shapes` +// would be set to TF_AttrMetadata.list_size and `storage_size` would be set to +// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, +// attr_name). +// +// Fails if storage_size is insufficient to hold the requested shapes. +TF_CAPI_EXPORT extern void TF_OperationGetAttrShapeList( + TF_Operation* oper, const char* attr_name, int64_t** dims, int* num_dims, + int num_shapes, int64_t* storage, int storage_size, TF_Status* status); + +// Sets `value` to the binary-serialized TensorShapeProto of the value of +// `attr_name` attribute of `oper`'. +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* value, + TF_Status* status); + +// Fills in `values` with binary-serialized TensorShapeProto values of the +// attribute `attr_name` of `oper`. `values` must point to an array of length at +// least `num_values` (ideally set to TF_AttrMetadata.list_size from +// TF_OperationGetAttrMetadata(oper, attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProtoList( + TF_Operation* oper, const char* attr_name, TF_Buffer** values, + int max_values, TF_Status* status); + +// Gets the TF_Tensor valued attribute of `attr_name` of `oper`. +// +// Allocates a new TF_Tensor which the caller is expected to take +// ownership of (and can deallocate using TF_DeleteTensor). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensor(TF_Operation* oper, + const char* attr_name, + TF_Tensor** value, + TF_Status* status); + +// Fills in `values` with the TF_Tensor values of the attribute `attr_name` of +// `oper`. `values` must point to an array of TF_Tensor* of length at least +// `max_values` (ideally set to TF_AttrMetadata.list_size from +// TF_OperationGetAttrMetadata(oper, attr_name)). +// +// The caller takes ownership of all the non-null TF_Tensor* entries in `values` +// (which can be deleted using TF_DeleteTensor(values[i])). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorList(TF_Operation* oper, + const char* attr_name, + TF_Tensor** values, + int max_values, + TF_Status* status); + +// Sets `output_attr_value` to the binary-serialized AttrValue proto +// representation of the value of the `attr_name` attr of `oper`. +TF_CAPI_EXPORT extern void TF_OperationGetAttrValueProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); + +// Returns the operation in the graph with `oper_name`. Returns nullptr if +// no operation found. +TF_CAPI_EXPORT extern TF_Operation* TF_GraphOperationByName( + TF_Graph* graph, const char* oper_name); + +// Iterate through the operations of a graph. To use: +// size_t pos = 0; +// TF_Operation* oper; +// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { +// DoSomethingWithOperation(oper); +// } +TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, + size_t* pos); + +// Write out a serialized representation of `graph` (as a GraphDef protocol +// message) to `output_graph_def` (allocated by TF_NewBuffer()). +// `output_graph_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. +// +// May fail on very large graphs in the future. +TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, + TF_Buffer* output_graph_def, + TF_Status* status); + +// Returns the serialized OpDef proto with name `op_name`, or a bad status if no +// such op exists. This can return OpDefs of functions copied into the graph. +TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph, + const char* op_name, + TF_Buffer* output_op_def, + TF_Status* status); + +// Returns the serialized VersionDef proto for this graph. +TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, + TF_Buffer* output_version_def, + TF_Status* status); + +// TF_ImportGraphDefOptions holds options that can be passed to +// TF_GraphImportGraphDef. +typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; + +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( + TF_ImportGraphDefOptions* opts); + +// Set the prefix to be prepended to the names of nodes in `graph_def` that will +// be imported into `graph`. `prefix` is copied and has no lifetime +// requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( + TF_ImportGraphDefOptions* opts, const char* prefix); + +// Set the execution device for nodes in `graph_def`. +// Only applies to nodes where a device was not already explicitly specified. +// `device` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetDefaultDevice( + TF_ImportGraphDefOptions* opts, const char* device); + +// Set whether to uniquify imported operation names. If true, imported operation +// names will be modified if their name already exists in the graph. If false, +// conflicting names will be treated as an error. Note that this option has no +// effect if a prefix is set, since the prefix will guarantee all names are +// unique. Defaults to false. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); + +// If true, the specified prefix will be modified if it already exists as an +// operation name or prefix in the graph. If false, a conflicting prefix will be +// treated as an error. This option has no effect if no prefix is specified. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); + +// Set any imported nodes with input `src_name:src_index` to have that input +// replaced with `dst`. `src_name` refers to a node in the graph to be imported, +// `dst` references a node already existing in the graph being imported into. +// `src_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( + TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, + TF_Output dst); + +// Set any imported nodes with control input `src_name` to have that input +// replaced with `dst`. `src_name` refers to a node in the graph to be imported, +// `dst` references an operation already existing in the graph being imported +// into. `src_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( + TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); + +// Cause the imported graph to have a control dependency on `oper`. `oper` +// should exist in the graph being imported into. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( + TF_ImportGraphDefOptions* opts, TF_Operation* oper); + +// Add an output in `graph_def` to be returned via the `return_outputs` output +// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input +// mapping, the corresponding existing tensor in `graph` will be returned. +// `oper_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( + TF_ImportGraphDefOptions* opts, const char* oper_name, int index); + +// Returns the number of return outputs added via +// TF_ImportGraphDefOptionsAddReturnOutput(). +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( + const TF_ImportGraphDefOptions* opts); + +// Add an operation in `graph_def` to be returned via the `return_opers` output +// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no +// lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( + TF_ImportGraphDefOptions* opts, const char* oper_name); + +// Returns the number of return operations added via +// TF_ImportGraphDefOptionsAddReturnOperation(). +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts); + +// TF_ImportGraphDefResults holds results that are generated by +// TF_GraphImportGraphDefWithResults(). +typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults; + +// Fetches the return outputs requested via +// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is +// returned in `num_outputs`. The array of return outputs is returned in +// `outputs`. `*outputs` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs( + TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs); + +// Fetches the return operations requested via +// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched +// operations is returned in `num_opers`. The array of return operations is +// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( + TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); + +// Fetches any input mappings requested via +// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef +// and weren't used as input to any node in the imported graph def. The number +// of fetched mappings is returned in `num_missing_unused_input_mappings`. The +// array of each mapping's source node name is returned in `src_names`, and the +// array of each mapping's source index is returned in `src_indexes`. +// +// `*src_names`, `*src_indexes`, and the memory backing each string in +// `src_names` are owned by and have the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, + const char*** src_names, int** src_indexes); + +// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults( + TF_ImportGraphDefResults* results); + +// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and +// a bad status on error. Otherwise, returns a populated +// TF_ImportGraphDefResults instance. The returned instance must be deleted via +// TF_DeleteImportGraphDefResults(). +TF_CAPI_EXPORT extern TF_ImportGraphDefResults* +TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, + TF_Status* status); + +// Import the graph serialized in `graph_def` into `graph`. +// Convenience function for when only return outputs are needed. +// +// `num_return_outputs` must be the number of return outputs added (i.e. the +// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If +// `num_return_outputs` is non-zero, `return_outputs` must be of length +// `num_return_outputs`. Otherwise it can be null. +TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, + int num_return_outputs, TF_Status* status); + +// Import the graph serialized in `graph_def` into `graph`. +// Convenience function for when no results are needed. +TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status); + +// Adds a copy of function `func` and optionally its gradient function `grad` +// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating +// an operation using the function's name. +// Any changes to `func`/`grad` (including deleting it) done after this method +// returns, won't affect the copy of `func`/`grad` in `g`. +// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no +// effect on them, but can establish the function->gradient relationship +// between them if `func` does not already have a gradient. If `func` already +// has a gradient different from `grad`, an error is returned. +// +// `func` must not be null. +// If `grad` is null and `func` is not in `g`, `func` is added without a +// gradient. +// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop. +// `grad` must have appropriate signature as described in the doc of +// GradientDef in tensorflow/core/framework/function.proto. +// +// If successful, status is set to OK and `func` and `grad` are added to `g`. +// Otherwise, status is set to the encountered error and `g` is unmodified. +TF_CAPI_EXPORT extern void TF_GraphCopyFunction(TF_Graph* g, + const TF_Function* func, + const TF_Function* grad, + TF_Status* status); + +// Returns the number of TF_Functions registered in `g`. +TF_CAPI_EXPORT extern int TF_GraphNumFunctions(TF_Graph* g); + +// Fills in `funcs` with the TF_Function* registered in `g`. +// `funcs` must point to an array of TF_Function* of length at least +// `max_func`. In usual usage, max_func should be set to the result of +// TF_GraphNumFunctions(g). In this case, all the functions registered in +// `g` will be returned. Else, an unspecified subset. +// +// If successful, returns the number of TF_Function* successfully set in +// `funcs` and sets status to OK. The caller takes ownership of +// all the returned TF_Functions. They must be deleted with TF_DeleteFunction. +// On error, returns 0, sets status to the encountered error, and the contents +// of funcs will be undefined. +TF_CAPI_EXPORT extern int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, + int max_func, TF_Status* status); + +// Note: The following function may fail on very large protos in the future. + +TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, + TF_Buffer* output_node_def, + TF_Status* status); + typedef struct TF_WhileParams { // The number of inputs to the while loop, i.e. the number of loop variables. // This is the size of cond_inputs, body_inputs, and body_outputs. @@ -149,6 +1012,558 @@ TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* dx, TF_Status* status, TF_Output* dy); +// Create a TF_Function from a TF_Graph +// +// Params: +// fn_body - the graph whose operations (or subset of whose operations) will be +// converted to TF_Function. +// fn_name - the name of the new TF_Function. Should match the operation +// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. +// If `append_hash_to_fn_name` is false, `fn_name` must be distinct +// from other function and operation names (at least those +// registered in graphs where this function will be used). +// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name +// of the function will be `fn_name` appended with +// '_'. +// If set to 0, the function's name will be `fn_name`. +// num_opers - `num_opers` contains the number of elements in the `opers` array +// or a special value of -1 meaning that no array is given. +// The distinction between an empty array of operations and no +// array of operations is necessary to distinguish the case of +// creating a function with no body (e.g. identity or permutation) +// and the case of creating a function whose body contains all +// the nodes in the graph (except for the automatic skipping, see +// below). +// opers - Array of operations to become the body of the function or null. +// - If no array is given (`num_opers` = -1), all the +// operations in `fn_body` will become part of the function +// except operations referenced in `inputs`. These operations +// must have a single output (these operations are typically +// placeholders created for the sole purpose of representing +// an input. We can relax this constraint if there are +// compelling use cases). +// - If an array is given (`num_opers` >= 0), all operations +// in it will become part of the function. In particular, no +// automatic skipping of dummy input operations is performed. +// ninputs - number of elements in `inputs` array +// inputs - array of TF_Outputs that specify the inputs to the function. +// If `ninputs` is zero (the function takes no inputs), `inputs` +// can be null. The names used for function inputs are normalized +// names of the operations (usually placeholders) pointed to by +// `inputs`. These operation names should start with a letter. +// Normalization will convert all letters to lowercase and +// non-alphanumeric characters to '_' to make resulting names match +// the "[a-z][a-z0-9_]*" pattern for operation argument names. +// `inputs` cannot contain the same tensor twice. +// noutputs - number of elements in `outputs` array +// outputs - array of TF_Outputs that specify the outputs of the function. +// If `noutputs` is zero (the function returns no outputs), `outputs` +// can be null. `outputs` can contain the same tensor more than once. +// output_names - The names of the function's outputs. `output_names` array +// must either have the same length as `outputs` +// (i.e. `noutputs`) or be null. In the former case, +// the names should match the regular expression for ArgDef +// names - "[a-z][a-z0-9_]*". In the latter case, +// names for outputs will be generated automatically. +// opts - various options for the function, e.g. XLA's inlining control. +// description - optional human-readable description of this function. +// status - Set to OK on success and an appropriate error on failure. +// +// Note that when the same TF_Output is listed as both an input and an output, +// the corresponding function's output will equal to this input, +// instead of the original node's output. +// +// Callers must also satisfy the following constraints: +// - `inputs` cannot refer to TF_Outputs within a control flow context. For +// example, one cannot use the output of "switch" node as input. +// - `inputs` and `outputs` cannot have reference types. Reference types are +// not exposed through C API and are being replaced with Resources. We support +// reference types inside function's body to support legacy code. Do not +// use them in new code. +// - Every node in the function's body must have all of its inputs (including +// control inputs). In other words, for every node in the body, each input +// must be either listed in `inputs` or must come from another node in +// the body. In particular, it is an error to have a control edge going from +// a node outside of the body into a node in the body. This applies to control +// edges going from nodes referenced in `inputs` to nodes in the body when +// the former nodes are not in the body (automatically skipped or not +// included in explicitly specified body). +// +// Returns: +// On success, a newly created TF_Function instance. It must be deleted by +// calling TF_DeleteFunction. +// +// On failure, null. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + const TF_FunctionOptions* opts, const char* description, TF_Status* status); + +// Similar to TF_GraphToFunction but allows specifying control outputs of the +// function. +// +// The arguments of TF_GraphToFunction have the same meaning, but the new +// arguments are as follows: +// +// ncontrol_outputs: Number of control outputs of the function. +// control_outputs: vector of TF_Operation objects to be marked as control +// outputs of the function. Operations marked as control outputs are +// guaranteed to execute. +// control_output_names: Optional. If not nullptr, vector of strings, one +// per control output, with their names to be added to the function's +// OpDef. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status); + +// Returns the name of the graph function. +// The return value points to memory that is only usable until the next +// mutation to *func. +TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); + +// Write out a serialized representation of `func` (as a FunctionDef protocol +// message) to `output_func_def` (allocated by TF_NewBuffer()). +// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. +// +// May fail on very large graphs in the future. +TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, + TF_Buffer* output_func_def, + TF_Status* status); + +// Construct and return the function whose FunctionDef representation is +// serialized in `proto`. `proto_len` must equal the number of bytes +// pointed to by `proto`. +// Returns: +// On success, a newly created TF_Function instance. It must be deleted by +// calling TF_DeleteFunction. +// +// On failure, null. +TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( + const void* proto, size_t proto_len, TF_Status* status); + +// Sets function attribute named `attr_name` to value stored in `proto`. +// If this attribute is already set to another value, it is overridden. +// `proto` should point to a sequence of bytes of length `proto_len` +// representing a binary serialization of an AttrValue protocol +// buffer. +TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Sets `output_attr_value` to the binary-serialized AttrValue proto +// representation of the value of the `attr_name` attr of `func`. +// If `attr_name` attribute is not present, status is set to an error. +TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( + TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); + +// Frees the memory used by the `func` struct. +// TF_DeleteFunction is a noop if `func` is null. +// Deleting a function does not remove it from any graphs it was copied to. +TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); + +// Attempts to evaluate `output`. This will only be possible if `output` doesn't +// depend on any graph inputs (this function is safe to call if this isn't the +// case though). +// +// If the evaluation is successful, this function returns true and `output`s +// value is returned in `result`. Otherwise returns false. An error status is +// returned if something is wrong with the graph or input. Note that this may +// return false even if no error status is set. +TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph, + TF_Output output, + TF_Tensor** result, + TF_Status* status); + +// TODO(josh11b): Register OpDef, available to all operations added +// to this graph. + +// -------------------------------------------------------------------------- +// API for driving Graph execution. + +typedef struct TF_Session TF_Session; + +// Return a new execution session with the associated graph, or NULL on +// error. Does not take ownership of any input parameters. +// +// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be +// kept alive for the lifetime of the returned TF_Session. New nodes can still +// be added to `graph` after this call. +TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, + const TF_SessionOptions* opts, + TF_Status* status); + +// This function creates a new TF_Session (which is created on success) using +// `session_options`, and then initializes state (restoring tensors and other +// assets) using `run_options`. +// +// Any NULL and non-NULL value combinations for (`run_options, `meta_graph_def`) +// are valid. +// +// - `export_dir` must be set to the path of the exported SavedModel. +// - `tags` must include the set of tags used to identify one MetaGraphDef in +// the SavedModel. +// - `graph` must be a graph newly allocated with TF_NewGraph(). +// +// If successful, populates `graph` with the contents of the Graph and +// `meta_graph_def` with the MetaGraphDef of the loaded model. +TF_CAPI_EXPORT extern TF_Session* TF_LoadSessionFromSavedModel( + const TF_SessionOptions* session_options, const TF_Buffer* run_options, + const char* export_dir, const char* const* tags, int tags_len, + TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status); + +// Close a session. +// +// Contacts any other processes associated with the session, if applicable. +// May not be called after TF_DeleteSession(). +TF_CAPI_EXPORT extern void TF_CloseSession(TF_Session*, TF_Status* status); + +// Destroy a session object. +// +// Even if error information is recorded in *status, this call discards all +// local resources associated with the session. The session may not be used +// during or after this call (and the session drops its reference to the +// corresponding graph). +TF_CAPI_EXPORT extern void TF_DeleteSession(TF_Session*, TF_Status* status); + +// Run the graph associated with the session starting with the supplied inputs +// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). +// +// Any NULL and non-NULL value combinations for (`run_options`, +// `run_metadata`) are valid. +// +// - `run_options` may be NULL, in which case it will be ignored; or +// non-NULL, in which case it must point to a `TF_Buffer` containing the +// serialized representation of a `RunOptions` protocol buffer. +// - `run_metadata` may be NULL, in which case it will be ignored; or +// non-NULL, in which case it must point to an empty, freshly allocated +// `TF_Buffer` that may be updated to contain the serialized representation +// of a `RunMetadata` protocol buffer. +// +// The caller retains ownership of `input_values` (which can be deleted using +// TF_DeleteTensor). The caller also retains ownership of `run_options` and/or +// `run_metadata` (when not NULL) and should manually call TF_DeleteBuffer on +// them. +// +// On success, the tensors corresponding to outputs[0,noutputs-1] are placed in +// output_values[]. Ownership of the elements of output_values[] is transferred +// to the caller, which must eventually call TF_DeleteTensor on them. +// +// On failure, output_values[] contains NULLs. +TF_CAPI_EXPORT extern void TF_SessionRun( + TF_Session* session, + // RunOptions + const TF_Buffer* run_options, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // RunMetadata + TF_Buffer* run_metadata, + // Output status + TF_Status*); + +// Set up the graph with the intended feeds (inputs) and fetches (outputs) for a +// sequence of partial run calls. +// +// On success, returns a handle that is used for subsequent PRun calls. The +// handle should be deleted with TF_DeletePRunHandle when it is no longer +// needed. +// +// On failure, out_status contains a tensorflow::Status with an error +// message. *handle is set to nullptr. +TF_CAPI_EXPORT extern void TF_SessionPRunSetup( + TF_Session*, + // Input names + const TF_Output* inputs, int ninputs, + // Output names + const TF_Output* outputs, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output handle + const char** handle, + // Output status + TF_Status*); + +// Continue to run the graph with additional feeds and fetches. The +// execution state is uniquely identified by the handle. +TF_CAPI_EXPORT extern void TF_SessionPRun( + TF_Session*, const char* handle, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output status + TF_Status*); + +// Deletes a handle allocated by TF_SessionPRunSetup. +// Once called, no more calls to TF_SessionPRun should be made. +TF_CAPI_EXPORT extern void TF_DeletePRunHandle(const char* handle); + +// -------------------------------------------------------------------------- +// The deprecated session API. Please switch to the above instead of +// TF_ExtendGraph(). This deprecated API can be removed at any time without +// notice. + +typedef struct TF_DeprecatedSession TF_DeprecatedSession; + +TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession( + const TF_SessionOptions*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_Reset(const TF_SessionOptions* opt, + const char** containers, int ncontainers, + TF_Status* status); +// Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and +// add the nodes in that GraphDef to the graph for the session. +// +// Prefer use of TF_Session and TF_GraphImportGraphDef over this. +TF_CAPI_EXPORT extern void TF_ExtendGraph(TF_DeprecatedSession*, + const void* proto, size_t proto_len, + TF_Status*); + +// See TF_SessionRun() above. +TF_CAPI_EXPORT extern void TF_Run(TF_DeprecatedSession*, + const TF_Buffer* run_options, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Buffer* run_metadata, TF_Status*); + +// See TF_SessionPRunSetup() above. +TF_CAPI_EXPORT extern void TF_PRunSetup(TF_DeprecatedSession*, + const char** input_names, int ninputs, + const char** output_names, int noutputs, + const char** target_oper_names, + int ntargets, const char** handle, + TF_Status*); + +// See TF_SessionPRun above. +TF_CAPI_EXPORT extern void TF_PRun(TF_DeprecatedSession*, const char* handle, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Status*); + +typedef struct TF_DeviceList TF_DeviceList; + +// Lists all devices in a TF_Session. +// +// Caller takes ownership of the returned TF_DeviceList* which must eventually +// be freed with a call to TF_DeleteDeviceList. +TF_CAPI_EXPORT extern TF_DeviceList* TF_SessionListDevices(TF_Session* session, + TF_Status* status); + +// Lists all devices in a TF_Session. +// +// Caller takes ownership of the returned TF_DeviceList* which must eventually +// be freed with a call to TF_DeleteDeviceList. +TF_CAPI_EXPORT extern TF_DeviceList* TF_DeprecatedSessionListDevices( + TF_DeprecatedSession* session, TF_Status* status); + +// Deallocates the device list. +TF_CAPI_EXPORT extern void TF_DeleteDeviceList(TF_DeviceList* list); + +// Counts the number of elements in the device list. +TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); + +// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) +// The return value will be a pointer to a null terminated string. The caller +// must not modify or delete the string. It will be deallocated upon a call to +// TF_DeleteDeviceList. +// +// If index is out of bounds, an error code will be set in the status object, +// and a null pointer will be returned. +TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, + int index, + TF_Status* status); + +// Retrieves the type of the device at the given index. +// +// The caller must not modify or delete the string. It will be deallocated upon +// a call to TF_DeleteDeviceList. +// +// If index is out of bounds, an error code will be set in the status object, +// and a null pointer will be returned. +TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, + int index, + TF_Status* status); + +// Retrieve the amount of memory associated with a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and -1 will be returned. +TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( + const TF_DeviceList* list, int index, TF_Status* status); + +// Retrieve the incarnation number of a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and 0 will be returned. +TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation( + const TF_DeviceList* list, int index, TF_Status* status); + +// -------------------------------------------------------------------------- +// Load plugins containing custom ops and kernels + +// TF_Library holds information about dynamically loaded TensorFlow plugins. +typedef struct TF_Library TF_Library; + +// Load the library specified by library_filename and register the ops and +// kernels present in that library. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, place OK in status and return the newly created library handle. +// The caller owns the library handle. +// +// On failure, place an error status in status and return NULL. +TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename, + TF_Status* status); + +// Get the OpList of OpDefs defined in the library pointed by lib_handle. +// +// Returns a TF_Buffer. The memory pointed to by the result is owned by +// lib_handle. The data in the buffer will be the serialized OpList proto for +// ops defined in the library. +TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); + +// Frees the memory associated with the library handle. +// Does NOT unload the library. +TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); + +// Get the OpList of all OpDefs defined in this address space. +// Returns a TF_Buffer, ownership of which is transferred to the caller +// (and can be freed using TF_DeleteBuffer). +// +// The data in the buffer will be the serialized OpList proto for ops registered +// in this address space. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); + +// TF_ApiDefMap encapsulates a collection of API definitions for an operation. +// +// This object maps the name of a TensorFlow operation to a description of the +// API to generate for it, as defined by the ApiDef protocol buffer ( +// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) +// +// The ApiDef messages are typically used to generate convenience wrapper +// functions for TensorFlow operations in various language bindings. +typedef struct TF_ApiDefMap TF_ApiDefMap; + +// Creates a new TF_ApiDefMap instance. +// +// Params: +// op_list_buffer - TF_Buffer instance containing serialized OpList +// protocol buffer. (See +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto +// for the OpList proto definition). +// status - Set to OK on success and an appropriate error on failure. +TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, + TF_Status* status); + +// Deallocates a TF_ApiDefMap. +TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap); + +// Add ApiDefs to the map. +// +// `text` corresponds to a text representation of an ApiDefs protocol message. +// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). +// +// The provided ApiDefs will be merged with existing ones in the map, with +// precedence given to the newly added version in case of conflicts with +// previous calls to TF_ApiDefMapPut. +TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, + const char* text, size_t text_len, + TF_Status* status); + +// Returns a serialized ApiDef protocol buffer for the TensorFlow operation +// named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, + const char* name, + size_t name_len, + TF_Status* status); + +// -------------------------------------------------------------------------- +// Kernel definition information. + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// registered kernels. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// kernels registered for the operation named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( + const char* name, TF_Status* status); + +// -------------------------------------------------------------------------- +// In-process TensorFlow server functionality, for use in distributed training. +// A Server instance encapsulates a set of devices and a Session target that +// can participate in distributed training. A server belongs to a cluster +// (specified by a ClusterSpec), and corresponds to a particular task in a +// named job. The server can communicate with any other server in the same +// cluster. + +// In-process TensorFlow server. +typedef struct TF_Server TF_Server; + +// Creates a new in-process TensorFlow server configured using a serialized +// ServerDef protocol buffer provided via `proto` and `proto_len`. +// +// The server will not serve any requests until TF_ServerStart is invoked. +// The server will stop serving requests once TF_ServerStop or +// TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, + size_t proto_len, + TF_Status* status); + +// Starts an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); + +// Stops an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); + +// Blocks until the server has been successfully stopped (via TF_ServerStop or +// TF_ServerClose). +TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); + +// Returns the target string that can be provided to TF_SetTarget() to connect +// a TF_Session to `server`. +// +// The returned string is valid only until TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); + +// Destroy an in-process TensorFlow server, frees memory. If server is running +// it will be stopped and joined. +TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); + +// Register a listener method that processes printed messages. +// +// If any listeners are registered, the print operator will call all listeners +// with the printed messages and immediately return without writing to the +// logs. +TF_CAPI_EXPORT extern void TF_RegisterLogListener( + void (*listener)(const char*)); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 11fb7705625..32880378c2b 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_C_C_API_INTERNAL_H_ #define TENSORFLOW_C_C_API_INTERNAL_H_ +#include "tensorflow/c/c_api.h" + #include #include #include #include #include -#include "tensorflow/c/c_core_api.h" - // clang-format off // Required for IS_MOBILE_PLATFORM #include "tensorflow/core/platform/platform.h" @@ -217,10 +217,6 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) std::string getTF_OutputDebugString(TF_Output node); -TF_Operation* ToOperation(Node* node); - -TensorId ToTensorId(const TF_Output& output); - } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_core_api.cc b/tensorflow/c/c_core_api.cc deleted file mode 100644 index 67daaef08ac..00000000000 --- a/tensorflow/c/c_core_api.cc +++ /dev/null @@ -1,2193 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/c/c_api.h" - -#include -#include -#include -#include - -#include "absl/strings/match.h" -// Required for IS_MOBILE_PLATFORM -#include "tensorflow/core/platform/platform.h" // NOLINT - -#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -#include "tensorflow/cc/saved_model/loader.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/framework/logging.h" -#include "tensorflow/core/framework/op_gen_lib.h" -#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -#include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/c/tf_tensor.h" -#include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/common_runtime/eval_const_tensor.h" -#include "tensorflow/core/common_runtime/shape_refiner.h" -#include "tensorflow/core/framework/allocation_description.pb.h" -#include "tensorflow/core/framework/kernel_def.pb.h" -#include "tensorflow/core/framework/log_memory.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor.pb.h" // NOLINT -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/graph/validate.h" -#include "tensorflow/core/lib/core/coding.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/public/version.h" - -// The implementation below is at the top level instead of the -// brain namespace because we are defining 'extern "C"' functions. -using tensorflow::AllocationDescription; -using tensorflow::DataType; -using tensorflow::ExtendSessionGraphHelper; -using tensorflow::Graph; -using tensorflow::GraphDef; -using tensorflow::mutex_lock; -using tensorflow::NameRangeMap; -using tensorflow::NameRangesForNode; -using tensorflow::NewSession; -using tensorflow::Node; -using tensorflow::NodeBuilder; -using tensorflow::NodeDef; -using tensorflow::OpDef; -using tensorflow::OpRegistry; -using tensorflow::OutputTensor; -using tensorflow::PartialTensorShape; -using tensorflow::RunMetadata; -using tensorflow::RunOptions; -using tensorflow::Session; -using tensorflow::Status; -using tensorflow::string; -using tensorflow::Tensor; -using tensorflow::TensorBuffer; -using tensorflow::TensorId; -using tensorflow::TensorShape; -using tensorflow::TensorShapeProto; -using tensorflow::ToTensorId; -using tensorflow::VersionDef; -using tensorflow::errors::FailedPrecondition; -using tensorflow::errors::InvalidArgument; -using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; - -extern "C" { - -// -------------------------------------------------------------------------- -const char* TF_Version() { return TF_VERSION_STRING; } - -// -------------------------------------------------------------------------- - -// -------------------------------------------------------------------------- -TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } -void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } - -void TF_SetTarget(TF_SessionOptions* options, const char* target) { - options->options.target = target; -} - -void TF_SetConfig(TF_SessionOptions* options, const void* proto, - size_t proto_len, TF_Status* status) { - if (!options->options.config.ParseFromArray(proto, proto_len)) { - status->status = InvalidArgument("Unparseable ConfigProto"); - } -} -// -------------------------------------------------------------------------- -TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; } - -TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { - void* copy = tensorflow::port::Malloc(proto_len); - memcpy(copy, proto, proto_len); - - TF_Buffer* buf = new TF_Buffer; - buf->data = copy; - buf->length = proto_len; - buf->data_deallocator = [](void* data, size_t length) { - tensorflow::port::Free(data); - }; - return buf; -} - -void TF_DeleteBuffer(TF_Buffer* buffer) { - if (buffer == nullptr) return; - if (buffer->data_deallocator != nullptr) { - (*buffer->data_deallocator)(const_cast(buffer->data), - buffer->length); - } - delete buffer; -} - -TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } - -// -------------------------------------------------------------------------- - -TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, - TF_Status* status) { - Session* session; - status->status = NewSession(opt->options, &session); - if (status->status.ok()) { - return new TF_DeprecatedSession({session}); - } else { - DCHECK_EQ(nullptr, session); - return nullptr; - } -} - -void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { - status->status = s->session->Close(); -} - -void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { - status->status = Status::OK(); - if (s == nullptr) return; - delete s->session; - delete s; -} - -void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto, - size_t proto_len, TF_Status* status) { - GraphDef g; - if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { - status->status = InvalidArgument("Invalid GraphDef"); - return; - } - status->status = s->session->Extend(g); -} - -} // end extern "C" - -// Reset helper for converting character arrays to string vectors. -static void TF_Reset_Helper(const TF_SessionOptions* opt, - const char** containers, int ncontainers, - TF_Status* status) { - std::vector container_names(ncontainers); - for (int i = 0; i < ncontainers; ++i) { - container_names[i] = containers[i]; - } - - status->status = Reset(opt->options, container_names); -} - -extern "C" { - -void TF_Reset(const TF_SessionOptions* opt, const char** containers, - int ncontainers, TF_Status* status) { - TF_Reset_Helper(opt, containers, ncontainers, status); -} - -} // end extern "C" - -namespace tensorflow { - - -Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, - TF_Buffer* out) { - if (out->data != nullptr) { - return InvalidArgument("Passing non-empty TF_Buffer is invalid."); - } - const size_t proto_size = in.ByteSizeLong(); - void* buf = port::Malloc(proto_size); - if (buf == nullptr) { - return tensorflow::errors::ResourceExhausted( - "Failed to allocate memory to serialize message of type '", - in.GetTypeName(), "' and size ", proto_size); - } - if (!in.SerializeWithCachedSizesToArray(static_cast(buf))) { - port::Free(buf); - return InvalidArgument("Unable to serialize ", in.GetTypeName(), - " protocol buffer, perhaps the serialized size (", - proto_size, " bytes) is too large?"); - } - out->data = buf; - out->length = proto_size; - out->data_deallocator = [](void* data, size_t length) { port::Free(data); }; - return Status::OK(); -} - -void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type) { - // If any session has already run this node_id, mark this session as - // unrunnable. - for (auto it : graph->sessions) { - mutex_lock session_lock(it.first->mu); - if (it.first->last_num_graph_nodes > op.node.id()) { - it.second = strings::StrCat( - "Operation '", op.node.DebugString(), "' was changed by ", - mutation_type, - " after it was run by a session. This mutation will have no effect, " - "and will trigger an error in the future. Either don't modify " - "nodes after running them or create a new session."); - } - } -} - -namespace { - -// Helper method that creates a shape handle for a shape described by dims. -tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( - tensorflow::shape_inference::InferenceContext* ic, int num_dims, - const int64_t* dims) { - if (num_dims != -1) { - std::vector dim_vec; - dim_vec.reserve(num_dims); - for (int i = 0; i < num_dims; ++i) { - dim_vec.push_back(ic->MakeDim(dims[i])); - } - return ic->MakeShape(dim_vec); - } else { - return ic->UnknownShape(); - } -} - -} // namespace - -void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, - int num_shapes_and_types, - const int64_t** shapes, - const int* ranks, - const TF_DataType* types, - TF_Status* status) { - Node* node = &output.oper->node; - - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - if (ic == nullptr) { - status->status = - InvalidArgument("Node ", node->name(), " was not found in the graph"); - return; - } - - auto shape_and_type_vec = - std::vector( - num_shapes_and_types); - for (int i = 0; i < num_shapes_and_types; ++i) { - tensorflow::shape_inference::ShapeHandle shape_handle = - ShapeHandleFromDims(ic, ranks[i], shapes[i]); - shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( - shape_handle, static_cast(types[i])); - } - - ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); -} - -// Helpers for loading a TensorFlow plugin (a .so file). -Status LoadLibrary(const char* library_filename, void** result, - const void** buf, size_t* len); - -// TODO(josh11b,mrry): Change Session to be able to use a Graph* -// directly, instead of requiring us to serialize to a GraphDef and -// call Session::Extend(). -bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { - if (session->graph != nullptr) { - // Take the graph lock before the session lock to avoid deadlock. This is - // safe since session->graph does not change. - session->graph->mu.lock(); - mutex_lock session_lock(session->mu); - const Graph& graph = session->graph->graph; - - const string& mutation_warning = session->graph->sessions[session]; - if (!mutation_warning.empty()) { - // TODO(b/74949947): turn this back into an error status - LOG(WARNING) << mutation_warning; - session->graph->sessions[session].clear(); - } - - const auto num_nodes = graph.num_node_ids(); - if (session->last_num_graph_nodes < num_nodes) { - // TODO(nolivia): check this on a subset of the graph instead of all of - // it. - status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; - } - - GraphDef graph_def; - *graph_def.mutable_versions() = graph.versions(); - // Fill graph_def with nodes with ids in the range - // [session->last_num_graph_nodes, num_nodes), that is the nodes - // added since the last TF_SessionRun() call. - for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { - Node* const node = graph.FindNodeId(id); - if (node != nullptr && node->IsOp()) { - NodeDef* const node_def = graph_def.add_node(); - *node_def = node->def(); - } - } - *graph_def.mutable_library() = graph.flib_def().ToProto(); - session->graph->mu.unlock(); - status->status = session->session->Extend(std::move(graph_def)); - if (!status->status.ok()) { - // Contract is we always delete input_values[i]. - return false; - } - // Note: session->session is not modified if Extend() fails, so - // we only set last_num_graph_nodes if it succeeds. - session->last_num_graph_nodes = num_nodes; - } else { - session->graph->mu.unlock(); - } - } - return true; -} - -} // namespace tensorflow - -static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, - TF_Status* status) { - status->status = Status::OK(); - for (int i = 0; i < noutputs; ++i) { - c_outputs[i] = nullptr; - } -} - -static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, - std::vector>* input_pairs, - TF_Status* status) { - const int ninputs = input_pairs->size(); - for (int i = 0; i < ninputs; ++i) { - status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); - if (!status->status.ok()) return false; - } - return true; -} - -// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to -// result in a zero-sized tensor. -static TF_Tensor* EmptyTensor(TF_DataType dtype, - const tensorflow::TensorShape& shape) { - static char empty; - tensorflow::int64 nelems = 1; - std::vector dims; - for (int i = 0; i < shape.dims(); ++i) { - dims.push_back(shape.dim_size(i)); - nelems *= shape.dim_size(i); - } - CHECK_EQ(nelems, 0); - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - return TF_NewTensor( - dtype, reinterpret_cast(dims.data()), shape.dims(), - reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); -} - -static void TF_Run_Helper( - Session* session, const char* handle, const TF_Buffer* run_options, - // Input tensors - const std::vector>& input_pairs, - // Output tensors - const std::vector& output_tensor_names, TF_Tensor** c_outputs, - // Target nodes - const std::vector& target_oper_names, TF_Buffer* run_metadata, - TF_Status* status) { - const int noutputs = output_tensor_names.size(); - std::vector outputs(noutputs); - Status result; - - if (handle == nullptr) { - RunOptions run_options_proto; - if (run_options != nullptr && !run_options_proto.ParseFromArray( - run_options->data, run_options->length)) { - status->status = InvalidArgument("Unparseable RunOptions proto"); - return; - } - if (run_metadata != nullptr && run_metadata->data != nullptr) { - status->status = - InvalidArgument("Passing non-empty run_metadata is invalid."); - return; - } - - RunMetadata run_metadata_proto; - result = session->Run(run_options_proto, input_pairs, output_tensor_names, - target_oper_names, &outputs, &run_metadata_proto); - - // Serialize back to upstream client, who now owns the new buffer - if (run_metadata != nullptr) { - status->status = MessageToBuffer(run_metadata_proto, run_metadata); - if (!status->status.ok()) return; - } - } else { - // NOTE(zongheng): PRun does not support RunOptions yet. - result = session->PRun(handle, input_pairs, output_tensor_names, &outputs); - } - if (!result.ok()) { - status->status = result; - return; - } - - // Store results in c_outputs[] - for (int i = 0; i < noutputs; ++i) { - const Tensor& src = outputs[i]; - if (!src.IsInitialized() || src.NumElements() == 0) { - c_outputs[i] = - EmptyTensor(static_cast(src.dtype()), src.shape()); - continue; - } - c_outputs[i] = TF_TensorFromTensor(src, &status->status); - if (!status->status.ok()) return; - } -} - -extern "C" { - -void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, - // Input tensors - const char** c_input_names, TF_Tensor** c_inputs, int ninputs, - // Output tensors - const char** c_output_names, TF_Tensor** c_outputs, int noutputs, - // Target nodes - const char** c_target_oper_names, int ntargets, - TF_Buffer* run_metadata, TF_Status* status) { - TF_Run_Setup(noutputs, c_outputs, status); - std::vector> input_pairs(ninputs); - if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; - for (int i = 0; i < ninputs; ++i) { - input_pairs[i].first = c_input_names[i]; - } - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = c_output_names[i]; - } - std::vector target_oper_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_oper_names[i] = c_target_oper_names[i]; - } - TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names, - c_outputs, target_oper_names, run_metadata, status); -} - -void TF_PRunSetup(TF_DeprecatedSession* s, - // Input names - const char** c_input_names, int ninputs, - // Output names - const char** c_output_names, int noutputs, - // Target nodes - const char** c_target_oper_names, int ntargets, - const char** handle, TF_Status* status) { - *handle = nullptr; - - std::vector input_names(ninputs); - std::vector output_names(noutputs); - std::vector target_oper_names(ntargets); - for (int i = 0; i < ninputs; ++i) { - input_names[i] = c_input_names[i]; - } - for (int i = 0; i < noutputs; ++i) { - output_names[i] = c_output_names[i]; - } - for (int i = 0; i < ntargets; ++i) { - target_oper_names[i] = c_target_oper_names[i]; - } - string new_handle; - status->status = s->session->PRunSetup(input_names, output_names, - target_oper_names, &new_handle); - if (status->status.ok()) { - char* buf = new char[new_handle.size() + 1]; - memcpy(buf, new_handle.c_str(), new_handle.size() + 1); - *handle = buf; - } -} - -void TF_PRun(TF_DeprecatedSession* s, const char* handle, - // Input tensors - const char** c_input_names, TF_Tensor** c_inputs, int ninputs, - // Output tensors - const char** c_output_names, TF_Tensor** c_outputs, int noutputs, - // Target nodes - const char** c_target_oper_names, int ntargets, - TF_Status* status) { - TF_Run_Setup(noutputs, c_outputs, status); - std::vector> input_pairs(ninputs); - if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; - for (int i = 0; i < ninputs; ++i) { - input_pairs[i].first = c_input_names[i]; - } - - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = c_output_names[i]; - } - std::vector target_oper_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_oper_names[i] = c_target_oper_names[i]; - } - TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names, - c_outputs, target_oper_names, nullptr, status); -} - -TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { - TF_Library* lib_handle = new TF_Library; - status->status = tensorflow::LoadLibrary( - library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, - &lib_handle->op_list.length); - if (!status->status.ok()) { - delete lib_handle; - return nullptr; - } - return lib_handle; -} - -TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } - -void TF_DeleteLibraryHandle(TF_Library* lib_handle) { - if (lib_handle == nullptr) return; - tensorflow::port::Free(const_cast(lib_handle->op_list.data)); - delete lib_handle; -} - -TF_Buffer* TF_GetAllOpList() { - std::vector op_defs; - tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs); - tensorflow::OpList op_list; - for (const auto& op : op_defs) { - *(op_list.add_op()) = op; - } - TF_Buffer* ret = TF_NewBuffer(); - TF_CHECK_OK(MessageToBuffer(op_list, ret)); - return ret; -} - -// -------------------------------------------------------------------------- -// ListDevices & SessionListDevices API - -void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; } - -TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { - TF_DeviceList* response = new TF_DeviceList; - status->status = session->session->ListDevices(&response->response); - return response; -} - -TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session, - TF_Status* status) { - TF_DeviceList* response = new TF_DeviceList; - status->status = session->session->ListDevices(&response->response); - return response; -} - -int TF_DeviceListCount(const TF_DeviceList* list) { - return list->response.size(); -} - -#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \ - return_type method_name(const TF_DeviceList* list, const int index, \ - TF_Status* status) { \ - if (list == nullptr) { \ - status->status = InvalidArgument("list is null!"); \ - return err_val; \ - } \ - if (index < 0 || index >= list->response.size()) { \ - status->status = InvalidArgument("index out of bounds"); \ - return err_val; \ - } \ - status->status = Status::OK(); \ - return list->response[index].accessor; \ - } - -TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); -TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), - nullptr); -TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); -TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0); - -#undef TF_DEVICELIST_METHOD - -} // end extern "C" - -// -------------------------------------------------------------------------- -// New Graph and Session API - -// Helper functions ----------------------------------------------------------- - -namespace tensorflow { - -TF_Operation* ToOperation(Node* node) { - return static_cast(static_cast(node)); -} - -TensorId ToTensorId(const TF_Output& output) { - return TensorId(output.oper->node.name(), output.index); -} - -} // namespace tensorflow - -namespace { - -string OutputName(const TF_Output& output) { - return StrCat(output.oper->node.name(), ":", output.index); -} - -const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, - const char* attr_name, - TF_Status* status) { - const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); - if (attr == nullptr) { - status->status = InvalidArgument("Operation '", oper->node.name(), - "' has no attr named '", attr_name, "'."); - } - return attr; -} - -} // namespace - -// Shape functions ----------------------------------------------------------- - -void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, - const int64_t* dims, const int num_dims, - TF_Status* status) { - Node* node = &output.oper->node; - - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - if (ic == nullptr) { - status->status = - InvalidArgument("Node ", node->name(), " was not found in the graph"); - return; - } - tensorflow::shape_inference::ShapeHandle new_shape = - tensorflow::ShapeHandleFromDims(ic, num_dims, dims); - status->status = graph->refiner.SetShape(node, output.index, new_shape); -} - -int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, - TF_Status* status) { - Node* node = &output.oper->node; - - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - if (ic == nullptr) { - status->status = - InvalidArgument("Node ", node->name(), " was not found in the graph"); - return -1; - } - - tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); - - // Unknown rank means the number of dimensions is -1. - if (!ic->RankKnown(shape)) { - return -1; - } - - return ic->Rank(shape); -} - -void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, - int num_dims, TF_Status* status) { - Node* node = &output.oper->node; - - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - if (ic == nullptr) { - status->status = - InvalidArgument("Node ", node->name(), " was not found in the graph"); - return; - } - - tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); - - int rank = -1; - if (ic->RankKnown(shape)) { - rank = ic->Rank(shape); - } - - if (num_dims != rank) { - status->status = InvalidArgument("Expected rank is ", num_dims, - " but actual rank is ", rank); - return; - } - - if (num_dims == 0) { - // Output shape is a scalar. - return; - } - - // Rank is greater than 0, so fill in the values, if known, and - // -1 for unknown values. - for (int i = 0; i < num_dims; ++i) { - auto dim = ic->Dim(shape, i); - tensorflow::int64 value = -1; - if (ic->ValueKnown(dim)) { - value = ic->Value(dim); - } - dims[i] = value; - } -} - -// TF_OperationDescription functions ------------------------------------------ - -extern "C" { - -static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, - const char* op_type, - const char* oper_name) - TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - return new TF_OperationDescription(graph, op_type, oper_name); -} - -TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type, - const char* oper_name) { - mutex_lock l(graph->mu); - return TF_NewOperationLocked(graph, op_type, oper_name); -} - -void TF_SetDevice(TF_OperationDescription* desc, const char* device) { - desc->node_builder.Device(device); -} - -void TF_AddInput(TF_OperationDescription* desc, TF_Output input) { - desc->node_builder.Input(&input.oper->node, input.index); -} - -void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs, - int num_inputs) { - std::vector input_list; - input_list.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - input_list.emplace_back(&inputs[i].oper->node, inputs[i].index); - } - desc->node_builder.Input(input_list); -} - -void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { - desc->node_builder.ControlInput(&input->node); -} - -void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { - desc->colocation_constraints.emplace( - StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); -} - -void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, - const void* value, size_t length) { - tensorflow::StringPiece s(static_cast(value), length); - desc->node_builder.Attr(attr_name, s); -} - -void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, - const void* const* values, const size_t* lengths, - int num_values) { - if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { - desc->colocation_constraints.clear(); - for (int i = 0; i < num_values; ++i) { - desc->colocation_constraints.emplace(static_cast(values[i]), - lengths[i]); - } - } else { - std::vector v; - v.reserve(num_values); - for (int i = 0; i < num_values; ++i) { - v.emplace_back(static_cast(values[i]), lengths[i]); - } - desc->node_builder.Attr(attr_name, v); - } -} - -void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, - int64_t value) { - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - desc->node_builder.Attr(attr_name, static_cast(value)); -} - -void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name, - const int64_t* values, int num_values) { - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - desc->node_builder.Attr( - attr_name, - ArraySlice( - reinterpret_cast(values), num_values)); -} - -void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name, - float value) { - desc->node_builder.Attr(attr_name, value); -} - -void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name, - const float* values, int num_values) { - desc->node_builder.Attr(attr_name, - ArraySlice(values, num_values)); -} - -void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, - unsigned char value) { - desc->node_builder.Attr(attr_name, static_cast(value)); -} - -void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name, - const unsigned char* values, int num_values) { - std::unique_ptr b(new bool[num_values]); - for (int i = 0; i < num_values; ++i) { - b[i] = values[i]; - } - desc->node_builder.Attr(attr_name, - ArraySlice(b.get(), num_values)); -} - -void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, - TF_DataType value) { - desc->node_builder.Attr(attr_name, static_cast(value)); -} - -void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, - const TF_DataType* values, int num_values) { - desc->node_builder.Attr( - attr_name, ArraySlice( - reinterpret_cast(values), num_values)); -} - -void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name, - const char* placeholder) { - tensorflow::AttrValue attr_value; - attr_value.set_placeholder(placeholder); - desc->node_builder.Attr(attr_name, attr_value); -} - -void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, - const char* value, size_t length) { - tensorflow::NameAttrList func_name; - func_name.set_name(string(value, value + length)); - desc->node_builder.Attr(attr_name, func_name); -} - -void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, - const int64_t* dims, int num_dims) { - PartialTensorShape shape; - if (num_dims >= 0) { - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - shape = PartialTensorShape(ArraySlice( - reinterpret_cast(dims), num_dims)); - } - desc->node_builder.Attr(attr_name, shape); -} - -void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name, - const int64_t* const* dims, const int* num_dims, - int num_shapes) { - std::vector shapes; - shapes.reserve(num_shapes); - for (int i = 0; i < num_shapes; ++i) { - if (num_dims[i] < 0) { - shapes.emplace_back(); - } else { - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - shapes.emplace_back(ArraySlice( - reinterpret_cast(dims[i]), num_dims[i])); - } - } - desc->node_builder.Attr(attr_name, shapes); -} - -void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, - const char* attr_name, const void* proto, - size_t proto_len, TF_Status* status) { - // shape.ParseFromArray takes an int as length, this function takes size_t, - // make sure there is no information loss. - if (proto_len > std::numeric_limits::max()) { - status->status = InvalidArgument( - "proto_len (", proto_len, - " bytes) is too large to be parsed by the protocol buffer library"); - return; - } - TensorShapeProto shape; - if (shape.ParseFromArray(proto, static_cast(proto_len))) { - desc->node_builder.Attr(attr_name, shape); - status->status = Status::OK(); - } else { - status->status = InvalidArgument("Unparseable TensorShapeProto"); - } -} - -void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, - const char* attr_name, - const void* const* protos, - const size_t* proto_lens, int num_shapes, - TF_Status* status) { - std::vector shapes; - shapes.resize(num_shapes); - for (int i = 0; i < num_shapes; ++i) { - if (proto_lens[i] > std::numeric_limits::max()) { - status->status = InvalidArgument( - "length of element ", i, " in the list (", proto_lens[i], - " bytes) is too large to be parsed by the protocol buffer library"); - return; - } - if (!shapes[i].ParseFromArray(protos[i], static_cast(proto_lens[i]))) { - status->status = - InvalidArgument("Unparseable TensorShapeProto at index ", i); - return; - } - } - desc->node_builder.Attr(attr_name, shapes); - status->status = Status::OK(); -} - -void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, - TF_Tensor* value, TF_Status* status) { - Tensor t; - status->status = TF_TensorToTensor(value, &t); - if (status->status.ok()) desc->node_builder.Attr(attr_name, t); -} - -void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, - TF_Tensor* const* values, int num_values, - TF_Status* status) { - status->status = Status::OK(); - std::vector t; - t.reserve(num_values); - - for (int i = 0; i < num_values && status->status.ok(); ++i) { - Tensor v; - status->status = TF_TensorToTensor(values[i], &v); - t.emplace_back(v); - } - - if (status->status.ok()) desc->node_builder.Attr(attr_name, t); -} - -void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, - const void* proto, size_t proto_len, - TF_Status* status) { - tensorflow::AttrValue attr_value; - if (!attr_value.ParseFromArray(proto, proto_len)) { - status->status = InvalidArgument("Unparseable AttrValue proto"); - return; - } - - if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { - if (attr_value.value_case() != tensorflow::AttrValue::kList && - attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) { - status->status = - InvalidArgument("Expected \"list\" field for \"", - tensorflow::kColocationAttrName, "\" attribute"); - return; - } - desc->colocation_constraints.clear(); - for (const string& location : attr_value.list().s()) { - desc->colocation_constraints.insert(location); - } - } else { - desc->node_builder.Attr(attr_name, std::move(attr_value)); - } - - status->status = Status::OK(); -} - -static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, - TF_Status* status) - TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { - Node* ret = nullptr; - - if (desc->graph->name_map.count(desc->node_builder.node_name())) { - status->status = InvalidArgument("Duplicate node name in graph: '", - desc->node_builder.node_name(), "'"); - } else { - if (!desc->colocation_constraints.empty()) { - desc->node_builder.Attr( - tensorflow::kColocationAttrName, - std::vector(desc->colocation_constraints.begin(), - desc->colocation_constraints.end())); - } - status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret, - /*consume=*/true); - - if (status->status.ok()) { - // Run shape inference function for newly added node. - status->status = desc->graph->refiner.AddNode(ret); - } - if (status->status.ok()) { - // Add the node to the name-to-node mapping. - desc->graph->name_map[ret->name()] = ret; - } else if (ret != nullptr) { - desc->graph->graph.RemoveNode(ret); - ret = nullptr; - } - } - - delete desc; - - return ToOperation(ret); -} - -TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, - TF_Status* status) { - mutex_lock l(desc->graph->mu); - return TF_FinishOperationLocked(desc, status); -} - -// TF_Operation functions -// ---------------------------------------------------------- - -const char* TF_OperationName(TF_Operation* oper) { - return oper->node.name().c_str(); -} - -const char* TF_OperationOpType(TF_Operation* oper) { - return oper->node.type_string().c_str(); -} - -const char* TF_OperationDevice(TF_Operation* oper) { - return oper->node.requested_device().c_str(); -} - -int TF_OperationNumOutputs(TF_Operation* oper) { - return oper->node.num_outputs(); -} - -TF_DataType TF_OperationOutputType(TF_Output oper_out) { - return static_cast( - oper_out.oper->node.output_type(oper_out.index)); -} - -int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, - TF_Status* status) { - NameRangeMap name_ranges; - status->status = - NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); - if (!status->status.ok()) return -1; - auto iter = name_ranges.find(arg_name); - if (iter == name_ranges.end()) { - status->status = InvalidArgument("Output arg '", arg_name, "' not found"); - return -1; - } - return iter->second.second - iter->second.first; -} - -int TF_OperationNumInputs(TF_Operation* oper) { - return oper->node.num_inputs(); -} - -TF_DataType TF_OperationInputType(TF_Input oper_in) { - return static_cast(oper_in.oper->node.input_type(oper_in.index)); -} - -int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, - TF_Status* status) { - NameRangeMap name_ranges; - status->status = - NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); - if (!status->status.ok()) return -1; - auto iter = name_ranges.find(arg_name); - if (iter == name_ranges.end()) { - status->status = InvalidArgument("Input arg '", arg_name, "' not found"); - return -1; - } - return iter->second.second - iter->second.first; -} - -TF_Output TF_OperationInput(TF_Input oper_in) { - const tensorflow::Edge* edge; - Status s = oper_in.oper->node.input_edge(oper_in.index, &edge); - if (!s.ok()) { - return {nullptr, -1}; - } - - return {ToOperation(edge->src()), edge->src_output()}; -} - -void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs, - int max_inputs) { - for (auto* edge : oper->node.in_edges()) { - if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) { - inputs[edge->dst_input()] = {ToOperation(edge->src()), - edge->src_output()}; - } - } -} - -int TF_OperationOutputNumConsumers(TF_Output oper_out) { - int count = 0; - for (const auto* edge : oper_out.oper->node.out_edges()) { - if (edge->src_output() == oper_out.index) { - ++count; - } - } - return count; -} - -int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, - int max_consumers) { - int count = 0; - for (const auto* edge : oper_out.oper->node.out_edges()) { - if (edge->src_output() == oper_out.index) { - if (count < max_consumers) { - consumers[count] = {ToOperation(edge->dst()), edge->dst_input()}; - } - ++count; - } - } - return count; -} - -int TF_OperationNumControlInputs(TF_Operation* oper) { - int count = 0; - for (const auto* edge : oper->node.in_edges()) { - if (edge->IsControlEdge() && !edge->src()->IsSource()) { - ++count; - } - } - return count; -} - -int TF_OperationGetControlInputs(TF_Operation* oper, - TF_Operation** control_inputs, - int max_control_inputs) { - int count = 0; - for (const auto* edge : oper->node.in_edges()) { - if (edge->IsControlEdge() && !edge->src()->IsSource()) { - if (count < max_control_inputs) { - control_inputs[count] = ToOperation(edge->src()); - } - ++count; - } - } - return count; -} - -int TF_OperationNumControlOutputs(TF_Operation* oper) { - int count = 0; - for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge() && !edge->dst()->IsSink()) { - ++count; - } - } - return count; -} - -int TF_OperationGetControlOutputs(TF_Operation* oper, - TF_Operation** control_outputs, - int max_control_outputs) { - int count = 0; - for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge() && !edge->dst()->IsSink()) { - if (count < max_control_outputs) { - control_outputs[count] = ToOperation(edge->dst()); - } - ++count; - } - } - return count; -} - -TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, - const char* attr_name, - TF_Status* status) { - TF_AttrMetadata metadata; - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return metadata; - switch (attr->value_case()) { -#define SINGLE_CASE(kK, attr_type, size_expr) \ - case tensorflow::AttrValue::kK: \ - metadata.is_list = 0; \ - metadata.list_size = -1; \ - metadata.type = attr_type; \ - metadata.total_size = size_expr; \ - break; - - SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); - SINGLE_CASE(kI, TF_ATTR_INT, -1); - SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); - SINGLE_CASE(kB, TF_ATTR_BOOL, -1); - SINGLE_CASE(kType, TF_ATTR_TYPE, -1); - SINGLE_CASE(kShape, TF_ATTR_SHAPE, - attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); - SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); -#undef SINGLE_CASE - - case tensorflow::AttrValue::kList: - metadata.is_list = 1; - metadata.list_size = 0; - metadata.total_size = -1; -#define LIST_CASE(field, attr_type, ...) \ - if (attr->list().field##_size() > 0) { \ - metadata.type = attr_type; \ - metadata.list_size = attr->list().field##_size(); \ - __VA_ARGS__; \ - break; \ - } - - LIST_CASE( - s, TF_ATTR_STRING, metadata.total_size = 0; - for (int i = 0; i < attr->list().s_size(); - ++i) { metadata.total_size += attr->list().s(i).size(); }); - LIST_CASE(i, TF_ATTR_INT); - LIST_CASE(f, TF_ATTR_FLOAT); - LIST_CASE(b, TF_ATTR_BOOL); - LIST_CASE(type, TF_ATTR_TYPE); - LIST_CASE( - shape, TF_ATTR_SHAPE, metadata.total_size = 0; - for (int i = 0; i < attr->list().shape_size(); ++i) { - const auto& s = attr->list().shape(i); - metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); - }); - LIST_CASE(tensor, TF_ATTR_TENSOR); - LIST_CASE(tensor, TF_ATTR_FUNC); -#undef LIST_CASE - // All lists empty, determine the type from the OpDef. - if (metadata.list_size == 0) { - for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { - const auto& a = oper->node.op_def().attr(i); - if (a.name() != attr_name) continue; - const string& typestr = a.type(); - if (typestr == "list(string)") { - metadata.type = TF_ATTR_STRING; - } else if (typestr == "list(int)") { - metadata.type = TF_ATTR_INT; - } else if (typestr == "list(float)") { - metadata.type = TF_ATTR_FLOAT; - } else if (typestr == "list(bool)") { - metadata.type = TF_ATTR_BOOL; - } else if (typestr == "list(type)") { - metadata.type = TF_ATTR_TYPE; - } else if (typestr == "list(shape)") { - metadata.type = TF_ATTR_SHAPE; - } else if (typestr == "list(tensor)") { - metadata.type = TF_ATTR_TENSOR; - } else if (typestr == "list(func)") { - metadata.type = TF_ATTR_FUNC; - } else { - status->status = InvalidArgument( - "Attribute '", attr_name, - "' has an empty value of an unrecognized type '", typestr, "'"); - return metadata; - } - } - } - break; - - case tensorflow::AttrValue::kPlaceholder: - metadata.is_list = 0; - metadata.list_size = -1; - metadata.type = TF_ATTR_PLACEHOLDER; - metadata.total_size = -1; - break; - - case tensorflow::AttrValue::kFunc: - metadata.is_list = 0; - metadata.list_size = -1; - metadata.type = TF_ATTR_FUNC; - metadata.total_size = -1; - break; - - case tensorflow::AttrValue::VALUE_NOT_SET: - status->status = - InvalidArgument("Attribute '", attr_name, "' has no value set"); - break; - } - return metadata; -} - -void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, - void* value, size_t max_length, - TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - if (attr->value_case() != tensorflow::AttrValue::kS) { - status->status = - InvalidArgument("Attribute '", attr_name, "' is not a string"); - return; - } - if (max_length <= 0) { - return; - } - const auto& s = attr->s(); - std::memcpy(value, s.data(), std::min(s.length(), max_length)); -} - -void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, - void** values, size_t* lengths, - int max_values, void* storage, - size_t storage_size, TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - if (attr->value_case() != tensorflow::AttrValue::kList) { - status->status = - InvalidArgument("Value for '", attr_name, "' is not a list"); - return; - } - const auto len = std::min(max_values, attr->list().s_size()); - char* p = static_cast(storage); - for (int i = 0; i < len; ++i) { - const string& s = attr->list().s(i); - values[i] = p; - lengths[i] = s.size(); - if ((p + s.size()) > (static_cast(storage) + storage_size)) { - status->status = InvalidArgument( - "Not enough storage to hold the requested list of strings"); - return; - } - memcpy(values[i], s.data(), s.size()); - p += s.size(); - } -} - -#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ - void func(TF_Operation* oper, const char* attr_name, c_type* value, \ - TF_Status* status) { \ - cpp_type v; \ - status->status = \ - tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ - *value = static_cast(v); \ - } \ - void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ - int max_values, TF_Status* status) { \ - const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (!status->status.ok()) return; \ - if (attr->value_case() != tensorflow::AttrValue::kList) { \ - status->status = \ - InvalidArgument("Value for '", attr_name, "' is not a list."); \ - return; \ - } \ - const auto len = std::min(max_values, attr->list().list_field##_size()); \ - for (int i = 0; i < len; ++i) { \ - values[i] = static_cast(attr->list().list_field(i)); \ - } \ - } -DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); -DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); -DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b); -DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); -#undef DEFINE_GETATTR - -void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, - int64_t* value, int num_dims, TF_Status* status) { - PartialTensorShape shape; - status->status = - tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); - if (!status->status.ok()) return; - auto len = std::min(shape.dims(), num_dims); - for (int i = 0; i < len; ++i) { - value[i] = shape.dim_size(i); - } -} - -void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, - int64_t** dims, int* num_dims, int num_shapes, - int64_t* storage, int storage_size, - TF_Status* status) { - std::vector shapes; - status->status = - tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); - if (!status->status.ok()) return; - auto len = std::min(static_cast(shapes.size()), num_shapes); - int64_t* p = storage; - int storage_left = storage_size; - for (int i = 0; i < len; ++i) { - // shapes[i].dims() == -1 for shapes with an unknown rank. - int64_t n = shapes[i].dims(); - num_dims[i] = n; - dims[i] = p; - if (n < 0) { - continue; - } - if (storage_left < n) { - status->status = InvalidArgument( - "Not enough storage to hold the requested list of shapes"); - return; - } - storage_left -= n; - for (int j = 0; j < n; ++j, ++p) { - *p = shapes[i].dim_size(j); - } - } -} - -void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, - const char* attr_name, - TF_Buffer* value, TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - if (attr->value_case() != tensorflow::AttrValue::kShape) { - status->status = - InvalidArgument("Value for '", attr_name, "' is not a shape."); - return; - } - status->status = MessageToBuffer(attr->shape(), value); -} - -void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, - const char* attr_name, - TF_Buffer** values, int max_values, - TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - if (attr->value_case() != tensorflow::AttrValue::kList) { - status->status = - InvalidArgument("Value for '", attr_name, "' is not a list"); - return; - } - const auto len = std::min(max_values, attr->list().shape_size()); - for (int i = 0; i < len; ++i) { - values[i] = TF_NewBuffer(); - status->status = MessageToBuffer(attr->list().shape(i), values[i]); - if (!status->status.ok()) { - // Delete everything allocated to far, the operation has failed. - for (int j = 0; j <= i; ++j) { - TF_DeleteBuffer(values[j]); - } - return; - } - } -} - -void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, - TF_Tensor** value, TF_Status* status) { - *value = nullptr; - Tensor t; - status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); - if (!status->status.ok()) return; - *value = TF_TensorFromTensor(t, &status->status); -} - -void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, - TF_Tensor** values, int max_values, - TF_Status* status) { - std::vector ts; - status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); - if (!status->status.ok()) return; - const auto len = std::min(max_values, static_cast(ts.size())); - for (int i = 0; i < len; ++i) { - values[i] = TF_TensorFromTensor(ts[i], &status->status); - } -} - -void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, - TF_Buffer* output_attr_value, - TF_Status* status) { - const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; - status->status = MessageToBuffer(*attr, output_attr_value); -} - -void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, - TF_Status* status) { - status->status = MessageToBuffer(oper->node.def(), output_node_def); -} - -// TF_Graph functions --------------------------------------------------------- - -TF_Graph::TF_Graph() - : graph(tensorflow::OpRegistry::Global()), - refiner(graph.versions().producer(), graph.op_registry()), - delete_requested(false), - parent(nullptr), - parent_inputs(nullptr) { - // Tell the shape refiner to also run shape inference on functions. - refiner.set_function_library_for_shape_inference(&graph.flib_def()); -} - -TF_Graph* TF_NewGraph() { return new TF_Graph; } - -void TF_DeleteGraph(TF_Graph* g) { - if (g == nullptr) return; - g->mu.lock(); - g->delete_requested = true; - const bool del = g->sessions.empty(); - g->mu.unlock(); - if (del) delete g; -} - -TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) { - mutex_lock l(graph->mu); - auto iter = graph->name_map.find(oper_name); - if (iter == graph->name_map.end()) { - return nullptr; - } else { - return ToOperation(iter->second); - } -} - -TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) { - if (*pos == 0) { - // Advance past the first sentinel nodes in every graph (the source & sink). - *pos += 2; - } else { - // Advance to the next node. - *pos += 1; - } - - mutex_lock l(graph->mu); - while (*pos < static_cast(graph->graph.num_node_ids())) { - Node* node = graph->graph.FindNodeId(*pos); - // FindNodeId() returns nullptr for nodes that have been deleted. - // We aren't currently allowing nodes to be deleted, but it is safer - // to still check. - if (node != nullptr) return ToOperation(node); - *pos += 1; - } - - // No more nodes. - return nullptr; -} - -void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, - TF_Status* status) { - GraphDef def; - { - mutex_lock l(graph->mu); - graph->graph.ToGraphDef(&def); - } - status->status = MessageToBuffer(def, output_graph_def); -} - -void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, - TF_Buffer* output_op_def, TF_Status* status) { - const OpDef* op_def; - { - mutex_lock l(graph->mu); - status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); - if (!status->status.ok()) return; - } - status->status = MessageToBuffer(*op_def, output_op_def); -} - -void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def, - TF_Status* status) { - VersionDef versions; - { - mutex_lock l(graph->mu); - versions = graph->graph.versions(); - } - status->status = MessageToBuffer(versions, output_version_def); -} - -TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { - return new TF_ImportGraphDefOptions; -} -void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) { - delete opts; -} -void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, - const char* prefix) { - opts->opts.prefix = prefix; -} -void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts, - const char* device) { - opts->opts.default_device = device; -} - -void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, - unsigned char uniquify_names) { - opts->opts.uniquify_names = uniquify_names; -} - -void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, - unsigned char uniquify_prefix) { - opts->opts.uniquify_prefix = uniquify_prefix; -} - -void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, - const char* src_name, - int src_index, TF_Output dst) { - opts->tensor_id_data.push_back(src_name); - const string& src_name_str = opts->tensor_id_data.back(); - // We don't need to store dst's name in tensor_id_data, since `dst` must - // outlive the ImportGraphDef call. - opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); -} - -void TF_ImportGraphDefOptionsRemapControlDependency( - TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) { - opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] = - TensorId(dst->node.name(), tensorflow::Graph::kControlSlot); -} - -extern void TF_ImportGraphDefOptionsAddControlDependency( - TF_ImportGraphDefOptions* opts, TF_Operation* oper) { - opts->opts.control_dependencies.push_back(oper->node.name()); -} - -void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, - const char* oper_name, int index) { - opts->tensor_id_data.push_back(oper_name); - const string& oper_name_str = opts->tensor_id_data.back(); - opts->opts.return_tensors.emplace_back(oper_name_str, index); -} - -int TF_ImportGraphDefOptionsNumReturnOutputs( - const TF_ImportGraphDefOptions* opts) { - return opts->opts.return_tensors.size(); -} - -void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, - const char* oper_name) { - opts->opts.return_nodes.push_back(oper_name); -} - -int TF_ImportGraphDefOptionsNumReturnOperations( - const TF_ImportGraphDefOptions* opts) { - return opts->opts.return_nodes.size(); -} - -void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, - int* num_outputs, - TF_Output** outputs) { - *num_outputs = results->return_tensors.size(); - *outputs = results->return_tensors.data(); -} - -void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, - int* num_opers, - TF_Operation*** opers) { - *num_opers = results->return_nodes.size(); - *opers = results->return_nodes.data(); -} - -void TF_ImportGraphDefResultsMissingUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, - const char*** src_names, int** src_indexes) { - *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); - *src_names = results->missing_unused_key_names.data(); - *src_indexes = results->missing_unused_key_indexes.data(); -} - -void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { - delete results; -} - -static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, - const TF_ImportGraphDefOptions* opts, - TF_ImportGraphDefResults* tf_results, - TF_Status* status) - TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - const int last_node_id = graph->graph.num_node_ids(); - tensorflow::ImportGraphDefResults results; - status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, - &graph->refiner, &results); - if (!status->status.ok()) return; - - // Add new nodes to name_map - for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { - auto* node = graph->graph.FindNodeId(i); - if (node != nullptr) graph->name_map[node->name()] = node; - } - - // Populate return_tensors - DCHECK(tf_results->return_tensors.empty()); - tf_results->return_tensors.resize(results.return_tensors.size()); - for (int i = 0; i < results.return_tensors.size(); ++i) { - tf_results->return_tensors[i].oper = - ToOperation(results.return_tensors[i].first); - tf_results->return_tensors[i].index = results.return_tensors[i].second; - } - - // Populate return_nodes - DCHECK(tf_results->return_nodes.empty()); - tf_results->return_nodes.resize(results.return_nodes.size()); - for (int i = 0; i < results.return_nodes.size(); ++i) { - tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); - } - - // Populate missing unused map keys - DCHECK(tf_results->missing_unused_key_names.empty()); - DCHECK(tf_results->missing_unused_key_indexes.empty()); - DCHECK(tf_results->missing_unused_key_names_data.empty()); - - size_t size = results.missing_unused_input_map_keys.size(); - tf_results->missing_unused_key_names.resize(size); - tf_results->missing_unused_key_indexes.resize(size); - - for (int i = 0; i < size; ++i) { - TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.emplace_back(id.first); - tf_results->missing_unused_key_names[i] = - tf_results->missing_unused_key_names_data.back().c_str(); - tf_results->missing_unused_key_indexes[i] = id.second; - } -} - -TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, TF_Status* status) { - GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, - graph_def->length)) { - status->status = InvalidArgument("Invalid GraphDef"); - return nullptr; - } - auto results = new TF_ImportGraphDefResults(); - mutex_lock l(graph->mu); - GraphImportGraphDefLocked(graph, def, options, results, status); - if (!status->status.ok()) { - delete results; - return nullptr; - } - return results; -} - -void TF_GraphImportGraphDefWithReturnOutputs( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, - int num_return_outputs, TF_Status* status) { - if (num_return_outputs != options->opts.return_tensors.size()) { - status->status = InvalidArgument("Expected 'num_return_outputs' to be ", - options->opts.return_tensors.size(), - ", got ", num_return_outputs); - return; - } - if (num_return_outputs > 0 && return_outputs == nullptr) { - status->status = InvalidArgument( - "'return_outputs' must be preallocated to length ", num_return_outputs); - return; - } - GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, - graph_def->length)) { - status->status = InvalidArgument("Invalid GraphDef"); - return; - } - TF_ImportGraphDefResults results; - mutex_lock l(graph->mu); - GraphImportGraphDefLocked(graph, def, options, &results, status); - DCHECK_EQ(results.return_tensors.size(), num_return_outputs); - memcpy(return_outputs, results.return_tensors.data(), - num_return_outputs * sizeof(TF_Output)); -} - -void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, - TF_Status* status) { - TF_ImportGraphDefResults* results = - TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); - TF_DeleteImportGraphDefResults(results); -} - -// TF_Session functions ---------------------------------------------- - -TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {} - -TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, - TF_Status* status) { - Session* session; - status->status = NewSession(opt->options, &session); - if (status->status.ok()) { - TF_Session* new_session = new TF_Session(session, graph); - if (graph != nullptr) { - mutex_lock l(graph->mu); - graph->sessions[new_session] = ""; - } - return new_session; - } else { - DCHECK_EQ(nullptr, session); - return nullptr; - } -} - -TF_Session* TF_LoadSessionFromSavedModel( - const TF_SessionOptions* session_options, const TF_Buffer* run_options, - const char* export_dir, const char* const* tags, int tags_len, - TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) { -// TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring -// that the tensorflow/cc/saved_model:loader build target is mobile friendly. -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Loading a SavedModel is not supported on mobile. File a bug at " - "https://github.com/tensorflow/tensorflow/issues if this feature is " - "important to you"); - return nullptr; -#else - mutex_lock l(graph->mu); - if (!graph->name_map.empty()) { - status->status = InvalidArgument("Graph is non-empty."); - return nullptr; - } - - RunOptions run_options_proto; - if (run_options != nullptr && !run_options_proto.ParseFromArray( - run_options->data, run_options->length)) { - status->status = InvalidArgument("Unparseable RunOptions proto"); - return nullptr; - } - - std::unordered_set tag_set; - for (int i = 0; i < tags_len; i++) { - tag_set.insert(string(tags[i])); - } - - tensorflow::SavedModelBundle bundle; - status->status = - tensorflow::LoadSavedModel(session_options->options, run_options_proto, - export_dir, tag_set, &bundle); - if (!status->status.ok()) return nullptr; - - // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session - // extends using GraphDefs. The Graph instance is different, but equivalent - // to the one used to create the session. - // - // TODO(jhseu): When Session is modified to take Graphs instead of - // GraphDefs, return the Graph generated in LoadSavedModel(). - TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); - TF_ImportGraphDefResults results; - GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), - import_opts, &results, status); - TF_DeleteImportGraphDefOptions(import_opts); - if (!status->status.ok()) return nullptr; - - if (meta_graph_def != nullptr) { - status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); - if (!status->status.ok()) return nullptr; - } - - TF_Session* session = new TF_Session(bundle.session.release(), graph); - - graph->sessions[session] = ""; - session->last_num_graph_nodes = graph->graph.num_node_ids(); - return session; -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -void TF_CloseSession(TF_Session* s, TF_Status* status) { - status->status = s->session->Close(); -} - -void TF_DeleteSession(TF_Session* s, TF_Status* status) { - status->status = Status::OK(); - if (s == nullptr) return; - TF_Graph* const graph = s->graph; - if (graph != nullptr) { - graph->mu.lock(); - graph->sessions.erase(s); - const bool del = graph->delete_requested && graph->sessions.empty(); - graph->mu.unlock(); - if (del) delete graph; - } - delete s->session; - delete s; -} - -void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, - const TF_Output* inputs, TF_Tensor* const* input_values, - int ninputs, const TF_Output* outputs, - TF_Tensor** output_values, int noutputs, - const TF_Operation* const* target_opers, int ntargets, - TF_Buffer* run_metadata, TF_Status* status) { - // TODO(josh11b,mrry): Change Session to be able to use a Graph* - // directly, instead of requiring us to serialize to a GraphDef and - // call Session::Extend(). - if (session->extend_before_run && - !ExtendSessionGraphHelper(session, status)) { - return; - } - - TF_Run_Setup(noutputs, output_values, status); - - // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector> input_pairs(ninputs); - if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; - for (int i = 0; i < ninputs; ++i) { - input_pairs[i].first = OutputName(inputs[i]); - } - - // Convert from TF_Output to string names. - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = OutputName(outputs[i]); - } - - // Convert from TF_Operation* to string names. - std::vector target_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_names[i] = target_opers[i]->node.name(); - } - - // Actually run. - TF_Run_Helper(session->session, nullptr, run_options, input_pairs, - output_names, output_values, target_names, run_metadata, - status); -} - -void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, - int ninputs, const TF_Output* outputs, int noutputs, - const TF_Operation* const* target_opers, int ntargets, - const char** handle, TF_Status* status) { - *handle = nullptr; - - if (session->extend_before_run && - !ExtendSessionGraphHelper(session, status)) { - return; - } - - std::vector input_names(ninputs); - for (int i = 0; i < ninputs; ++i) { - input_names[i] = OutputName(inputs[i]); - } - - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = OutputName(outputs[i]); - } - - std::vector target_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_names[i] = target_opers[i]->node.name(); - } - - string new_handle; - status->status = session->session->PRunSetup(input_names, output_names, - target_names, &new_handle); - if (status->status.ok()) { - char* buf = new char[new_handle.size() + 1]; - memcpy(buf, new_handle.c_str(), new_handle.size() + 1); - *handle = buf; - } -} - -void TF_DeletePRunHandle(const char* handle) { - delete[] handle; - // TODO(suharshs): Free up any resources held by the partial run state. -} - -void TF_SessionPRun(TF_Session* session, const char* handle, - const TF_Output* inputs, TF_Tensor* const* input_values, - int ninputs, const TF_Output* outputs, - TF_Tensor** output_values, int noutputs, - const TF_Operation* const* target_opers, int ntargets, - TF_Status* status) { - // TODO(josh11b,mrry): Change Session to be able to use a Graph* - // directly, instead of requiring us to serialize to a GraphDef and - // call Session::Extend(). - if (session->extend_before_run && - !ExtendSessionGraphHelper(session, status)) { - return; - } - - TF_Run_Setup(noutputs, output_values, status); - - // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector> input_pairs(ninputs); - if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; - for (int i = 0; i < ninputs; ++i) { - input_pairs[i].first = OutputName(inputs[i]); - } - - // Convert from TF_Output to string names. - std::vector output_names(noutputs); - for (int i = 0; i < noutputs; ++i) { - output_names[i] = OutputName(outputs[i]); - } - - // Convert from TF_Operation* to string names. - std::vector target_names(ntargets); - for (int i = 0; i < ntargets; ++i) { - target_names[i] = target_opers[i]->node.name(); - } - - TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names, - output_values, target_names, nullptr, status); -} - -unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, - TF_Tensor** result, TF_Status* status) { - *result = nullptr; - mutex_lock l(graph->mu); - OutputTensor tensor(&output.oper->node, output.index); - bool evaluated; - Tensor result_tensor; - status->status = EvaluateConstantTensor( - tensor, graph->refiner, *graph->graph.op_registry(), - graph->graph.versions().producer(), &evaluated, &result_tensor); - if (evaluated) { - DCHECK(status->status.ok()); - *result = TF_TensorFromTensor(result_tensor, &status->status); - if (!status->status.ok()) evaluated = false; - } - return evaluated; -} - -TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { - tensorflow::OpList op_list; - if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { - status->status = InvalidArgument("Unparseable OpList"); - return nullptr; - } - status->status = Status::OK(); - return new TF_ApiDefMap(op_list); -} - -void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } - -void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, - size_t text_len, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "ApiDefMap is not supported on mobile."); -#else - mutex_lock l(api_def_map->lock); - if (api_def_map->update_docs_called) { - status->status = FailedPrecondition( - "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " - "called."); - return; - } - string api_def_text(text, text_len); - status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, - size_t name_len, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "ApiDefMap is not supported on mobile."); - return nullptr; -#else - mutex_lock l(api_def_map->lock); - if (!api_def_map->update_docs_called) { - api_def_map->api_def_map.UpdateDocs(); - api_def_map->update_docs_called = true; - } - string name_str(name, name_len); - const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); - if (api_def == nullptr) { - return nullptr; - } - - TF_Buffer* ret = TF_NewBuffer(); - status->status = MessageToBuffer(*api_def, ret); - if (!status->status.ok()) { - TF_DeleteBuffer(ret); - return nullptr; - } - return ret; -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { - tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); - TF_Buffer* ret = TF_NewBuffer(); - status->status = MessageToBuffer(kernel_list, ret); - if (!status->status.ok()) { - TF_DeleteBuffer(ret); - return nullptr; - } - return ret; -} - -TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { - tensorflow::KernelList kernel_list = - tensorflow::GetRegisteredKernelsForOp(name); - TF_Buffer* ret = TF_NewBuffer(); - status->status = MessageToBuffer(kernel_list, ret); - if (!status->status.ok()) { - TF_DeleteBuffer(ret); - return nullptr; - } - return ret; -} - -// TF_Server functions ---------------------------------------------- - -#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -TF_Server::TF_Server(std::unique_ptr server) - : target(server->target()), server(std::move(server)) {} -#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) - -TF_Server* TF_NewServer(const void* proto, size_t proto_len, - TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported on mobile"); - return nullptr; -#else - tensorflow::ServerDef server_def; - if (!server_def.ParseFromArray(proto, static_cast(proto_len))) { - status->status = InvalidArgument( - "Could not parse provided bytes into a ServerDef protocol buffer"); - return nullptr; - } - - std::unique_ptr out_server; - status->status = tensorflow::NewServer(server_def, &out_server); - if (!status->status.ok()) return nullptr; - - return new TF_Server(std::move(out_server)); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -void TF_ServerStart(TF_Server* server, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported on mobile"); -#else - status->status = server->server->Start(); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -void TF_ServerStop(TF_Server* server, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported on mobile"); -#else - status->status = server->server->Stop(); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -void TF_ServerJoin(TF_Server* server, TF_Status* status) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported on mobile"); -#else - status->status = server->server->Join(); -#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) -} - -const char* TF_ServerTarget(TF_Server* server) { -#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) - return nullptr; -#else - return server->target.c_str(); -#endif -} - -void TF_DeleteServer(TF_Server* server) { -#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) - delete server; -#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -} - -void TF_RegisterLogListener(void (*listener)(const char*)) { -#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) - tensorflow::logging::RegisterListener(listener); -#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -} - -} // end extern "C" diff --git a/tensorflow/c/c_core_api.h b/tensorflow/c/c_core_api.h deleted file mode 100644 index d3b5447b717..00000000000 --- a/tensorflow/c/c_core_api.h +++ /dev/null @@ -1,1456 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_C_C_CORE_API_H_ -#define TENSORFLOW_C_C_CORE_API_H_ - -#include -#include - -#include "tensorflow/c/tf_attrtype.h" -#include "tensorflow/c/tf_datatype.h" -#include "tensorflow/c/tf_status.h" -#include "tensorflow/c/tf_tensor.h" - -// -------------------------------------------------------------------------- -// C API for TensorFlow. -// -// The API leans towards simplicity and uniformity instead of convenience -// since most usage will be by language specific wrappers. -// -// Conventions: -// * We use the prefix TF_ for everything in the API. -// * Objects are always passed around as pointers to opaque structs -// and these structs are allocated/deallocated via the API. -// * TF_Status holds error information. It is an object type -// and therefore is passed around as a pointer to an opaque -// struct as mentioned above. -// * Every call that has a TF_Status* argument clears it on success -// and fills it with error info on failure. -// * unsigned char is used for booleans (instead of the 'bool' type). -// In C++ bool is a keyword while in C99 bool is a macro defined -// in stdbool.h. It is possible for the two to be inconsistent. -// For example, neither the C99 nor the C++11 standard force a byte -// size on the bool type, so the macro defined in stdbool.h could -// be inconsistent with the bool keyword in C++. Thus, the use -// of stdbool.h is avoided and unsigned char is used instead. -// * size_t is used to represent byte sizes of objects that are -// materialized in the address space of the calling process. -// * int is used as an index into arrays. -// * Deletion functions are safe to call on nullptr. -// -// Questions left to address: -// * Might at some point need a way for callers to provide their own Env. -// * Maybe add TF_TensorShape that encapsulates dimension info. -// -// Design decisions made: -// * Backing store for tensor memory has an associated deallocation -// function. This deallocation function will point to client code -// for tensors populated by the client. So the client can do things -// like shadowing a numpy array. -// * We do not provide TF_OK since it is not strictly necessary and we -// are not optimizing for convenience. -// * We make assumption that one session has one graph. This should be -// fine since we have the ability to run sub-graphs. -// * We could allow NULL for some arguments (e.g., NULL options arg). -// However since convenience is not a primary goal, we don't do this. -// * Devices are not in this API. Instead, they are created/used internally -// and the API just provides high level controls over the number of -// devices of each type. - -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - -#ifdef __cplusplus -extern "C" { -#endif - -// -------------------------------------------------------------------------- -// TF_Version returns a string describing version information of the -// TensorFlow library. TensorFlow using semantic versioning. -TF_CAPI_EXPORT extern const char* TF_Version(void); - -// -------------------------------------------------------------------------- -// TF_Buffer holds a pointer to a block of data and its associated length. -// Typically, the data consists of a serialized protocol buffer, but other data -// may also be held in a buffer. -// -// By default, TF_Buffer itself does not do any memory management of the -// pointed-to block. If need be, users of this struct should specify how to -// deallocate the block by setting the `data_deallocator` function pointer. -typedef struct TF_Buffer { - const void* data; - size_t length; - void (*data_deallocator)(void* data, size_t length); -} TF_Buffer; - -// Makes a copy of the input and sets an appropriate deallocator. Useful for -// passing in read-only, input protobufs. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, - size_t proto_len); - -// Useful for passing *out* a protobuf. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); - -TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); - -TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); - -// -------------------------------------------------------------------------- -// TF_SessionOptions holds options that can be passed during session creation. -typedef struct TF_SessionOptions TF_SessionOptions; - -// Return a new options object. -TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); - -// Set the target in TF_SessionOptions.options. -// target can be empty, a single entry, or a comma separated list of entries. -// Each entry is in one of the following formats : -// "local" -// ip:port -// host:port -TF_CAPI_EXPORT extern void TF_SetTarget(TF_SessionOptions* options, - const char* target); - -// Set the config in TF_SessionOptions.options. -// config should be a serialized tensorflow.ConfigProto proto. -// If config was not parsed successfully as a ConfigProto, record the -// error information in *status. -TF_CAPI_EXPORT extern void TF_SetConfig(TF_SessionOptions* options, - const void* proto, size_t proto_len, - TF_Status* status); - -// Destroy an options object. -TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); - -// TODO(jeff,sanjay): -// - export functions to set Config fields - -// -------------------------------------------------------------------------- -// The new graph construction API, still under development. - -// Represents a computation graph. Graphs may be shared between sessions. -// Graphs are thread-safe when used as directed below. -typedef struct TF_Graph TF_Graph; - -// Return a new graph object. -TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); - -// Destroy an options object. Graph will be deleted once no more -// TFSession's are referencing it. -TF_CAPI_EXPORT extern void TF_DeleteGraph(TF_Graph*); - -// Operation being built. The underlying graph must outlive this. -typedef struct TF_OperationDescription TF_OperationDescription; - -// Operation that has been added to the graph. Valid until the graph is -// deleted -- in particular adding a new operation to the graph does not -// invalidate old TF_Operation* pointers. -typedef struct TF_Operation TF_Operation; - -// Represents a specific input of an operation. -typedef struct TF_Input { - TF_Operation* oper; - int index; // The index of the input within oper. -} TF_Input; - -// Represents a specific output of an operation. -typedef struct TF_Output { - TF_Operation* oper; - int index; // The index of the output within oper. -} TF_Output; - -// TF_Function is a grouping of operations with defined inputs and outputs. -// Once created and added to graphs, functions can be invoked by creating an -// operation whose operation type matches the function name. -typedef struct TF_Function TF_Function; - -// Function definition options. TODO(iga): Define and implement -typedef struct TF_FunctionOptions TF_FunctionOptions; - -// Sets the shape of the Tensor referenced by `output` in `graph` to -// the shape described by `dims` and `num_dims`. -// -// If the number of dimensions is unknown, `num_dims` must be set to -// -1 and `dims` can be null. If a dimension is unknown, the -// corresponding entry in the `dims` array must be -1. -// -// This does not overwrite the existing shape associated with `output`, -// but merges the input shape with the existing shape. For example, -// setting a shape of [-1, 2] with an existing shape [2, -1] would set -// a final shape of [2, 2] based on shape merging semantics. -// -// Returns an error into `status` if: -// * `output` is not in `graph`. -// * An invalid shape is being set (e.g., the shape being set -// is incompatible with the existing shape). -TF_CAPI_EXPORT extern void TF_GraphSetTensorShape(TF_Graph* graph, - TF_Output output, - const int64_t* dims, - const int num_dims, - TF_Status* status); - -// Returns the number of dimensions of the Tensor referenced by `output` -// in `graph`. -// -// If the number of dimensions in the shape is unknown, returns -1. -// -// Returns an error into `status` if: -// * `output` is not in `graph`. -TF_CAPI_EXPORT extern int TF_GraphGetTensorNumDims(TF_Graph* graph, - TF_Output output, - TF_Status* status); - -// Returns the shape of the Tensor referenced by `output` in `graph` -// into `dims`. `dims` must be an array large enough to hold `num_dims` -// entries (e.g., the return value of TF_GraphGetTensorNumDims). -// -// If the number of dimensions in the shape is unknown or the shape is -// a scalar, `dims` will remain untouched. Otherwise, each element of -// `dims` will be set corresponding to the size of the dimension. An -// unknown dimension is represented by `-1`. -// -// Returns an error into `status` if: -// * `output` is not in `graph`. -// * `num_dims` does not match the actual number of dimensions. -TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, - TF_Output output, - int64_t* dims, int num_dims, - TF_Status* status); - -// Operation will only be added to *graph when TF_FinishOperation() is -// called (assuming TF_FinishOperation() does not return an error). -// *graph must not be deleted until after TF_FinishOperation() is -// called. -TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperation( - TF_Graph* graph, const char* op_type, const char* oper_name); - -// Specify the device for `desc`. Defaults to empty, meaning unconstrained. -TF_CAPI_EXPORT extern void TF_SetDevice(TF_OperationDescription* desc, - const char* device); - -// The calls to TF_AddInput and TF_AddInputList must match (in number, -// order, and type) the op declaration. For example, the "Concat" op -// has registration: -// REGISTER_OP("Concat") -// .Input("concat_dim: int32") -// .Input("values: N * T") -// .Output("output: T") -// .Attr("N: int >= 2") -// .Attr("T: type"); -// that defines two inputs, "concat_dim" and "values" (in that order). -// You must use TF_AddInput() for the first input (since it takes a -// single tensor), and TF_AddInputList() for the second input (since -// it takes a list, even if you were to pass a list with a single -// tensor), as in: -// TF_OperationDescription* desc = TF_NewOperation(graph, "Concat", "c"); -// TF_Output concat_dim_input = {...}; -// TF_AddInput(desc, concat_dim_input); -// TF_Output values_inputs[5] = {{...}, ..., {...}}; -// TF_AddInputList(desc, values_inputs, 5); - -// For inputs that take a single tensor. -TF_CAPI_EXPORT extern void TF_AddInput(TF_OperationDescription* desc, - TF_Output input); - -// For inputs that take a list of tensors. -// inputs must point to TF_Output[num_inputs]. -TF_CAPI_EXPORT extern void TF_AddInputList(TF_OperationDescription* desc, - const TF_Output* inputs, - int num_inputs); - -// Call once per control input to `desc`. -TF_CAPI_EXPORT extern void TF_AddControlInput(TF_OperationDescription* desc, - TF_Operation* input); - -// Request that `desc` be co-located on the device where `op` -// is placed. -// -// Use of this is discouraged since the implementation of device placement is -// subject to change. Primarily intended for internal libraries -TF_CAPI_EXPORT extern void TF_ColocateWith(TF_OperationDescription* desc, - TF_Operation* op); - -// Call some TF_SetAttr*() function for every attr that is not -// inferred from an input and doesn't have a default value you wish to -// keep. - -// `value` must point to a string of length `length` bytes. -TF_CAPI_EXPORT extern void TF_SetAttrString(TF_OperationDescription* desc, - const char* attr_name, - const void* value, size_t length); -// `values` and `lengths` each must have lengths `num_values`. -// `values[i]` must point to a string of length `lengths[i]` bytes. -TF_CAPI_EXPORT extern void TF_SetAttrStringList(TF_OperationDescription* desc, - const char* attr_name, - const void* const* values, - const size_t* lengths, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrInt(TF_OperationDescription* desc, - const char* attr_name, int64_t value); -TF_CAPI_EXPORT extern void TF_SetAttrIntList(TF_OperationDescription* desc, - const char* attr_name, - const int64_t* values, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrFloat(TF_OperationDescription* desc, - const char* attr_name, float value); -TF_CAPI_EXPORT extern void TF_SetAttrFloatList(TF_OperationDescription* desc, - const char* attr_name, - const float* values, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrBool(TF_OperationDescription* desc, - const char* attr_name, - unsigned char value); -TF_CAPI_EXPORT extern void TF_SetAttrBoolList(TF_OperationDescription* desc, - const char* attr_name, - const unsigned char* values, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrType(TF_OperationDescription* desc, - const char* attr_name, - TF_DataType value); -TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, - const char* attr_name, - const TF_DataType* values, - int num_values); -TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc, - const char* attr_name, - const char* placeholder); - -// Set a 'func' attribute to the specified name. -// `value` must point to a string of length `length` bytes. -TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, - const char* attr_name, - const char* value, size_t length); - -// Set `num_dims` to -1 to represent "unknown rank". Otherwise, -// `dims` points to an array of length `num_dims`. `dims[i]` must be -// >= -1, with -1 meaning "unknown dimension". -TF_CAPI_EXPORT extern void TF_SetAttrShape(TF_OperationDescription* desc, - const char* attr_name, - const int64_t* dims, int num_dims); -// `dims` and `num_dims` must point to arrays of length `num_shapes`. -// Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise, -// `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]` -// must be >= -1, with -1 meaning "unknown dimension". -TF_CAPI_EXPORT extern void TF_SetAttrShapeList(TF_OperationDescription* desc, - const char* attr_name, - const int64_t* const* dims, - const int* num_dims, - int num_shapes); -// `proto` must point to an array of `proto_len` bytes representing a -// binary-serialized TensorShapeProto. -TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProto( - TF_OperationDescription* desc, const char* attr_name, const void* proto, - size_t proto_len, TF_Status* status); -// `protos` and `proto_lens` must point to arrays of length `num_shapes`. -// `protos[i]` must point to an array of `proto_lens[i]` bytes -// representing a binary-serialized TensorShapeProto. -TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProtoList( - TF_OperationDescription* desc, const char* attr_name, - const void* const* protos, const size_t* proto_lens, int num_shapes, - TF_Status* status); - -TF_CAPI_EXPORT extern void TF_SetAttrTensor(TF_OperationDescription* desc, - const char* attr_name, - TF_Tensor* value, - TF_Status* status); -TF_CAPI_EXPORT extern void TF_SetAttrTensorList(TF_OperationDescription* desc, - const char* attr_name, - TF_Tensor* const* values, - int num_values, - TF_Status* status); - -// `proto` should point to a sequence of bytes of length `proto_len` -// representing a binary serialization of an AttrValue protocol -// buffer. -TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, - const char* attr_name, - const void* proto, - size_t proto_len, - TF_Status* status); - -// If this function succeeds: -// * *status is set to an OK value, -// * a TF_Operation is added to the graph, -// * a non-null value pointing to the added operation is returned -- -// this value is valid until the underlying graph is deleted. -// Otherwise: -// * *status is set to a non-OK value, -// * the graph is not modified, -// * a null value is returned. -// In either case, it deletes `desc`. -TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperation( - TF_OperationDescription* desc, TF_Status* status); - -// TF_Operation functions. Operations are immutable once created, so -// these are all query functions. - -TF_CAPI_EXPORT extern const char* TF_OperationName(TF_Operation* oper); -TF_CAPI_EXPORT extern const char* TF_OperationOpType(TF_Operation* oper); -TF_CAPI_EXPORT extern const char* TF_OperationDevice(TF_Operation* oper); - -TF_CAPI_EXPORT extern int TF_OperationNumOutputs(TF_Operation* oper); -TF_CAPI_EXPORT extern TF_DataType TF_OperationOutputType(TF_Output oper_out); -TF_CAPI_EXPORT extern int TF_OperationOutputListLength(TF_Operation* oper, - const char* arg_name, - TF_Status* status); - -TF_CAPI_EXPORT extern int TF_OperationNumInputs(TF_Operation* oper); -TF_CAPI_EXPORT extern TF_DataType TF_OperationInputType(TF_Input oper_in); -TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper, - const char* arg_name, - TF_Status* status); - -// In this code: -// TF_Output producer = TF_OperationInput(consumer); -// There is an edge from producer.oper's output (given by -// producer.index) to consumer.oper's input (given by consumer.index). -TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in); - -// Get list of all inputs of a specific operation. `inputs` must point to -// an array of length at least `max_inputs` (ideally set to -// TF_OperationNumInputs(oper)). Beware that a concurrent -// modification of the graph can increase the number of inputs of -// an operation. -TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper, - TF_Output* inputs, - int max_inputs); - -// Get the number of current consumers of a specific output of an -// operation. Note that this number can change when new operations -// are added to the graph. -TF_CAPI_EXPORT extern int TF_OperationOutputNumConsumers(TF_Output oper_out); - -// Get list of all current consumers of a specific output of an -// operation. `consumers` must point to an array of length at least -// `max_consumers` (ideally set to -// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent -// modification of the graph can increase the number of consumers of -// an operation. Returns the number of output consumers (should match -// TF_OperationOutputNumConsumers(oper_out)). -TF_CAPI_EXPORT extern int TF_OperationOutputConsumers(TF_Output oper_out, - TF_Input* consumers, - int max_consumers); - -// Get the number of control inputs to an operation. -TF_CAPI_EXPORT extern int TF_OperationNumControlInputs(TF_Operation* oper); - -// Get list of all control inputs to an operation. `control_inputs` must -// point to an array of length `max_control_inputs` (ideally set to -// TF_OperationNumControlInputs(oper)). Returns the number of control -// inputs (should match TF_OperationNumControlInputs(oper)). -TF_CAPI_EXPORT extern int TF_OperationGetControlInputs( - TF_Operation* oper, TF_Operation** control_inputs, int max_control_inputs); - -// Get the number of operations that have `*oper` as a control input. -// Note that this number can change when new operations are added to -// the graph. -TF_CAPI_EXPORT extern int TF_OperationNumControlOutputs(TF_Operation* oper); - -// Get the list of operations that have `*oper` as a control input. -// `control_outputs` must point to an array of length at least -// `max_control_outputs` (ideally set to -// TF_OperationNumControlOutputs(oper)). Beware that a concurrent -// modification of the graph can increase the number of control -// outputs. Returns the number of control outputs (should match -// TF_OperationNumControlOutputs(oper)). -TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( - TF_Operation* oper, TF_Operation** control_outputs, - int max_control_outputs); - -// TF_AttrMetadata describes the value of an attribute on an operation. -typedef struct TF_AttrMetadata { - // A boolean: 1 if the attribute value is a list, 0 otherwise. - unsigned char is_list; - - // Length of the list if is_list is true. Undefined otherwise. - int64_t list_size; - - // Type of elements of the list if is_list != 0. - // Type of the single value stored in the attribute if is_list == 0. - TF_AttrType type; - - // Total size the attribute value. - // The units of total_size depend on is_list and type. - // (1) If type == TF_ATTR_STRING and is_list == 0 - // then total_size is the byte size of the string - // valued attribute. - // (2) If type == TF_ATTR_STRING and is_list == 1 - // then total_size is the cumulative byte size - // of all the strings in the list. - // (3) If type == TF_ATTR_SHAPE and is_list == 0 - // then total_size is the number of dimensions - // of the shape valued attribute, or -1 - // if its rank is unknown. - // (4) If type == TF_ATTR_SHAPE and is_list == 1 - // then total_size is the cumulative number - // of dimensions of all shapes in the list. - // (5) Otherwise, total_size is undefined. - int64_t total_size; -} TF_AttrMetadata; - -// Returns metadata about the value of the attribute `attr_name` of `oper`. -TF_CAPI_EXPORT extern TF_AttrMetadata TF_OperationGetAttrMetadata( - TF_Operation* oper, const char* attr_name, TF_Status* status); - -// Fills in `value` with the value of the attribute `attr_name`. `value` must -// point to an array of length at least `max_length` (ideally set to -// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrString(TF_Operation* oper, - const char* attr_name, - void* value, - size_t max_length, - TF_Status* status); - -// Get the list of strings in the value of the attribute `attr_name`. Fills in -// `values` and `lengths`, each of which must point to an array of length at -// least `max_values`. -// -// The elements of values will point to addresses in `storage` which must be at -// least `storage_size` bytes in length. Ideally, max_values would be set to -// TF_AttrMetadata.list_size and `storage` would be at least -// TF_AttrMetadata.total_size, obtained from TF_OperationGetAttrMetadata(oper, -// attr_name). -// -// Fails if storage_size is too small to hold the requested number of strings. -TF_CAPI_EXPORT extern void TF_OperationGetAttrStringList( - TF_Operation* oper, const char* attr_name, void** values, size_t* lengths, - int max_values, void* storage, size_t storage_size, TF_Status* status); - -TF_CAPI_EXPORT extern void TF_OperationGetAttrInt(TF_Operation* oper, - const char* attr_name, - int64_t* value, - TF_Status* status); - -// Fills in `values` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `max_values` (ideally set -// TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrIntList(TF_Operation* oper, - const char* attr_name, - int64_t* values, - int max_values, - TF_Status* status); - -TF_CAPI_EXPORT extern void TF_OperationGetAttrFloat(TF_Operation* oper, - const char* attr_name, - float* value, - TF_Status* status); - -// Fills in `values` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `max_values` (ideally set -// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrFloatList(TF_Operation* oper, - const char* attr_name, - float* values, - int max_values, - TF_Status* status); - -TF_CAPI_EXPORT extern void TF_OperationGetAttrBool(TF_Operation* oper, - const char* attr_name, - unsigned char* value, - TF_Status* status); - -// Fills in `values` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `max_values` (ideally set -// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrBoolList(TF_Operation* oper, - const char* attr_name, - unsigned char* values, - int max_values, - TF_Status* status); - -TF_CAPI_EXPORT extern void TF_OperationGetAttrType(TF_Operation* oper, - const char* attr_name, - TF_DataType* value, - TF_Status* status); - -// Fills in `values` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `max_values` (ideally set -// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, -// attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrTypeList(TF_Operation* oper, - const char* attr_name, - TF_DataType* values, - int max_values, - TF_Status* status); - -// Fills in `value` with the value of the attribute `attr_name` of `oper`. -// `values` must point to an array of length at least `num_dims` (ideally set to -// TF_Attr_Meta.size from TF_OperationGetAttrMetadata(oper, attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrShape(TF_Operation* oper, - const char* attr_name, - int64_t* value, - int num_dims, - TF_Status* status); - -// Fills in `dims` with the list of shapes in the attribute `attr_name` of -// `oper` and `num_dims` with the corresponding number of dimensions. On return, -// for every i where `num_dims[i]` > 0, `dims[i]` will be an array of -// `num_dims[i]` elements. A value of -1 for `num_dims[i]` indicates that the -// i-th shape in the list is unknown. -// -// The elements of `dims` will point to addresses in `storage` which must be -// large enough to hold at least `storage_size` int64_ts. Ideally, `num_shapes` -// would be set to TF_AttrMetadata.list_size and `storage_size` would be set to -// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, -// attr_name). -// -// Fails if storage_size is insufficient to hold the requested shapes. -TF_CAPI_EXPORT extern void TF_OperationGetAttrShapeList( - TF_Operation* oper, const char* attr_name, int64_t** dims, int* num_dims, - int num_shapes, int64_t* storage, int storage_size, TF_Status* status); - -// Sets `value` to the binary-serialized TensorShapeProto of the value of -// `attr_name` attribute of `oper`'. -TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProto( - TF_Operation* oper, const char* attr_name, TF_Buffer* value, - TF_Status* status); - -// Fills in `values` with binary-serialized TensorShapeProto values of the -// attribute `attr_name` of `oper`. `values` must point to an array of length at -// least `num_values` (ideally set to TF_AttrMetadata.list_size from -// TF_OperationGetAttrMetadata(oper, attr_name)). -TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProtoList( - TF_Operation* oper, const char* attr_name, TF_Buffer** values, - int max_values, TF_Status* status); - -// Gets the TF_Tensor valued attribute of `attr_name` of `oper`. -// -// Allocates a new TF_Tensor which the caller is expected to take -// ownership of (and can deallocate using TF_DeleteTensor). -TF_CAPI_EXPORT extern void TF_OperationGetAttrTensor(TF_Operation* oper, - const char* attr_name, - TF_Tensor** value, - TF_Status* status); - -// Fills in `values` with the TF_Tensor values of the attribute `attr_name` of -// `oper`. `values` must point to an array of TF_Tensor* of length at least -// `max_values` (ideally set to TF_AttrMetadata.list_size from -// TF_OperationGetAttrMetadata(oper, attr_name)). -// -// The caller takes ownership of all the non-null TF_Tensor* entries in `values` -// (which can be deleted using TF_DeleteTensor(values[i])). -TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorList(TF_Operation* oper, - const char* attr_name, - TF_Tensor** values, - int max_values, - TF_Status* status); - -// Sets `output_attr_value` to the binary-serialized AttrValue proto -// representation of the value of the `attr_name` attr of `oper`. -TF_CAPI_EXPORT extern void TF_OperationGetAttrValueProto( - TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, - TF_Status* status); - -// Returns the operation in the graph with `oper_name`. Returns nullptr if -// no operation found. -TF_CAPI_EXPORT extern TF_Operation* TF_GraphOperationByName( - TF_Graph* graph, const char* oper_name); - -// Iterate through the operations of a graph. To use: -// size_t pos = 0; -// TF_Operation* oper; -// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { -// DoSomethingWithOperation(oper); -// } -TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, - size_t* pos); - -// Write out a serialized representation of `graph` (as a GraphDef protocol -// message) to `output_graph_def` (allocated by TF_NewBuffer()). -// `output_graph_def`'s underlying buffer will be freed when TF_DeleteBuffer() -// is called. -// -// May fail on very large graphs in the future. -TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, - TF_Buffer* output_graph_def, - TF_Status* status); - -// Returns the serialized OpDef proto with name `op_name`, or a bad status if no -// such op exists. This can return OpDefs of functions copied into the graph. -TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph, - const char* op_name, - TF_Buffer* output_op_def, - TF_Status* status); - -// Returns the serialized VersionDef proto for this graph. -TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, - TF_Buffer* output_version_def, - TF_Status* status); - -// TF_ImportGraphDefOptions holds options that can be passed to -// TF_GraphImportGraphDef. -typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; - -TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( - void); -TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( - TF_ImportGraphDefOptions* opts); - -// Set the prefix to be prepended to the names of nodes in `graph_def` that will -// be imported into `graph`. `prefix` is copied and has no lifetime -// requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( - TF_ImportGraphDefOptions* opts, const char* prefix); - -// Set the execution device for nodes in `graph_def`. -// Only applies to nodes where a device was not already explicitly specified. -// `device` is copied and has no lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetDefaultDevice( - TF_ImportGraphDefOptions* opts, const char* device); - -// Set whether to uniquify imported operation names. If true, imported operation -// names will be modified if their name already exists in the graph. If false, -// conflicting names will be treated as an error. Note that this option has no -// effect if a prefix is set, since the prefix will guarantee all names are -// unique. Defaults to false. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( - TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); - -// If true, the specified prefix will be modified if it already exists as an -// operation name or prefix in the graph. If false, a conflicting prefix will be -// treated as an error. This option has no effect if no prefix is specified. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( - TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); - -// Set any imported nodes with input `src_name:src_index` to have that input -// replaced with `dst`. `src_name` refers to a node in the graph to be imported, -// `dst` references a node already existing in the graph being imported into. -// `src_name` is copied and has no lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( - TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, - TF_Output dst); - -// Set any imported nodes with control input `src_name` to have that input -// replaced with `dst`. `src_name` refers to a node in the graph to be imported, -// `dst` references an operation already existing in the graph being imported -// into. `src_name` is copied and has no lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( - TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); - -// Cause the imported graph to have a control dependency on `oper`. `oper` -// should exist in the graph being imported into. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( - TF_ImportGraphDefOptions* opts, TF_Operation* oper); - -// Add an output in `graph_def` to be returned via the `return_outputs` output -// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input -// mapping, the corresponding existing tensor in `graph` will be returned. -// `oper_name` is copied and has no lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( - TF_ImportGraphDefOptions* opts, const char* oper_name, int index); - -// Returns the number of return outputs added via -// TF_ImportGraphDefOptionsAddReturnOutput(). -TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( - const TF_ImportGraphDefOptions* opts); - -// Add an operation in `graph_def` to be returned via the `return_opers` output -// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no -// lifetime requirements. -TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( - TF_ImportGraphDefOptions* opts, const char* oper_name); - -// Returns the number of return operations added via -// TF_ImportGraphDefOptionsAddReturnOperation(). -TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations( - const TF_ImportGraphDefOptions* opts); - -// TF_ImportGraphDefResults holds results that are generated by -// TF_GraphImportGraphDefWithResults(). -typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults; - -// Fetches the return outputs requested via -// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is -// returned in `num_outputs`. The array of return outputs is returned in -// `outputs`. `*outputs` is owned by and has the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs( - TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs); - -// Fetches the return operations requested via -// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched -// operations is returned in `num_opers`. The array of return operations is -// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( - TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); - -// Fetches any input mappings requested via -// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef -// and weren't used as input to any node in the imported graph def. The number -// of fetched mappings is returned in `num_missing_unused_input_mappings`. The -// array of each mapping's source node name is returned in `src_names`, and the -// array of each mapping's source index is returned in `src_indexes`. -// -// `*src_names`, `*src_indexes`, and the memory backing each string in -// `src_names` are owned by and have the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, - const char*** src_names, int** src_indexes); - -// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). -TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults( - TF_ImportGraphDefResults* results); - -// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and -// a bad status on error. Otherwise, returns a populated -// TF_ImportGraphDefResults instance. The returned instance must be deleted via -// TF_DeleteImportGraphDefResults(). -TF_CAPI_EXPORT extern TF_ImportGraphDefResults* -TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, - TF_Status* status); - -// Import the graph serialized in `graph_def` into `graph`. -// Convenience function for when only return outputs are needed. -// -// `num_return_outputs` must be the number of return outputs added (i.e. the -// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If -// `num_return_outputs` is non-zero, `return_outputs` must be of length -// `num_return_outputs`. Otherwise it can be null. -TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, - int num_return_outputs, TF_Status* status); - -// Import the graph serialized in `graph_def` into `graph`. -// Convenience function for when no results are needed. -TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, TF_Status* status); - -// Adds a copy of function `func` and optionally its gradient function `grad` -// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating -// an operation using the function's name. -// Any changes to `func`/`grad` (including deleting it) done after this method -// returns, won't affect the copy of `func`/`grad` in `g`. -// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no -// effect on them, but can establish the function->gradient relationship -// between them if `func` does not already have a gradient. If `func` already -// has a gradient different from `grad`, an error is returned. -// -// `func` must not be null. -// If `grad` is null and `func` is not in `g`, `func` is added without a -// gradient. -// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop. -// `grad` must have appropriate signature as described in the doc of -// GradientDef in tensorflow/core/framework/function.proto. -// -// If successful, status is set to OK and `func` and `grad` are added to `g`. -// Otherwise, status is set to the encountered error and `g` is unmodified. -TF_CAPI_EXPORT extern void TF_GraphCopyFunction(TF_Graph* g, - const TF_Function* func, - const TF_Function* grad, - TF_Status* status); - -// Returns the number of TF_Functions registered in `g`. -TF_CAPI_EXPORT extern int TF_GraphNumFunctions(TF_Graph* g); - -// Fills in `funcs` with the TF_Function* registered in `g`. -// `funcs` must point to an array of TF_Function* of length at least -// `max_func`. In usual usage, max_func should be set to the result of -// TF_GraphNumFunctions(g). In this case, all the functions registered in -// `g` will be returned. Else, an unspecified subset. -// -// If successful, returns the number of TF_Function* successfully set in -// `funcs` and sets status to OK. The caller takes ownership of -// all the returned TF_Functions. They must be deleted with TF_DeleteFunction. -// On error, returns 0, sets status to the encountered error, and the contents -// of funcs will be undefined. -TF_CAPI_EXPORT extern int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, - int max_func, TF_Status* status); - -// Note: The following function may fail on very large protos in the future. - -TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, - TF_Buffer* output_node_def, - TF_Status* status); - -// Create a TF_Function from a TF_Graph -// -// Params: -// fn_body - the graph whose operations (or subset of whose operations) will be -// converted to TF_Function. -// fn_name - the name of the new TF_Function. Should match the operation -// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. -// If `append_hash_to_fn_name` is false, `fn_name` must be distinct -// from other function and operation names (at least those -// registered in graphs where this function will be used). -// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name -// of the function will be `fn_name` appended with -// '_'. -// If set to 0, the function's name will be `fn_name`. -// num_opers - `num_opers` contains the number of elements in the `opers` array -// or a special value of -1 meaning that no array is given. -// The distinction between an empty array of operations and no -// array of operations is necessary to distinguish the case of -// creating a function with no body (e.g. identity or permutation) -// and the case of creating a function whose body contains all -// the nodes in the graph (except for the automatic skipping, see -// below). -// opers - Array of operations to become the body of the function or null. -// - If no array is given (`num_opers` = -1), all the -// operations in `fn_body` will become part of the function -// except operations referenced in `inputs`. These operations -// must have a single output (these operations are typically -// placeholders created for the sole purpose of representing -// an input. We can relax this constraint if there are -// compelling use cases). -// - If an array is given (`num_opers` >= 0), all operations -// in it will become part of the function. In particular, no -// automatic skipping of dummy input operations is performed. -// ninputs - number of elements in `inputs` array -// inputs - array of TF_Outputs that specify the inputs to the function. -// If `ninputs` is zero (the function takes no inputs), `inputs` -// can be null. The names used for function inputs are normalized -// names of the operations (usually placeholders) pointed to by -// `inputs`. These operation names should start with a letter. -// Normalization will convert all letters to lowercase and -// non-alphanumeric characters to '_' to make resulting names match -// the "[a-z][a-z0-9_]*" pattern for operation argument names. -// `inputs` cannot contain the same tensor twice. -// noutputs - number of elements in `outputs` array -// outputs - array of TF_Outputs that specify the outputs of the function. -// If `noutputs` is zero (the function returns no outputs), `outputs` -// can be null. `outputs` can contain the same tensor more than once. -// output_names - The names of the function's outputs. `output_names` array -// must either have the same length as `outputs` -// (i.e. `noutputs`) or be null. In the former case, -// the names should match the regular expression for ArgDef -// names - "[a-z][a-z0-9_]*". In the latter case, -// names for outputs will be generated automatically. -// opts - various options for the function, e.g. XLA's inlining control. -// description - optional human-readable description of this function. -// status - Set to OK on success and an appropriate error on failure. -// -// Note that when the same TF_Output is listed as both an input and an output, -// the corresponding function's output will equal to this input, -// instead of the original node's output. -// -// Callers must also satisfy the following constraints: -// - `inputs` cannot refer to TF_Outputs within a control flow context. For -// example, one cannot use the output of "switch" node as input. -// - `inputs` and `outputs` cannot have reference types. Reference types are -// not exposed through C API and are being replaced with Resources. We support -// reference types inside function's body to support legacy code. Do not -// use them in new code. -// - Every node in the function's body must have all of its inputs (including -// control inputs). In other words, for every node in the body, each input -// must be either listed in `inputs` or must come from another node in -// the body. In particular, it is an error to have a control edge going from -// a node outside of the body into a node in the body. This applies to control -// edges going from nodes referenced in `inputs` to nodes in the body when -// the former nodes are not in the body (automatically skipped or not -// included in explicitly specified body). -// -// Returns: -// On success, a newly created TF_Function instance. It must be deleted by -// calling TF_DeleteFunction. -// -// On failure, null. -TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( - const TF_Graph* fn_body, const char* fn_name, - unsigned char append_hash_to_fn_name, int num_opers, - const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, - int noutputs, const TF_Output* outputs, const char* const* output_names, - const TF_FunctionOptions* opts, const char* description, TF_Status* status); - -// Similar to TF_GraphToFunction but allows specifying control outputs of the -// function. -// -// The arguments of TF_GraphToFunction have the same meaning, but the new -// arguments are as follows: -// -// ncontrol_outputs: Number of control outputs of the function. -// control_outputs: vector of TF_Operation objects to be marked as control -// outputs of the function. Operations marked as control outputs are -// guaranteed to execute. -// control_output_names: Optional. If not nullptr, vector of strings, one -// per control output, with their names to be added to the function's -// OpDef. -TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( - const TF_Graph* fn_body, const char* fn_name, - unsigned char append_hash_to_fn_name, int num_opers, - const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, - int noutputs, const TF_Output* outputs, const char* const* output_names, - int ncontrol_outputs, const TF_Operation* const* control_outputs, - const char* const* control_output_names, const TF_FunctionOptions* opts, - const char* description, TF_Status* status); - -// Returns the name of the graph function. -// The return value points to memory that is only usable until the next -// mutation to *func. -TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); - -// Write out a serialized representation of `func` (as a FunctionDef protocol -// message) to `output_func_def` (allocated by TF_NewBuffer()). -// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() -// is called. -// -// May fail on very large graphs in the future. -TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, - TF_Buffer* output_func_def, - TF_Status* status); - -// Construct and return the function whose FunctionDef representation is -// serialized in `proto`. `proto_len` must equal the number of bytes -// pointed to by `proto`. -// Returns: -// On success, a newly created TF_Function instance. It must be deleted by -// calling TF_DeleteFunction. -// -// On failure, null. -TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( - const void* proto, size_t proto_len, TF_Status* status); - -// Sets function attribute named `attr_name` to value stored in `proto`. -// If this attribute is already set to another value, it is overridden. -// `proto` should point to a sequence of bytes of length `proto_len` -// representing a binary serialization of an AttrValue protocol -// buffer. -TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func, - const char* attr_name, - const void* proto, - size_t proto_len, - TF_Status* status); - -// Sets `output_attr_value` to the binary-serialized AttrValue proto -// representation of the value of the `attr_name` attr of `func`. -// If `attr_name` attribute is not present, status is set to an error. -TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( - TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value, - TF_Status* status); - -// Frees the memory used by the `func` struct. -// TF_DeleteFunction is a noop if `func` is null. -// Deleting a function does not remove it from any graphs it was copied to. -TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); - -// Attempts to evaluate `output`. This will only be possible if `output` doesn't -// depend on any graph inputs (this function is safe to call if this isn't the -// case though). -// -// If the evaluation is successful, this function returns true and `output`s -// value is returned in `result`. Otherwise returns false. An error status is -// returned if something is wrong with the graph or input. Note that this may -// return false even if no error status is set. -TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph, - TF_Output output, - TF_Tensor** result, - TF_Status* status); - -// TODO(josh11b): Register OpDef, available to all operations added -// to this graph. - -// -------------------------------------------------------------------------- -// API for driving Graph execution. - -typedef struct TF_Session TF_Session; - -// Return a new execution session with the associated graph, or NULL on -// error. Does not take ownership of any input parameters. -// -// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be -// kept alive for the lifetime of the returned TF_Session. New nodes can still -// be added to `graph` after this call. -TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, - const TF_SessionOptions* opts, - TF_Status* status); - -// This function creates a new TF_Session (which is created on success) using -// `session_options`, and then initializes state (restoring tensors and other -// assets) using `run_options`. -// -// Any NULL and non-NULL value combinations for (`run_options, `meta_graph_def`) -// are valid. -// -// - `export_dir` must be set to the path of the exported SavedModel. -// - `tags` must include the set of tags used to identify one MetaGraphDef in -// the SavedModel. -// - `graph` must be a graph newly allocated with TF_NewGraph(). -// -// If successful, populates `graph` with the contents of the Graph and -// `meta_graph_def` with the MetaGraphDef of the loaded model. -TF_CAPI_EXPORT extern TF_Session* TF_LoadSessionFromSavedModel( - const TF_SessionOptions* session_options, const TF_Buffer* run_options, - const char* export_dir, const char* const* tags, int tags_len, - TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status); - -// Close a session. -// -// Contacts any other processes associated with the session, if applicable. -// May not be called after TF_DeleteSession(). -TF_CAPI_EXPORT extern void TF_CloseSession(TF_Session*, TF_Status* status); - -// Destroy a session object. -// -// Even if error information is recorded in *status, this call discards all -// local resources associated with the session. The session may not be used -// during or after this call (and the session drops its reference to the -// corresponding graph). -TF_CAPI_EXPORT extern void TF_DeleteSession(TF_Session*, TF_Status* status); - -// Run the graph associated with the session starting with the supplied inputs -// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). -// -// Any NULL and non-NULL value combinations for (`run_options`, -// `run_metadata`) are valid. -// -// - `run_options` may be NULL, in which case it will be ignored; or -// non-NULL, in which case it must point to a `TF_Buffer` containing the -// serialized representation of a `RunOptions` protocol buffer. -// - `run_metadata` may be NULL, in which case it will be ignored; or -// non-NULL, in which case it must point to an empty, freshly allocated -// `TF_Buffer` that may be updated to contain the serialized representation -// of a `RunMetadata` protocol buffer. -// -// The caller retains ownership of `input_values` (which can be deleted using -// TF_DeleteTensor). The caller also retains ownership of `run_options` and/or -// `run_metadata` (when not NULL) and should manually call TF_DeleteBuffer on -// them. -// -// On success, the tensors corresponding to outputs[0,noutputs-1] are placed in -// output_values[]. Ownership of the elements of output_values[] is transferred -// to the caller, which must eventually call TF_DeleteTensor on them. -// -// On failure, output_values[] contains NULLs. -TF_CAPI_EXPORT extern void TF_SessionRun( - TF_Session* session, - // RunOptions - const TF_Buffer* run_options, - // Input tensors - const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, - // Output tensors - const TF_Output* outputs, TF_Tensor** output_values, int noutputs, - // Target operations - const TF_Operation* const* target_opers, int ntargets, - // RunMetadata - TF_Buffer* run_metadata, - // Output status - TF_Status*); - -// Set up the graph with the intended feeds (inputs) and fetches (outputs) for a -// sequence of partial run calls. -// -// On success, returns a handle that is used for subsequent PRun calls. The -// handle should be deleted with TF_DeletePRunHandle when it is no longer -// needed. -// -// On failure, out_status contains a tensorflow::Status with an error -// message. *handle is set to nullptr. -TF_CAPI_EXPORT extern void TF_SessionPRunSetup( - TF_Session*, - // Input names - const TF_Output* inputs, int ninputs, - // Output names - const TF_Output* outputs, int noutputs, - // Target operations - const TF_Operation* const* target_opers, int ntargets, - // Output handle - const char** handle, - // Output status - TF_Status*); - -// Continue to run the graph with additional feeds and fetches. The -// execution state is uniquely identified by the handle. -TF_CAPI_EXPORT extern void TF_SessionPRun( - TF_Session*, const char* handle, - // Input tensors - const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, - // Output tensors - const TF_Output* outputs, TF_Tensor** output_values, int noutputs, - // Target operations - const TF_Operation* const* target_opers, int ntargets, - // Output status - TF_Status*); - -// Deletes a handle allocated by TF_SessionPRunSetup. -// Once called, no more calls to TF_SessionPRun should be made. -TF_CAPI_EXPORT extern void TF_DeletePRunHandle(const char* handle); - -// -------------------------------------------------------------------------- -// The deprecated session API. Please switch to the above instead of -// TF_ExtendGraph(). This deprecated API can be removed at any time without -// notice. - -typedef struct TF_DeprecatedSession TF_DeprecatedSession; - -TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession( - const TF_SessionOptions*, TF_Status* status); -TF_CAPI_EXPORT extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, - TF_Status* status); -TF_CAPI_EXPORT extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, - TF_Status* status); -TF_CAPI_EXPORT extern void TF_Reset(const TF_SessionOptions* opt, - const char** containers, int ncontainers, - TF_Status* status); -// Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and -// add the nodes in that GraphDef to the graph for the session. -// -// Prefer use of TF_Session and TF_GraphImportGraphDef over this. -TF_CAPI_EXPORT extern void TF_ExtendGraph(TF_DeprecatedSession*, - const void* proto, size_t proto_len, - TF_Status*); - -// See TF_SessionRun() above. -TF_CAPI_EXPORT extern void TF_Run(TF_DeprecatedSession*, - const TF_Buffer* run_options, - const char** input_names, TF_Tensor** inputs, - int ninputs, const char** output_names, - TF_Tensor** outputs, int noutputs, - const char** target_oper_names, int ntargets, - TF_Buffer* run_metadata, TF_Status*); - -// See TF_SessionPRunSetup() above. -TF_CAPI_EXPORT extern void TF_PRunSetup(TF_DeprecatedSession*, - const char** input_names, int ninputs, - const char** output_names, int noutputs, - const char** target_oper_names, - int ntargets, const char** handle, - TF_Status*); - -// See TF_SessionPRun above. -TF_CAPI_EXPORT extern void TF_PRun(TF_DeprecatedSession*, const char* handle, - const char** input_names, TF_Tensor** inputs, - int ninputs, const char** output_names, - TF_Tensor** outputs, int noutputs, - const char** target_oper_names, int ntargets, - TF_Status*); - -typedef struct TF_DeviceList TF_DeviceList; - -// Lists all devices in a TF_Session. -// -// Caller takes ownership of the returned TF_DeviceList* which must eventually -// be freed with a call to TF_DeleteDeviceList. -TF_CAPI_EXPORT extern TF_DeviceList* TF_SessionListDevices(TF_Session* session, - TF_Status* status); - -// Lists all devices in a TF_Session. -// -// Caller takes ownership of the returned TF_DeviceList* which must eventually -// be freed with a call to TF_DeleteDeviceList. -TF_CAPI_EXPORT extern TF_DeviceList* TF_DeprecatedSessionListDevices( - TF_DeprecatedSession* session, TF_Status* status); - -// Deallocates the device list. -TF_CAPI_EXPORT extern void TF_DeleteDeviceList(TF_DeviceList* list); - -// Counts the number of elements in the device list. -TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); - -// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) -// The return value will be a pointer to a null terminated string. The caller -// must not modify or delete the string. It will be deallocated upon a call to -// TF_DeleteDeviceList. -// -// If index is out of bounds, an error code will be set in the status object, -// and a null pointer will be returned. -TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, - int index, - TF_Status* status); - -// Retrieves the type of the device at the given index. -// -// The caller must not modify or delete the string. It will be deallocated upon -// a call to TF_DeleteDeviceList. -// -// If index is out of bounds, an error code will be set in the status object, -// and a null pointer will be returned. -TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, - int index, - TF_Status* status); - -// Retrieve the amount of memory associated with a given device. -// -// If index is out of bounds, an error code will be set in the status object, -// and -1 will be returned. -TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( - const TF_DeviceList* list, int index, TF_Status* status); - -// Retrieve the incarnation number of a given device. -// -// If index is out of bounds, an error code will be set in the status object, -// and 0 will be returned. -TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation( - const TF_DeviceList* list, int index, TF_Status* status); - -// -------------------------------------------------------------------------- -// Load plugins containing custom ops and kernels - -// TF_Library holds information about dynamically loaded TensorFlow plugins. -typedef struct TF_Library TF_Library; - -// Load the library specified by library_filename and register the ops and -// kernels present in that library. -// -// Pass "library_filename" to a platform-specific mechanism for dynamically -// loading a library. The rules for determining the exact location of the -// library are platform-specific and are not documented here. -// -// On success, place OK in status and return the newly created library handle. -// The caller owns the library handle. -// -// On failure, place an error status in status and return NULL. -TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename, - TF_Status* status); - -// Get the OpList of OpDefs defined in the library pointed by lib_handle. -// -// Returns a TF_Buffer. The memory pointed to by the result is owned by -// lib_handle. The data in the buffer will be the serialized OpList proto for -// ops defined in the library. -TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); - -// Frees the memory associated with the library handle. -// Does NOT unload the library. -TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); - -// Get the OpList of all OpDefs defined in this address space. -// Returns a TF_Buffer, ownership of which is transferred to the caller -// (and can be freed using TF_DeleteBuffer). -// -// The data in the buffer will be the serialized OpList proto for ops registered -// in this address space. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); - -// TF_ApiDefMap encapsulates a collection of API definitions for an operation. -// -// This object maps the name of a TensorFlow operation to a description of the -// API to generate for it, as defined by the ApiDef protocol buffer ( -// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) -// -// The ApiDef messages are typically used to generate convenience wrapper -// functions for TensorFlow operations in various language bindings. -typedef struct TF_ApiDefMap TF_ApiDefMap; - -// Creates a new TF_ApiDefMap instance. -// -// Params: -// op_list_buffer - TF_Buffer instance containing serialized OpList -// protocol buffer. (See -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto -// for the OpList proto definition). -// status - Set to OK on success and an appropriate error on failure. -TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, - TF_Status* status); - -// Deallocates a TF_ApiDefMap. -TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap); - -// Add ApiDefs to the map. -// -// `text` corresponds to a text representation of an ApiDefs protocol message. -// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). -// -// The provided ApiDefs will be merged with existing ones in the map, with -// precedence given to the newly added version in case of conflicts with -// previous calls to TF_ApiDefMapPut. -TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, - const char* text, size_t text_len, - TF_Status* status); - -// Returns a serialized ApiDef protocol buffer for the TensorFlow operation -// named `name`. -TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, - const char* name, - size_t name_len, - TF_Status* status); - -// -------------------------------------------------------------------------- -// Kernel definition information. - -// Returns a serialized KernelList protocol buffer containing KernelDefs for all -// registered kernels. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); - -// Returns a serialized KernelList protocol buffer containing KernelDefs for all -// kernels registered for the operation named `name`. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( - const char* name, TF_Status* status); - -// -------------------------------------------------------------------------- -// In-process TensorFlow server functionality, for use in distributed training. -// A Server instance encapsulates a set of devices and a Session target that -// can participate in distributed training. A server belongs to a cluster -// (specified by a ClusterSpec), and corresponds to a particular task in a -// named job. The server can communicate with any other server in the same -// cluster. - -// In-process TensorFlow server. -typedef struct TF_Server TF_Server; - -// Creates a new in-process TensorFlow server configured using a serialized -// ServerDef protocol buffer provided via `proto` and `proto_len`. -// -// The server will not serve any requests until TF_ServerStart is invoked. -// The server will stop serving requests once TF_ServerStop or -// TF_DeleteServer is invoked. -TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, - size_t proto_len, - TF_Status* status); - -// Starts an in-process TensorFlow server. -TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); - -// Stops an in-process TensorFlow server. -TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); - -// Blocks until the server has been successfully stopped (via TF_ServerStop or -// TF_ServerClose). -TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); - -// Returns the target string that can be provided to TF_SetTarget() to connect -// a TF_Session to `server`. -// -// The returned string is valid only until TF_DeleteServer is invoked. -TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); - -// Destroy an in-process TensorFlow server, frees memory. If server is running -// it will be stopped and joined. -TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); - -// Register a listener method that processes printed messages. -// -// If any listeners are registered, the print operator will call all listeners -// with the printed messages and immediately return without writing to the -// logs. -TF_CAPI_EXPORT extern void TF_RegisterLogListener( - void (*listener)(const char*)); - -#ifdef __cplusplus -} /* end extern "C" */ -#endif - -#endif // TENSORFLOW_C_C_CORE_API_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 2ec1f442780..c25cb264ce7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -42,7 +42,7 @@ tf_cuda_library( "//conditions:default": [ "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:fixed_array", - "//tensorflow/c:c_core_api", + "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_tensor_internal", "//tensorflow/core:core_cpu", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 67324a441f9..96dc288f213 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -28,7 +28,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" -#include "tensorflow/c/c_core_api.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/tf_tensor_internal.h" diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index b951f45d0e1..070b3a9bb60 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -20,7 +20,7 @@ limitations under the License. // WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be // stable and can change without notice. -#include "tensorflow/c/c_core_api.h" +#include "tensorflow/c/c_api.h" // Macro to control visibility of exported symbols in the shared library (.so, // .dylib, .dll). diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index f69f79eed7a..c38d7b84a74 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -36,7 +36,6 @@ transitive_hdrs( "//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:reader", "//tensorflow/cc/saved_model:bundle_v2", - "//tensorflow/c:c_core_api_no_xla", # WARNING: None of the C/C++ code under python/ has any API guarantees, and TF team # reserves the right to change APIs and other header-level interfaces. If your custom # op uses these headers, it may break when users upgrade their version of tensorflow. diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 64a4469e0da..4dfe616263b 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -246,7 +246,6 @@ headers = ( list(find_files('*.proto', 'tensorflow/compiler')) + list(find_files('*.proto', 'tensorflow/core')) + list(find_files('*.proto', 'tensorflow/python')) + - list(find_files('*.h', 'tensorflow/c')) + list(find_files('*.h', 'tensorflow/cc')) + list(find_files('*.h', 'tensorflow/compiler')) + list(find_files('*.h', 'tensorflow/core')) + From a9b14560fafc8b5b50ae9547603ee39073191c3a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Mar 2020 01:25:39 +0900 Subject: [PATCH 144/492] fix typo --- tensorflow/lite/kernels/fully_connected.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index e5a92c33f84..fc6f1991fd3 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -251,7 +251,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_size_array = nullptr; if (params->keep_num_dims) { // When number of dimensions are kept the filter operates along the last - // dimentions. In other words, for an input tensor with shape + // dimensions. In other words, for an input tensor with shape // [batch_size, ..., n_inputs] and a filter of shape [n_inputs, n_units] // this Op produces an output of shape [batch_size, ..., n_units]. TF_LITE_ENSURE_EQ(context, input->dims->data[input->dims->size - 1], From a25935196439d8a23885c108e4d21f6ce62697d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 09:32:53 -0700 Subject: [PATCH 145/492] add a micro benchmarks for OpKernel::TraceString PiperOrigin-RevId: 301605173 Change-Id: I9003d08220448f79a0144d637cd0fb00ad54f7b7 --- tensorflow/core/framework/op_kernel_test.cc | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 94b502f3f71..3c915d13fdc 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -1026,6 +1026,7 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def, REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("MatMul").Device(DEVICE_CPU), DummyKernel); void BM_ConcatInputRange(int iters) { testing::StopTiming(); @@ -1067,8 +1068,51 @@ void BM_SelectInputRange(int iters) { BM_InputRangeHelper(iters, node_def, "condition", 0, 1); } +void BM_TraceString(const int iters, const int verbose) { + testing::StopTiming(); + + // Create a MatMul NodeDef with 2 inputs. + NodeDef node_def; + node_def.set_name("gradient_tape/model_1/dense_1/MatMul_1"); + node_def.set_op("MatMul"); + AttrValue transpose_a, transpose_b, attr_t; + attr_t.set_type(DT_FLOAT); + node_def.mutable_attr()->insert({"T", attr_t}); + transpose_a.set_b(true); + node_def.mutable_attr()->insert({"transpose_a", transpose_a}); + transpose_b.set_b(true); + node_def.mutable_attr()->insert({"transpose_b", transpose_b}); + for (size_t i = 0; i < 2; ++i) { + node_def.add_input(strings::StrCat("a:", i)); + } + + // Build OpKernel and OpKernelContext + Status status; + auto device = absl::make_unique(Env::Default()); + std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), + cpu_allocator(), node_def, + TF_GRAPH_DEF_VERSION, &status)); + TF_CHECK_OK(status); + + OpKernelContext::Params params; + params.device = device.get(); + params.op_kernel = op.get(); + Tensor a(DT_FLOAT, TensorShape({99000, 256})); + Tensor b(DT_FLOAT, TensorShape({256, 256})); + gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; + params.inputs = &inputs; + auto ctx = absl::make_unique(¶ms); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + auto trace = op->TraceString(ctx.get(), verbose); + } + testing::StopTiming(); +} + BENCHMARK(BM_ConcatInputRange); BENCHMARK(BM_SelectInputRange); +BENCHMARK(BM_TraceString)->Arg(1)->Arg(0); TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) { auto kernel_list = GetAllRegisteredKernels(); From 200be84a740f88dfd0ab59ea4ec831ec8659108b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 09:35:27 -0700 Subject: [PATCH 146/492] Fix spelling error in comment. PiperOrigin-RevId: 301605675 Change-Id: I46a839e16e9cc87297588025d778d25ed39cc39f --- tensorflow/lite/g3doc/guide/codegen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/g3doc/guide/codegen.md b/tensorflow/lite/g3doc/guide/codegen.md index 0d19ba6dbc0..4cf8a677b98 100644 --- a/tensorflow/lite/g3doc/guide/codegen.md +++ b/tensorflow/lite/g3doc/guide/codegen.md @@ -81,7 +81,7 @@ implementation project(":classify_wrapper") ### Step 3: Using the model ```java -// 1. Initiatize the Model +// 1. Initialize the Model MyClassifierModel myImageClassifier = null; try { From 91f90991272e4f6958a343622c18b0909650aefa Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 18 Mar 2020 09:41:42 -0700 Subject: [PATCH 147/492] [NFC] Replace usage of PatternMatchResult with LogicalResult PiperOrigin-RevId: 301606642 Change-Id: If0884654dc556dcf0060c1e92e2ee0e0b99d0473 --- .../compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 17ba38e8c40..b09a0159bcb 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -95,7 +95,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { // here is that are broadcasts have been made explicit. unsigned nloops = argType.getRank(); - if (isLHLO && !nloops) ConversionPattern::matchFailure(); + if (isLHLO && !nloops) return failure(); int operandCount = (isLHLO ? args.size() - 1 : args.size()); auto verifyArgOrResultType = [&](Value val) -> ShapedType { From 54b07c9014a024a9602987730c491759374d1975 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 09:47:13 -0700 Subject: [PATCH 148/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301607619 Change-Id: I18afef03257c452e29f77757c75c06b8881c7d28 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 6456f104ad3..52a9bf9551b 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From 70efc8c05c46f323da63465e66ecb9e577f9cab6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Mar 2020 01:52:50 +0900 Subject: [PATCH 149/492] minor spelling tweaks --- tensorflow/lite/toco/dump_graphviz.cc | 2 +- .../convert_trivial_transpose_to_reshape.cc | 2 +- .../toco/graph_transformations/graph_transformations.h | 2 +- .../graph_transformations/identify_dilated_conv.cc | 4 ++-- .../propagate_fake_quant_num_bits.cc | 2 +- .../graph_transformations/propagate_fixed_sizes.cc | 10 +++++----- .../toco/graph_transformations/quantization_util.cc | 2 +- .../remove_successive_transpose.cc | 6 +++--- .../tests/remove_successive_transpose_test.cc | 6 +++--- .../toco/graph_transformations/unroll_batch_matmul.cc | 2 +- tensorflow/lite/toco/logging/gen_html.py | 2 +- tensorflow/lite/toco/model.h | 4 ++-- tensorflow/lite/toco/python/toco_from_protos.py | 2 +- tensorflow/lite/toco/python/toco_from_protos_test.py | 2 +- tensorflow/lite/toco/tflite/export.cc | 5 +++-- tensorflow/lite/toco/tflite/operator_test.cc | 2 +- tensorflow/lite/toco/toco_cmdline_flags.cc | 2 +- tensorflow/lite/toco/toco_convert_test.cc | 2 +- tensorflow/lite/toco/toco_tooling.cc | 6 +++--- tensorflow/lite/toco/tooling_util.cc | 2 +- 20 files changed, 34 insertions(+), 33 deletions(-) diff --git a/tensorflow/lite/toco/dump_graphviz.cc b/tensorflow/lite/toco/dump_graphviz.cc index 95a34a7e4fb..68d3b957129 100644 --- a/tensorflow/lite/toco/dump_graphviz.cc +++ b/tensorflow/lite/toco/dump_graphviz.cc @@ -647,7 +647,7 @@ void DumpNode(const Model& model, string* output_file, const string& node_name, for (const auto& child : node.children) { if (!child.second->array_id.empty()) { - // Dump array if this node posesses one. + // Dump array if this node possesses one. DumpArray(model, output_file, child.second->array_id); } // Note that it is always possible to have children. Unlike a filesystem, diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index 3ff878c506a..2b5aaea2b23 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -67,7 +67,7 @@ bool TransposeAffectsMemoryOrder(std::vector perm, } // Note: We can assume we have error checked inputs in PropagateFixedSizes. - // Check that the permutation has propogated. + // Check that the permutation has propagated. std::vector const& perm = transpose_op->perm; if (perm.empty()) { return ::tensorflow::Status::OK(); diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h index 0b765b1f507..07b9fd4c5cf 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -159,7 +159,7 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) DECLARE_GRAPH_TRANSFORMATION(Quantize) DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) -DECLARE_GRAPH_TRANSFORMATION(RemoveSuccesiveTranspose) +DECLARE_GRAPH_TRANSFORMATION(RemoveSuccessiveTranspose) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator) diff --git a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc index ab86f5d07c9..1940068d32a 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -225,8 +225,8 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op, dilation_factor); if (changed) { LOG(INFO) - << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \"" - << conv_base_op->outputs[0] << "\"."; + << "Replaced sub-network with Dilated DepthwiseConv2D op outputting " + << "\"" << conv_base_op->outputs[0] << "\"."; } } diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 602ff4b7b3d..1ed618879c1 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -245,7 +245,7 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation, // This can be thought of as a bidirectional flood-fill of the num_bits implied // final_data_type that terminates at other FakeQuant ops (and a few others as // determined by DoesOpBlockBackwardPropagation/DoesOpBlockForwardPropagation). -// Once all FakeQuant ops have been visted the arrays should all have +// Once all FakeQuant ops have been visited the arrays should all have // appropriate final_data_types if the source graph was annotated with the // proper FakeQuant ops. // diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index dfb0143a05a..34813bcc0de 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -671,7 +671,7 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { break; } } - // Determine the concat size, and enfore that all inputs have + // Determine the concat size, and enforce that all inputs have // the same dimensions count. int concat_size = 0; for (const auto& input_name : op->inputs) { @@ -1098,7 +1098,7 @@ void ProcessUnidirectionalSequenceLstmOperator( constexpr int kInputActivationStateTensor = 18; constexpr int kInputCellStateTensor = 19; - // TFlite intepreter does not support array which is variable and contains a + // TFlite interpreter does not support array which is variable and contains a // buffer (see b/115961645 for more discussion). // The follow block remove buffer from the array to work around the // restriction, as a consequence, downstream applications should not @@ -1142,7 +1142,7 @@ void ProcessUnidirectionalSequenceRnnOperator( } constexpr int kHiddenStateTensor = 4; - // TFlite intepreter does not support array which is variable and contains a + // TFlite interpreter does not support array which is variable and contains a // buffer (see b/115961645 for more discussion). // The follow block remove buffer from the array to work around the // restriction, as a consequence, downstream applications should not @@ -1658,7 +1658,7 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { } if (op->ellipsis_mask != 0) { - // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce + // Something like LOG_FIRST_N(WARNING, 10) would be preferable to reduce // log noise. However, the TensorFlow logging library does not appear to // support this. LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0] @@ -2434,7 +2434,7 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) { break; case OperatorType::kCTCBeamSearchDecoder: // The sizes of the outputs are only known in runtime based on the input. - // Ignore shape progapation here and defer that to the interpreter. + // Ignore shape propagation here and defer that to the interpreter. break; case OperatorType::kMatrixSetDiagV2: // MatrixSetDiagV2 operators are converted to MatrixSetDiag, diff --git a/tensorflow/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/lite/toco/graph_transformations/quantization_util.cc index 56f83c9793f..23749abf0b1 100644 --- a/tensorflow/lite/toco/graph_transformations/quantization_util.cc +++ b/tensorflow/lite/toco/graph_transformations/quantization_util.cc @@ -229,7 +229,7 @@ bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation, ChooseQuantizationParamsForArrayAndQuantizedDataType( array, quantized_data_type, &quantization_params); transformation->AddMessageF( - "No quantization params - infering from data type %s with minmax " + "No quantization params - inferring from data type %s with minmax " "%g,%g as zero_point=%g, scale=%g", ArrayDataTypeName(quantized_data_type), array.minmax->min, array.minmax->max, quantization_params.zero_point, diff --git a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc index 1f0fdf88108..6eccda04c18 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc @@ -56,9 +56,9 @@ void ReplaceOpInputsWith(Model* model, const string& lookfor, } // namespace -::tensorflow::Status RemoveSuccesiveTranspose::Run(Model* model, - std::size_t op_index, - bool* modified) { +::tensorflow::Status RemoveSuccessiveTranspose::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; auto op = model->operators.begin() + op_index; if (op->get()->type != OperatorType::kTranspose) { diff --git a/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc b/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc index a5a0afbe8d1..218cb558e56 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc @@ -94,7 +94,7 @@ TEST_F(RemoveSuccessiveTransposeTest, RemoveTranspose) { // Creating a model. CreateGraph({1, 0}, {1, 0}); - toco::RemoveSuccesiveTranspose transformation; + toco::RemoveSuccessiveTranspose transformation; bool modified; ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok()); EXPECT_TRUE(modified); @@ -109,7 +109,7 @@ TEST_F(RemoveSuccessiveTransposeTest, DontRemoveNotIdentityTranspose) { // Creating a model. CreateGraph({0, 2, 1}, {1, 0, 2}); - toco::RemoveSuccesiveTranspose transformation; + toco::RemoveSuccessiveTranspose transformation; bool modified; ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok()); EXPECT_FALSE(modified); @@ -139,7 +139,7 @@ TEST_F(RemoveSuccessiveTransposeTest, DontRemoveTransposeOutputUnused) { transpose2_op->outputs = {"InputTransposeTranspose"}; model_->operators.push_back(std::unique_ptr(transpose2_op)); - toco::RemoveSuccesiveTranspose transformation; + toco::RemoveSuccessiveTranspose transformation; bool modified; ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok()); EXPECT_FALSE(modified); diff --git a/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc index 994c2bc77b8..16dfaf7fc80 100644 --- a/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -177,7 +177,7 @@ TransposeOperator* TransposeInput(const string& input, Model* model) { CHECK_EQ(input_array_a.shape().dims(dims_a - 1), input_array_b.shape().dims(dims_b - 2)) - << "Input dimensions must be compatible for multipication. shape a = [" + << "Input dimensions must be compatible for multiplication. shape a = [" << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = [" << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]"; diff --git a/tensorflow/lite/toco/logging/gen_html.py b/tensorflow/lite/toco/logging/gen_html.py index 95ad53ad407..c8afcd8ee17 100644 --- a/tensorflow/lite/toco/logging/gen_html.py +++ b/tensorflow/lite/toco/logging/gen_html.py @@ -136,7 +136,7 @@ class HTMLGenerator(object): dot_after: A string, the dot representation of the model after the conversion. toco_err_log: A string, the logs emitted by TOCO during conversion. Caller - need to ensure that this string is properly anoynimized (any kind of + need to ensure that this string is properly anonymized (any kind of user data should be eliminated). tflite_graph_path: A string, the filepath to the converted TFLite model. diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 050366826a1..9c669c2760f 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -490,7 +490,7 @@ struct ConvOperator : Operator { // inputs[4]: optional: merge repeated. // // Outputs: -// outputs[0]: deocoded. +// outputs[0]: decoded. // outputs[1]: log probability. // // TensorFlow equivalent: CTCBeamSearchDecoder @@ -1258,7 +1258,7 @@ struct ExpandDimsOperator : Operator { ExpandDimsOperator() : Operator(OperatorType::kExpandDims) {} }; -// Ceates a tensor of shape dims and fills it with the given scalar value. +// Creates a tensor of shape dims and fills it with the given scalar value. // Output type will be the same as the given scalar value. // // Inputs: diff --git a/tensorflow/lite/toco/python/toco_from_protos.py b/tensorflow/lite/toco/python/toco_from_protos.py index e24af2dc115..0f458416e31 100644 --- a/tensorflow/lite/toco/python/toco_from_protos.py +++ b/tensorflow/lite/toco/python/toco_from_protos.py @@ -21,7 +21,7 @@ import argparse import sys # We need to import pywrap_tensorflow prior to the toco wrapper. -# pylint: disable=invalud-import-order,g-bad-import-order +# pylint: disable=invalid-import-order,g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python import _pywrap_toco_api from tensorflow.python.platform import app diff --git a/tensorflow/lite/toco/python/toco_from_protos_test.py b/tensorflow/lite/toco/python/toco_from_protos_test.py index 511f714dfe1..b1414c4bfcd 100644 --- a/tensorflow/lite/toco/python/toco_from_protos_test.py +++ b/tensorflow/lite/toco/python/toco_from_protos_test.py @@ -85,7 +85,7 @@ class TocoFromProtosTest(googletest.TestCase): val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) out = tf.identity(val, name="out") out2 = tf.sin(val, name="out2") - # This is a valid mdoel + # This is a valid model self._run(sess, img, out, True) # This uses an invalid function. # TODO(aselle): Check to make sure a warning is included. diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index 07f24afb8ec..876973f62e4 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -53,7 +53,7 @@ namespace { // Check if a TensorFlow Op is a control flow op by its name. bool IsControlFlowOp(const string& tensorflow_op) { - // Technically this is equalivent to `::tensorflow::Node::IsControlFlow()`. + // Technically this is equivalent to `::tensorflow::Node::IsControlFlow()`. // It requires to construct a `::tensorflow::Graph` to use that helper // function, so we simply hardcode the list of control flow ops here. if (tensorflow_op == "Switch" || tensorflow_op == "RefSwitch" || @@ -477,7 +477,8 @@ tensorflow::Status Export( for (const string& input_array : model.GetInvalidInputArrays()) { if (model.HasArray(input_array)) { return tensorflow::errors::InvalidArgument(absl::StrCat( - "Placeholder ", input_array, " should be specied by input_arrays.")); + "Placeholder ", input_array, " should be specified by " + "input_arrays.")); } } diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index 6e7f021a727..cff1ea7e7c0 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -37,7 +37,7 @@ class OperatorTest : public ::testing::Test { static auto* by_name = new OpsByName(BuildOperatorByNameMap()); static auto* by_type = new OpsByType(BuildOperatorByTypeMap()); - // Make sure the two maps were consitently built. + // Make sure the two maps were consistently built. CHECK(by_name->count(name)) << "No operator for '" << name << "'."; BaseOperator* op1 = by_name->at(name).get(); CHECK(op1->type() == type) << "while verifying '" << name << "'."; diff --git a/tensorflow/lite/toco/toco_cmdline_flags.cc b/tensorflow/lite/toco/toco_cmdline_flags.cc index 25a286ee76d..c133db8f2a4 100644 --- a/tensorflow/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/lite/toco/toco_cmdline_flags.cc @@ -171,7 +171,7 @@ bool ParseTocoFlagsFromCommandLineFlags( "Ignored if the output format is not TFLite."), Flag("quantize_to_float16", parsed_flags.quantize_to_float16.bind(), parsed_flags.quantize_to_float16.default_value(), - "Used in conjuction with post_training_quantize. Specifies that " + "Used in conjunction with post_training_quantize. Specifies that " "the weights should be quantized to fp16 instead of the default " "(int8)"), Flag("quantize_weights", parsed_flags.quantize_weights.bind(), diff --git a/tensorflow/lite/toco/toco_convert_test.cc b/tensorflow/lite/toco/toco_convert_test.cc index 270c7aadcab..b02c1043f2b 100644 --- a/tensorflow/lite/toco/toco_convert_test.cc +++ b/tensorflow/lite/toco/toco_convert_test.cc @@ -39,7 +39,7 @@ TEST(TocoTest, BadInputFormat) { "Unhandled input_format='FILE_FORMAT_UNKNOWN'"); } -TEST(TocoTest, MissingOuputArrays) { +TEST(TocoTest, MissingOutputArrays) { TocoFlags toco_flags; ModelFlags model_flags; diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index f96c6b83025..da0915f9739 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -67,7 +67,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new PropagateActivationFunctionIntoConstants); transformations->Add(new PropagateArrayDataTypes); transformations->Add(new PropagateFixedSizes); - transformations->Add(new RemoveSuccesiveTranspose); + transformations->Add(new RemoveSuccessiveTranspose); transformations->Add(new RemoveTensorFlowAssert); transformations->Add(new RemoveTensorFlowIdentity); transformations->Add(new RemoveTrivialConcatenation); @@ -415,10 +415,10 @@ tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags, // is: // Input [1, 20, 1, 20, 1, 64] * ones [1, 3, 1, 3, 1, 1] // The problem is if the input is quantized, then the quantization parameters - // will be slightly different for the input and the output. (althought the + // will be slightly different for the input and the output. (although the // difference is really small). // But, since we're changing this pattern to be pack-based which enforce - // the quantization paramters to be exactly the same. + // the quantization parameters to be exactly the same. // So we have to wait for all quantization parameters being resolved and // propagated and create our own. // We may need to revisit this logic later. diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index fc666f1c789..55b98972da6 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -929,7 +929,7 @@ void CheckNonExistentIOArrays(const Model& model) { } static constexpr char general_comment[] = "Is it a typo? This should not happen. If you trigger this error " - "please send a bug report (with code to reporduce this error), to the " + "please send a bug report (with code to reproduce this error), to the " "TensorFlow Lite team."; for (const string& output_array : model.flags.output_arrays()) { if (IsConstantParameterArray(model, output_array)) { From c36851671d46b027febb6e8f980693de53b8af44 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Wed, 18 Mar 2020 09:50:21 -0700 Subject: [PATCH 150/492] Mark XlaBuilder ConstantLiteral method virtual This will overridden in MlirHloBuilder in a later change. PiperOrigin-RevId: 301608176 Change-Id: I275b276dc8091118edb06350352a5e219a4a866c --- tensorflow/compiler/xla/client/xla_builder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 9d03141715f..75975baba91 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -380,7 +380,7 @@ class XlaBuilder { return Parameter(parameter_number, shape, name, empty_bools); } - XlaOp ConstantLiteral(const LiteralSlice& literal); + virtual XlaOp ConstantLiteral(const LiteralSlice& literal); XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); From b16bed28bf5f0a15dce7009c130d9a084d31021c Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 18 Mar 2020 10:03:30 -0700 Subject: [PATCH 151/492] [XLA:Python] Make layouts in ComputationBuilder.CustomCall optional. This allows us to use a single API endpoint for custom calls with or without specified layouts. PiperOrigin-RevId: 301610929 Change-Id: I24c43b2cfe6620b95940dd425c1a606fc6563c30 --- tensorflow/compiler/xla/python/xla.cc | 4 +-- tensorflow/compiler/xla/python/xla_client.py | 36 +++++++++++++++---- .../compiler/xla/python/xla_client_test.py | 3 +- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index b42202ca838..ceea02f2374 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -344,9 +344,7 @@ void BuildOpsSubmodule(py::module* m) { py::arg("precision_config") = nullptr); ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), py::arg("new_element_type")); - // TODO(phawkins): remove CustomCall after callers are updated to use - // CustomCallWithLayout. - ops.def("CustomCall", &CustomCallWithLayout); + ops.def("CustomCall", &CustomCall); ops.def("CustomCallWithLayout", &CustomCallWithLayout); ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), py::arg("precision_config") = nullptr); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index f1f31a5eb89..b6948b6d84d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1192,6 +1192,8 @@ class ComputationBuilder(object): return ops.Call(self._builder, computation_to_apply.computation, list(operands)) + # TODO(skyewm): remove CustomCallWithLayout after callers are updated to use + # CustomCall. def CustomCallWithLayout(self, call_target_name, operands, @@ -1213,13 +1215,35 @@ class ComputationBuilder(object): An XlaOp representing the added custom call op. """ opaque = opaque or b'' - return ops.CustomCall(self._builder, call_target_name, - list(operands), shape_with_layout, - list(operand_shapes_with_layout), opaque) + return ops.CustomCallWithLayout( + self._builder, call_target_name, list(operands), shape_with_layout, + list(operand_shapes_with_layout), opaque) - # TODO(phawkins): remove CustomCall after callers are updated to use - # CustomCallWithLayout. - CustomCall = CustomCallWithLayout + def CustomCall(self, call_target_name, operands, shape, + operand_shapes_with_layout=None, opaque=None): + """Enqueues a custom call operation onto the computation. + + Args: + call_target_name: the name of the function to call. + operands: an iterable of XlaOp. The number and types of operands must + match the arity of `operand_shapes_with_layout`. + shape: the shape of the operator's output. Must have layout if + `operand_shapes_with_layout` is provided. + operand_shapes_with_layout: optional, the shapes of `operands` including + the expected layouts. + opaque: an opaque string passed to the backend. + + Returns: + An XlaOp representing the added custom call op. + """ + opaque = opaque or b'' + if operand_shapes_with_layout is None: + return ops.CustomCall(self._builder, call_target_name, list(operands), + shape, opaque) + else: + return ops.CustomCallWithLayout( + self._builder, call_target_name, list(operands), shape, + list(operand_shapes_with_layout), opaque) def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index de5ae258976..72b536ade68 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -371,8 +371,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.CustomCall( b"test_subtract_f32", operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), - shape_with_layout=xla_client.Shape.array_shape( - np.dtype(np.float32), (), ()), + shape=xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), operand_shapes_with_layout=( xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), From 78c0e5b189246c65271c69670e6b5a037f41f0eb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 10:07:46 -0700 Subject: [PATCH 152/492] don't trace arguments (include tensor shapes and op attributes used for cost analysis) for send/recv ops. PiperOrigin-RevId: 301612015 Change-Id: I1c761c25afaa12435a4e98d847a6ad44b6f2c25f --- tensorflow/core/kernels/sendrecv_ops.cc | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 12456037415..93830515040 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -115,14 +115,8 @@ string SendOp::TraceString(OpKernelContext* ctx, bool verbose) { auto dst_it = attr.find("_dst"); const string& src = src_it != attr.end() ? src_it->second.s() : ""; const string& dst = dst_it != attr.end() ? dst_it->second.s() : ""; - if (!verbose) { - return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, - ",to=", dst, "#"); - } else { - string trace_args = GetTraceArgument(ctx); - return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, - ",to=", dst, ",", trace_args, "#"); - } + return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, + ",to=", dst, "#"); } REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp); @@ -163,14 +157,8 @@ string RecvOp::TraceString(OpKernelContext* ctx, bool verbose) { auto dst_it = attr.find("_dst"); const string& src = src_it != attr.end() ? src_it->second.s() : ""; const string& dst = dst_it != attr.end() ? dst_it->second.s() : ""; - if (!verbose) { - return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, - ",to=", dst, "#"); - } else { - string trace_args = GetTraceArgument(ctx); - return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, - ",to=", dst, ",", trace_args, "#"); - } + return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, + ",to=", dst, "#"); } namespace { From b7e5c22c0bedcbb9251f9e3eee5799247c0a39c1 Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Wed, 18 Mar 2020 10:13:20 -0700 Subject: [PATCH 153/492] Minor changes to BUILD file related to code coverage. PiperOrigin-RevId: 301613214 Change-Id: I31ed5b19a96f23de1ddad9dd07af3857737634d0 --- tensorflow/compiler/xla/tests/BUILD | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d4fba7d28ac..3255aa84685 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -807,6 +807,8 @@ xla_test( # This is a big test that we skip for capacity reasons in OSS testing. "no_oss", "no_pip", + # TODO(b/151340488): Timed out on 2020-03-18. + "nozapfhahn", ], deps = [ ":client_library_test_base", @@ -1099,10 +1101,6 @@ xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], - backend_tags = { - # TODO(b/151340488): Timed out on 2020-03-12. - "interpreter": ["nozapfhahn"], - }, shard_count = 40, tags = [ "no_rocm", @@ -1143,8 +1141,6 @@ xla_test( shard_count = 25, tags = [ "no_rocm", - # TODO(b/151340488): Timed out on 2020-03-12. - "nozapfhahn", ], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", @@ -1500,10 +1496,6 @@ xla_test( name = "select_and_scatter_test", timeout = "long", srcs = ["select_and_scatter_test.cc"], - backend_tags = { - # TODO(b/151340488): Timed out on 2020-03-12. - "interpreter": ["nozapfhahn"], - }, tags = [ "no_rocm", "optonly", From 4eb0f8f752702555284e5b4e0beb063abe0f55a8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 10:18:38 -0700 Subject: [PATCH 154/492] Fix bug in SensitivitySpecificityBase derived metrics. Sub-classes of `SensitivitySpecificityBase` compute the value of one statistic, given a constraint on another statistic (e.g. compute recall at specified precision). Previously the documentation stated "Computes at a given .", there was no guaranteed and consistent behaviour in case the specified cannot be assume value on the provided sample of scores and labels (e.g. required recall of 0.7, but only either 0.6 or 0.8 can be reached). This change refines the function behaviour to "Computes best where >= ". This caters to common use cases of operating binary classifiers, with a requirement to e.g. maintain a minimal precision and maximise the recall - it is important not to report a recall from an operating point that undershoots the required precision (previously the closest precision would be selected, even it if is smaller). Because this changes (refines) the semantics of the metrics, some expected values in unittests etc. must be adapted. PiperOrigin-RevId: 301614619 Change-Id: I3b691cece8d7b0b922eadb6e97c721936057de97 --- tensorflow/python/keras/metrics.py | 130 ++++++++---------- .../keras/metrics_confusion_matrix_test.py | 28 ++-- 2 files changed, 76 insertions(+), 82 deletions(-) diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index e053191d9c4..4333ff784f8 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -1495,10 +1495,35 @@ class SensitivitySpecificityBase(Metric): K.batch_set_value( [(v, np.zeros((num_thresholds,))) for v in self.variables]) + def _find_max_under_constraint(self, constrained, dependent, predicate): + """Returns the maximum of dependent_statistic that satisfies the constraint. + + Args: + constrained: Over these values the constraint + is specified. A rank-1 tensor. + dependent: From these values the maximum that satiesfies the + constraint is selected. Values in this tensor and in + `constrained` are linked by having the same threshold at each + position, hence this tensor must have the same shape. + predicate: A binary boolean functor to be applied to arguments + `constrained` and `self.value`, e.g. `tf.greater`. + + Returns maximal dependent value, if no value satiesfies the constraint 0.0. + """ + feasible = array_ops.where(predicate(constrained, self.value)) + feasible_exists = math_ops.greater(array_ops.size(feasible), 0) + + def get_max(): + return math_ops.reduce_max(array_ops.gather(dependent, feasible)) + + return control_flow_ops.cond(feasible_exists, get_max, lambda: 0.0) + @keras_export('keras.metrics.SensitivityAtSpecificity') class SensitivityAtSpecificity(SensitivitySpecificityBase): - """Computes the sensitivity at a given specificity. + """Computes best sensitivity where specificity is >= specified value. + + the sensitivity at a given specificity. `Sensitivity` measures the proportion of actual positives that are correctly identified as such (tp / (tp + fn)). @@ -1518,16 +1543,16 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): Usage: - >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.4, num_thresholds=1) - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) + >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) + >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], - ... sample_weight=[1, 0, 0, 1]) + >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[1, 1, 2, 2, 1]) >>> m.result().numpy() - 1.0 + 0.333333 Usage with tf.keras API: @@ -1558,20 +1583,12 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): specificity, num_thresholds=num_thresholds, name=name, dtype=dtype) def result(self): - # Calculate specificities at all the thresholds. specificities = math_ops.div_no_nan( self.true_negatives, self.true_negatives + self.false_positives) - - # Find the index of the threshold where the specificity is closest to the - # given specificity. - min_index = math_ops.argmin( - math_ops.abs(specificities - self.value), axis=0) - min_index = math_ops.cast(min_index, dtypes.int32) - - # Compute sensitivity at that index. - return math_ops.div_no_nan( - self.true_positives[min_index], - self.true_positives[min_index] + self.false_negatives[min_index]) + sensitivities = math_ops.div_no_nan( + self.true_positives, self.true_positives + self.false_negatives) + return self._find_max_under_constraint( + specificities, sensitivities, math_ops.greater_equal) def get_config(self): config = { @@ -1584,7 +1601,7 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): @keras_export('keras.metrics.SpecificityAtSensitivity') class SpecificityAtSensitivity(SensitivitySpecificityBase): - """Computes the specificity at a given sensitivity. + """Computes best specificity where sensitivity is >= specified value. `Sensitivity` measures the proportion of actual positives that are correctly identified as such (tp / (tp + fn)). @@ -1604,16 +1621,16 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): Usage: - >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.8, num_thresholds=1) - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) + >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) + >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> m.result().numpy() - 1.0 + 0.66666667 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], - ... sample_weight=[1, 0, 0, 1]) + >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[1, 1, 2, 2, 2]) >>> m.result().numpy() - 1.0 + 0.5 Usage with tf.keras API: @@ -1644,20 +1661,12 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype) def result(self): - # Calculate sensitivities at all the thresholds. sensitivities = math_ops.div_no_nan( self.true_positives, self.true_positives + self.false_negatives) - - # Find the index of the threshold where the sensitivity is closest to the - # requested value. - min_index = math_ops.argmin( - math_ops.abs(sensitivities - self.value), axis=0) - min_index = math_ops.cast(min_index, dtypes.int32) - - # Compute specificity at that index. - return math_ops.div_no_nan( - self.true_negatives[min_index], - self.true_negatives[min_index] + self.false_positives[min_index]) + specificities = math_ops.div_no_nan( + self.true_negatives, self.true_negatives + self.false_positives) + return self._find_max_under_constraint( + sensitivities, specificities, math_ops.greater_equal) def get_config(self): config = { @@ -1670,7 +1679,7 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): @keras_export('keras.metrics.PrecisionAtRecall') class PrecisionAtRecall(SensitivitySpecificityBase): - """Computes the precision at a given recall. + """Computes best precision where recall is >= specified value. This metric creates four local variables, `true_positives`, `true_negatives`, `false_positives` and `false_negatives` that are used to compute the @@ -1682,16 +1691,16 @@ class PrecisionAtRecall(SensitivitySpecificityBase): Usage: - >>> m = tf.keras.metrics.PrecisionAtRecall(0.8, num_thresholds=1) - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) + >>> m = tf.keras.metrics.PrecisionAtRecall(0.5) + >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> m.result().numpy() - 1.0 + 0.5 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], - ... sample_weight=[1, 0, 0, 1]) + >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[2, 2, 2, 1, 1]) >>> m.result().numpy() - 1.0 + 0.33333333 Usage with tf.keras API: @@ -1725,20 +1734,12 @@ class PrecisionAtRecall(SensitivitySpecificityBase): dtype=dtype) def result(self): - # Calculate recall at all the thresholds. recalls = math_ops.div_no_nan( self.true_positives, self.true_positives + self.false_negatives) - - # Find the index of the threshold where the recall is closest to the - # requested value. - min_index = math_ops.argmin( - math_ops.abs(recalls - self.value), axis=0) - min_index = math_ops.cast(min_index, dtypes.int32) - - # Compute precision at that index. - return math_ops.div_no_nan( - self.true_positives[min_index], - self.true_positives[min_index] + self.false_positives[min_index]) + precisions = math_ops.div_no_nan( + self.true_positives, self.true_positives + self.false_positives) + return self._find_max_under_constraint( + recalls, precisions, math_ops.greater_equal) def get_config(self): config = {'num_thresholds': self.num_thresholds, 'recall': self.recall} @@ -1748,7 +1749,7 @@ class PrecisionAtRecall(SensitivitySpecificityBase): @keras_export('keras.metrics.RecallAtPrecision') class RecallAtPrecision(SensitivitySpecificityBase): - """Computes the maximally achievable recall at a required precision. + """Computes best recall where precision is >= specified value. For a given score-label-distribution the required precision might not be achievable, in this case 0.0 is returned as recall. @@ -1763,7 +1764,7 @@ class RecallAtPrecision(SensitivitySpecificityBase): Usage: - >>> m = tf.keras.metrics.RecallAtPrecision(0.8, num_thresholds=1) + >>> m = tf.keras.metrics.RecallAtPrecision(0.8) >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) >>> m.result().numpy() 0.5 @@ -1806,21 +1807,12 @@ class RecallAtPrecision(SensitivitySpecificityBase): dtype=dtype) def result(self): - # Calculate precision and recall at all the thresholds. - # All recalls are computed, because they are not a monotoneous function of - # precision and we want to search for the highest feasible recall. precisions = math_ops.div_no_nan( self.true_positives, self.true_positives + self.false_positives) recalls = math_ops.div_no_nan( self.true_positives, self.true_positives + self.false_negatives) - # Find best recall where the precision is as good as required. - feasible = array_ops.where(math_ops.greater_equal(precisions, self.value)) - feasible_exists = math_ops.greater(array_ops.size(feasible), 0) - best_recall = control_flow_ops.cond( - feasible_exists, - lambda: math_ops.reduce_max(array_ops.gather(recalls, feasible)), - lambda: 0.0) - return best_recall + return self._find_max_under_constraint( + precisions, recalls, math_ops.greater_equal) def get_config(self): config = {'num_thresholds': self.num_thresholds, diff --git a/tensorflow/python/keras/metrics_confusion_matrix_test.py b/tensorflow/python/keras/metrics_confusion_matrix_test.py index 2ea6282cb27..186c3f0328f 100644 --- a/tensorflow/python/keras/metrics_confusion_matrix_test.py +++ b/tensorflow/python/keras/metrics_confusion_matrix_test.py @@ -877,15 +877,15 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase): self.assertAlmostEqual(1, self.evaluate(result)) def test_unweighted_high_sensitivity(self): - s_obj = metrics.SpecificityAtSensitivity(0.8) - pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9] + s_obj = metrics.SpecificityAtSensitivity(1.0) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_true = constant_op.constant(label_values) self.evaluate(variables.variables_initializer(s_obj.variables)) result = s_obj(y_true, y_pred) - self.assertAlmostEqual(0.4, self.evaluate(result)) + self.assertAlmostEqual(0.2, self.evaluate(result)) def test_unweighted_low_sensitivity(self): s_obj = metrics.SpecificityAtSensitivity(0.4) @@ -974,40 +974,42 @@ class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase): def test_unweighted_high_recall(self): s_obj = metrics.PrecisionAtRecall(0.8) - pred_values = [0.0, 0.1, 0.2, 0.3, 0.5, 0.4, 0.5, 0.6, 0.8, 0.9] + pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] - # For a score between 0.4 and 0.5, we expect 0.8 precision, 0.8 recall. y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_true = constant_op.constant(label_values) self.evaluate(variables.variables_initializer(s_obj.variables)) result = s_obj(y_true, y_pred) - self.assertAlmostEqual(0.8, self.evaluate(result)) + # For 0.5 < decision threshold < 0.6. + self.assertAlmostEqual(2.0/3, self.evaluate(result)) def test_unweighted_low_recall(self): - s_obj = metrics.PrecisionAtRecall(0.4) - pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.15, 0.25, 0.26, 0.26] + s_obj = metrics.PrecisionAtRecall(0.6) + pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_true = constant_op.constant(label_values) self.evaluate(variables.variables_initializer(s_obj.variables)) result = s_obj(y_true, y_pred) - self.assertAlmostEqual(0.5, self.evaluate(result)) + # For 0.2 < decision threshold < 0.5. + self.assertAlmostEqual(0.75, self.evaluate(result)) @parameterized.parameters([dtypes.bool, dtypes.int32, dtypes.float32]) def test_weighted(self, label_dtype): - s_obj = metrics.PrecisionAtRecall(0.4) - pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + s_obj = metrics.PrecisionAtRecall(7.0/8) + pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] - weight_values = [2, 2, 1, 1, 1, 1, 1, 2, 2, 2] + weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2] y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_true = math_ops.cast(label_values, dtype=label_dtype) weights = constant_op.constant(weight_values) self.evaluate(variables.variables_initializer(s_obj.variables)) result = s_obj(y_true, y_pred, sample_weight=weights) - self.assertAlmostEqual(2./3., self.evaluate(result)) + # For 0.0 < decision threshold < 0.2. + self.assertAlmostEqual(0.7, self.evaluate(result)) def test_invalid_sensitivity(self): with self.assertRaisesRegexp( From c4a4aab60571ad0d018014f1b125c630f91fb5ea Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Mar 2020 02:24:31 +0900 Subject: [PATCH 155/492] minor spelling tweaks --- .../ilsvrc/generate_validation_labels.py | 2 +- .../lite/tools/benchmark/benchmark_test.cc | 8 +++---- .../tools/benchmark/benchmark_tflite_model.cc | 2 +- .../lite/tools/evaluation/evaluation_stage.h | 4 ++-- .../stages/topk_accuracy_eval_stage.h | 2 +- .../stages/utils/image_metrics_test.cc | 2 +- .../lite/tools/gen_op_registration_test.cc | 2 +- .../calibration/builtin_logging_ops/lstm.cc | 24 +++++++++---------- .../optimize/calibration/calibrator_test.cc | 2 +- .../calibration/logging_op_resolver.cc | 12 +++++----- .../lite/tools/optimize/operator_property.cc | 24 +++++++++---------- .../lite/tools/optimize/quantize_model.cc | 8 +++---- 12 files changed, 46 insertions(+), 46 deletions(-) diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/generate_validation_labels.py b/tensorflow/lite/tools/accuracy/ilsvrc/generate_validation_labels.py index c32a41e50d3..e91a9a9e7d0 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/generate_validation_labels.py +++ b/tensorflow/lite/tools/accuracy/ilsvrc/generate_validation_labels.py @@ -88,7 +88,7 @@ def main(): parser.add_argument( '--ilsvrc_devkit_dir', type=str, - help='Full path to ILSVRC 2012 devikit directory.') + help='Full path to ILSVRC 2012 devkit directory.') args = parser.parse_args() try: _check_arguments(args) diff --git a/tensorflow/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc index 663a8187148..f8d8a3dcd81 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc @@ -325,10 +325,10 @@ TEST(BenchmarkTest, DoesntCrashWithExplicitInputValueFilesStringModel) { class MaxDurationWorksTestListener : public BenchmarkListener { void OnBenchmarkEnd(const BenchmarkResults& results) override { - const int64_t num_actul_runs = results.inference_time_us().count(); - TFLITE_LOG(INFO) << "number of actual runs: " << num_actul_runs; - EXPECT_GE(num_actul_runs, 1); - EXPECT_LT(num_actul_runs, 100000000); + const int64_t num_actual_runs = results.inference_time_us().count(); + TFLITE_LOG(INFO) << "number of actual runs: " << num_actual_runs; + EXPECT_GE(num_actual_runs, 1); + EXPECT_LT(num_actual_runs, 100000000); } }; diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 35a5f6f16ca..cd00a196337 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -488,7 +488,7 @@ BenchmarkTfLiteModel::CreateRandomTensorData(const TfLiteTensor& t, #else // You need to build with -DTFLITE_ENABLE_FP16_CPU_BENCHMARKS=1 using a // compiler that supports __fp16 type. Note: when using Clang and *not* - // linking with compiler-rt, a defintion of __gnu_h2f_ieee and + // linking with compiler-rt, a definition of __gnu_h2f_ieee and // __gnu_f2h_ieee must be supplied. TFLITE_LOG(FATAL) << "Populating the tensor " << t.name << " of type FLOAT16 is disabled."; diff --git a/tensorflow/lite/tools/evaluation/evaluation_stage.h b/tensorflow/lite/tools/evaluation/evaluation_stage.h index 5a4f546e84d..36203a69804 100644 --- a/tensorflow/lite/tools/evaluation/evaluation_stage.h +++ b/tensorflow/lite/tools/evaluation/evaluation_stage.h @@ -24,8 +24,8 @@ namespace evaluation { // Superclass for a single stage of an EvaluationPipeline. // Defines basic skeleton for sub-classes to implement. // -// Ideally EvaluationStages should obtain access to initilizer/input objects via -// Get/Set methods on pointers, and not take ownership unless necessary. +// Ideally EvaluationStages should obtain access to initializer/input objects +// via Get/Set methods on pointers, and not take ownership unless necessary. class EvaluationStage { public: // Initializes an EvaluationStage, including verifying the diff --git a/tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.h b/tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.h index a5631824d0f..b81d7e905fd 100644 --- a/tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.h @@ -62,7 +62,7 @@ class TopkAccuracyEvalStage : public EvaluationStage { private: // Updates accuracy_counts_ based on comparing top k labels and the - // groundtruth one. Using string comparision since there are some duplicate + // groundtruth one. Using string comparison since there are some duplicate // labels in the imagenet dataset. void UpdateCounts(const std::vector& topk_indices); diff --git a/tensorflow/lite/tools/evaluation/stages/utils/image_metrics_test.cc b/tensorflow/lite/tools/evaluation/stages/utils/image_metrics_test.cc index c8ee5cdca1f..d34cc0d2964 100644 --- a/tensorflow/lite/tools/evaluation/stages/utils/image_metrics_test.cc +++ b/tensorflow/lite/tools/evaluation/stages/utils/image_metrics_test.cc @@ -128,7 +128,7 @@ TEST(ImageMetricsTest, BBoxAPwithIgnoredGroundTruth) { pd.push_back({false, 100, 0.95, {{0.9, 1.9}, {0.9, 1.9}}}); - // Two gt and three pd, one pair get ignored. So it's actuallly one gt with + // Two gt and three pd, one pair get ignored. So it's actually one gt with // two pd. EXPECT_NEAR(0.5, AveragePrecision().FromBoxes(gt, pd), 1e-6); gt[0].ignore = kIgnoreAllMatches; diff --git a/tensorflow/lite/tools/gen_op_registration_test.cc b/tensorflow/lite/tools/gen_op_registration_test.cc index 3037963264e..fd3598fb60e 100644 --- a/tensorflow/lite/tools/gen_op_registration_test.cc +++ b/tensorflow/lite/tools/gen_op_registration_test.cc @@ -36,7 +36,7 @@ class GenOpRegistrationTest : public ::testing::Test { std::map> custom_ops_; }; -TEST_F(GenOpRegistrationTest, TestNonExistantFiles) { +TEST_F(GenOpRegistrationTest, TestNonExistentFiles) { ReadOps("/tmp/tflite_model_1234"); EXPECT_EQ(builtin_ops_.size(), 0); EXPECT_EQ(custom_ops_.size(), 0); diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc index 094ae889d70..41a03f16d63 100644 --- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc +++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc @@ -65,7 +65,7 @@ inline void LstmStepWithAuxInput( float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr, Logger* logger, - const std::vector& intemediate_tensor_indexes, + const std::vector& intermediate_tensor_indexes, ErrorReporter* error_reporter) { // Since we have already checked that weights are all there or none, we can // check the existence of only one to the get the condition. @@ -155,7 +155,7 @@ inline void LstmStepWithAuxInput( input_gate_scratch); } if (use_layer_norm) { - logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch, + logger->LogTensorValue(intermediate_tensor_indexes[0], input_gate_scratch, n_cell * n_batch, error_reporter); tensor_utils::MeanStddevNormalization( input_gate_scratch, input_gate_scratch, n_cell, n_batch); @@ -176,7 +176,7 @@ inline void LstmStepWithAuxInput( forget_gate_scratch); } if (use_layer_norm) { - logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch, + logger->LogTensorValue(intermediate_tensor_indexes[1], forget_gate_scratch, n_cell * n_batch, error_reporter); tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch, n_cell, n_batch); @@ -193,7 +193,7 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); if (use_layer_norm) { - logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch, + logger->LogTensorValue(intermediate_tensor_indexes[2], cell_scratch, n_cell * n_batch, error_reporter); tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch); @@ -226,7 +226,7 @@ inline void LstmStepWithAuxInput( output_gate_scratch); } if (use_layer_norm) { - logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch, + logger->LogTensorValue(intermediate_tensor_indexes[3], output_gate_scratch, n_cell * n_batch, error_reporter); tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch, n_cell, n_batch); @@ -243,7 +243,7 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, n_batch * n_cell, output_gate_scratch); - logger->LogTensorValue(intemediate_tensor_indexes[4], output_gate_scratch, + logger->LogTensorValue(intermediate_tensor_indexes[4], output_gate_scratch, n_cell * n_batch, error_reporter); const bool use_projection_weight = (projection_weights_ptr != nullptr); @@ -314,7 +314,7 @@ TfLiteStatus EvalFloat( int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, TfLiteTensor* cell_state, TfLiteTensor* output, Logger* logger, - const std::vector& intemediate_tensor_indexes, + const std::vector& intermediate_tensor_indexes, ErrorReporter* error_reporter) { TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); int max_time, n_batch; @@ -402,7 +402,7 @@ TfLiteStatus EvalFloat( GetTensorData(activation_state), GetTensorData(cell_state), input_gate_scratch, forget_gate_scratch, cell_scratch, output_gate_scratch, - output_ptr_time, logger, intemediate_tensor_indexes, error_reporter); + output_ptr_time, logger, intermediate_tensor_indexes, error_reporter); } } else { for (int b = 0; b < n_batch; b++) { @@ -463,7 +463,7 @@ TfLiteStatus EvalFloat( n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr, forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr, - output_ptr, logger, intemediate_tensor_indexes, error_reporter); + output_ptr, logger, intermediate_tensor_indexes, error_reporter); } } } @@ -559,9 +559,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger, TfLiteTensor* output = GetOutput(context, node, ops::builtin::lstm::full::kOutputTensor); - std::vector intemediate_tensor_indexes(node->intermediates->size); + std::vector intermediate_tensor_indexes(node->intermediates->size); for (int i = 0; i < node->intermediates->size; ++i) { - intemediate_tensor_indexes[i] = node->intermediates->data[i]; + intermediate_tensor_indexes[i] = node->intermediates->data[i]; } switch (input_to_output_weights->type) { @@ -583,7 +583,7 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger, projection_bias, params, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, scratch_buffer, activation_state, cell_state, - output, logger, intemediate_tensor_indexes, error_reporter); + output, logger, intermediate_tensor_indexes, error_reporter); } case kTfLiteUInt8: case kTfLiteInt8: diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc b/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc index 8fd1ca4fe36..e95178a13ce 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc +++ b/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc @@ -165,7 +165,7 @@ TEST(CalibratorTest, MultipleInvokes) { EXPECT_NEAR(stats.at(tensor_idx).max, expected_values[tensor_idx], eps); } // Set input[0][0] = 1.5 and input[0][1] = 0.5 this should change the values - // only for input[0] and tensor 4 and ouputs 5, 6. + // only for input[0] and tensor 4 and outputs 5, 6. TfLiteTensor* input0 = interpreter->tensor(0); input0->data.f[0] = 1.5f; input0->data.f[1] = 0.5f; diff --git a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc index fcb48013ef0..634b2a76a3a 100644 --- a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc +++ b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc @@ -29,20 +29,20 @@ LoggingOpResolver::LoggingOpResolver( base_resolver.FindOp(op_and_version.first, op_and_version.second); BuiltinOperatorKey key = op_and_version; builtin_op_evalfn_map_[key] = base_registration->invoke; - auto logging_registation = + auto logging_registration = absl::make_unique(*base_registration); - logging_registation->invoke = logging_eval_fn; - builtin_op_registration_map_[key] = std::move(logging_registation); + logging_registration->invoke = logging_eval_fn; + builtin_op_registration_map_[key] = std::move(logging_registration); } for (const auto& op_and_version : custom_ops_to_replace) { const TfLiteRegistration* base_registration = base_resolver.FindOp( op_and_version.first.c_str(), op_and_version.second); CustomOperatorKey key = op_and_version; custom_op_evalfn_map_[key] = base_registration->invoke; - auto logging_registation = + auto logging_registration = absl::make_unique(*base_registration); - logging_registation->invoke = logging_eval_fn; - custom_op_registration_map_[key] = std::move(logging_registation); + logging_registration->invoke = logging_eval_fn; + custom_op_registration_map_[key] = std::move(logging_registration); } } diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index 0a8e52b94c0..88bf67a7a0a 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -496,8 +496,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, tensor_property_9.number_of_bits = 16; tensor_property_9.symmetric = true; // Without layer norm, we choose to quantize bias with the scale of - // input and its correpsonding weight. The other choice will - // be to ues the scale of recurrent and its correpsonding weight but we + // input and its corresponding weight. The other choice will + // be to ues the scale of recurrent and its corresponding weight but we // choose to use the smaller scale, which means higher resolution. TensorProperty tensor_property_12; tensor_property_12.use_derived_scale = true; @@ -548,7 +548,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, }; property.outputs = {{0, {}}}; property.intermediates = { - // Without layer normliazation, intermediate tensors 0, 1, 2, 3 are + // Without layer normalization, intermediate tensors 0, 1, 2, 3 are // not used and and their quantization parameters are ignored. {0, {}}, {1, {}}, @@ -563,8 +563,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, if (!op_variant.use_layer_norm && op_variant.use_projection && !op_variant.use_peephole) { // Without layer norm, we choose to quantize bias with the scale of - // input and its correpsonding weight. The other choice will - // be to ues the scale of recurrent and its correpsonding weight but we + // input and its corresponding weight. The other choice will + // be to ues the scale of recurrent and its corresponding weight but we // choose to use the smaller scale, which means higher resolution. TensorProperty tensor_property_12; tensor_property_12.use_derived_scale = true; @@ -612,7 +612,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, }; property.outputs = {{0, {}}}; property.intermediates = { - // Without layer normliazation, intermediate tensors 0, 1, 2, 3 are + // Without layer normalization, intermediate tensors 0, 1, 2, 3 are // not used and their quantization parameters are ignored. {0, {}}, {1, {}}, @@ -630,8 +630,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, tensor_property_9.number_of_bits = 16; tensor_property_9.symmetric = true; // Without layer norm, we choose to quantize bias with the scale of - // input and its correpsonding weight. The other choice will - // be to ues the scale of recurrent and its correpsonding weight but we + // input and its corresponding weight. The other choice will + // be to ues the scale of recurrent and its corresponding weight but we // choose to use the smaller scale, which means higher resolution. TensorProperty tensor_property_12; tensor_property_12.use_derived_scale = true; @@ -676,7 +676,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, }; property.outputs = {{0, {}}}; property.intermediates = { - // Without layer normliazation, intermediate tensors 0, 1, 2, 3 are + // Without layer normalization, intermediate tensors 0, 1, 2, 3 are // not used and their quantization parameters are ignored. {0, {}}, {1, {}}, @@ -696,8 +696,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, if (!op_variant.use_layer_norm && !op_variant.use_projection && !op_variant.use_peephole) { // Without layer norm, we choose to quantize bias with the scale of - // input and its correpsonding weight. The other choice will - // be to ues the scale of recurrent and its correpsonding weight but we + // input and its corresponding weight. The other choice will + // be to ues the scale of recurrent and its corresponding weight but we // choose to use the smaller scale, which means higher resolution. TensorProperty tensor_property_12; tensor_property_12.use_derived_scale = true; @@ -739,7 +739,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, }; property.outputs = {{0, {}}}; property.intermediates = { - // Without layer normliazation, intermediate tensors 0, 1, 2, 3 are + // Without layer normalization, intermediate tensors 0, 1, 2, 3 are // not used and their quantization parameters are ignored. {0, {}}, {1, {}}, diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index 6cc00dc2f26..5ebc513f5bc 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -347,7 +347,7 @@ TfLiteStatus ApplyConstraints(ModelT* model, // Add requant op before this input. // There are better ways to handle this, which is to try to push the - // rescale upwards recurrsively and hope all upstream ops can absort + // rescale upwards recursively and hope all upstream ops can absort // this rescale.and only add requant when there is no other way. std::unique_ptr requant_op; utils::MakeQuantizeOperator(model, &requant_op, op->inputs[input_idx], @@ -747,9 +747,9 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model, // Quantize tensros that have shared range. For example, in LSTM, the output // tensor and input state tensor should share the same range because they are // using the same scale and zero point. -// We have to model this explicitely because the output is modeled as an extra +// We have to model this explicitly because the output is modeled as an extra // tensor in LSTM. In calibrator, state tensors are logged both before and after -// the inferece so the range is fully captured. But output, although it is +// the inference so the range is fully captured. But output, although it is // identical to activation, is not a state tensor the input value (range) of the // very first inference is not captured. TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) { @@ -1073,7 +1073,7 @@ TfLiteStatus EnsureBiasScaleCompatibility( return kTfLiteError; } - // Get input scale for assymmetric quantization. + // Get input scale for asymmetric quantization. QuantizationParametersT temp_quant_params = QuantizationParametersT(); utils::GetAsymmetricQuantizationParams( input_tensor->quantization->min[0], From 1f28da1bd5f48cf56dbddb69ca38c878d4c47776 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Wed, 18 Mar 2020 10:24:23 -0700 Subject: [PATCH 156/492] Utilize registered delegate providers to populate default parameters in benchmark_test. PiperOrigin-RevId: 301615877 Change-Id: Ibd3d32346c0bd090a339b14206fa5c672df4e447 --- tensorflow/lite/tools/benchmark/BUILD | 1 + tensorflow/lite/tools/benchmark/benchmark_test.cc | 14 +++++--------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index f9e33c74aa2..b8b79cc34df 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -106,6 +106,7 @@ cc_test( deps = [ ":benchmark_performance_options", ":benchmark_tflite_model_lib", + ":delegate_provider_hdr", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", "//tensorflow/lite/testing:util", diff --git a/tensorflow/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc index 663a8187148..1f28c6a6a24 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/tools/benchmark/benchmark_performance_options.h" #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/lite/tools/benchmark/delegate_provider.h" #include "tensorflow/lite/tools/command_line_flags.h" namespace { @@ -70,28 +71,23 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs, BenchmarkParam::Create("")); params.AddParam("input_layer_value_files", BenchmarkParam::Create("")); - params.AddParam("use_hexagon", BenchmarkParam::Create(false)); - params.AddParam("use_xnnpack", BenchmarkParam::Create(false)); - params.AddParam("use_nnapi", BenchmarkParam::Create(false)); params.AddParam("allow_fp16", BenchmarkParam::Create(false)); params.AddParam("require_full_delegation", BenchmarkParam::Create(false)); params.AddParam("warmup_min_secs", BenchmarkParam::Create(0.5f)); params.AddParam("use_legacy_nnapi", BenchmarkParam::Create(false)); - params.AddParam("use_gpu", BenchmarkParam::Create(false)); params.AddParam("enable_op_profiling", BenchmarkParam::Create(false)); params.AddParam("max_profiling_buffer_entries", BenchmarkParam::Create(1024)); - params.AddParam("nnapi_accelerator_name", - BenchmarkParam::Create("")); - params.AddParam("nnapi_execution_preference", - BenchmarkParam::Create("")); - params.AddParam("disable_nnapi_cpu", BenchmarkParam::Create(false)); params.AddParam("max_delegated_partitions", BenchmarkParam::Create(0)); params.AddParam("profiling_output_csv_file", BenchmarkParam::Create("")); params.AddParam("enable_platform_tracing", BenchmarkParam::Create(false)); + + for (const auto& delegate_provider : GetRegisteredDelegateProviders()) { + delegate_provider->AddParams(¶ms); + } return params; } From 50ae62b65a614414e9ce81f64f4daf0d590a27c6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Mar 2020 02:33:15 +0900 Subject: [PATCH 157/492] minor spelling tweaks --- tensorflow/lite/examples/label_image/label_image.cc | 2 +- tensorflow/lite/external_cpu_backend_context.h | 4 ++-- tensorflow/lite/graph_info.h | 2 +- tensorflow/lite/interpreter.cc | 2 +- tensorflow/lite/model_test.cc | 2 +- tensorflow/lite/util.h | 6 +++--- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc index 2501c85bb81..b493fafa839 100644 --- a/tensorflow/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -132,7 +132,7 @@ void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t subgraph_index, uint32_t op_index, TfLiteRegistration registration) { // output something like - // time (ms) , Node xxx, OpCode xxx, symblic name + // time (ms) , Node xxx, OpCode xxx, symbolic name // 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3) diff --git a/tensorflow/lite/external_cpu_backend_context.h b/tensorflow/lite/external_cpu_backend_context.h index 3348f677413..c667057a48c 100644 --- a/tensorflow/lite/external_cpu_backend_context.h +++ b/tensorflow/lite/external_cpu_backend_context.h @@ -25,7 +25,7 @@ namespace tflite { // This is the base class for TF Lite internal backend contexts (like a // RUY-based cpu backend context class). A derived internal backend context is // generally a collection of utilities (i.e. a thread pool etc.) for TF Lite to -// use certain keneral libraries, such as Gemmlowp, RUY, etc., to implement TF +// use certain kernel libraries, such as Gemmlowp, RUY, etc., to implement TF // Lite operators. class TfLiteInternalBackendContext { public: @@ -68,7 +68,7 @@ class TfLiteInternalBackendContext { // the #thread info in the global cpu backend context (i.e. 'global_ctxt' above) // that affects how much parallelism an interpreter invocation will use. // Therefore, if different number of threads are used among different -// interpreters, don't call 'SetNumThreads' consectutively but call it +// interpreters, don't call 'SetNumThreads' consecutively but call it // separately between each interpreter's invocation as illustrated above. // // Note: it is the responsibility of the user of this context (i.e. a diff --git a/tensorflow/lite/graph_info.h b/tensorflow/lite/graph_info.h index 91533be8a22..cf84c1466af 100644 --- a/tensorflow/lite/graph_info.h +++ b/tensorflow/lite/graph_info.h @@ -41,7 +41,7 @@ class GraphInfo { // num_nodes(). virtual const TfLiteNode& node(size_t index) const = 0; - // Returns an implementation-speicfic node index which may be different from + // Returns an implementation-specific node index which may be different from // index. virtual size_t node_index(size_t index) const = 0; diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index d333fa736e3..db47af8014c 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -136,7 +136,7 @@ void Interpreter::SetExternalContext(TfLiteExternalContextType type, // If it's overwritten here, we will release the resource of the internally // owned external context. // Note: the 'max thread count' info associated with the overwritten context - // will be lost here, and such info is now detemined by the new context, thus + // will be lost here, and such info is now determined by the new context, thus // affecting how much parallelism a TFLite op would have. if (kTfLiteCpuBackendContext == type && external_contexts_[kTfLiteCpuBackendContext] == diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc index b9efdf676a8..8e87ada3faf 100644 --- a/tensorflow/lite/model_test.cc +++ b/tensorflow/lite/model_test.cc @@ -71,7 +71,7 @@ class TrivialResolver : public OpResolver { TfLiteRegistration* constant_return_; }; -TEST(BasicFlatBufferModel, TestNonExistantFiles) { +TEST(BasicFlatBufferModel, TestNonExistentFiles) { ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234")); } diff --git a/tensorflow/lite/util.h b/tensorflow/lite/util.h index 3b042eb5986..1b68f699662 100644 --- a/tensorflow/lite/util.h +++ b/tensorflow/lite/util.h @@ -44,7 +44,7 @@ TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector& input); // Converts an array (of the given size) to a `TfLiteIntArray`. The caller // takes ownership of the returned pointer, and must make sure 'dims' has at -// least 'rank' elemnts. +// least 'rank' elements. TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims); // Checks whether a `TfLiteIntArray` and an int array have matching elements. @@ -66,8 +66,8 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, size_t* bytes); // Creates a stub TfLiteRegistration instance with the provided -// `custom_op_name`. The op will fail if invoked, and is useful as a placeholde -// to defer op resolution. +// `custom_op_name`. The op will fail if invoked, and is useful as a +// placeholder to defer op resolution. // Note that `custom_op_name` must remain valid for the returned op's lifetime.. TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name); From c74e9eca851b03970616417aa13b0f9eb95f7219 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Wed, 18 Mar 2020 10:30:09 -0700 Subject: [PATCH 158/492] Apply name change(experimental_run_v2 -> run) for all callers in Tensorflow. PiperOrigin-RevId: 301617216 Change-Id: I17624641a96dac369b52833d9c72c3d4d2172596 --- tensorflow/python/distribute/tpu_strategy_test.py | 4 ++-- tensorflow/python/tpu/tpu.py | 2 +- .../training/experimental/loss_scaling_gradient_tape.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index bec96e4eece..4b88ae7134a 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -137,7 +137,7 @@ class TPUStrategyTest(test.TestCase): def computation(): return random_ops.random_gamma([10], [0.5, 1.5]) - return strategy.experimental_run_v2(computation) + return strategy.run(computation) with self.assertRaisesRegexp(errors.InvalidArgumentError, "TPU compilation failed"): @@ -149,7 +149,7 @@ class TPUStrategyTest(test.TestCase): def computation(): return random_ops.random_normal([10]) - return strategy.experimental_run_v2(computation) + return strategy.run(computation) good_run() diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 5c3a61d5d8d..fe8fac794db 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -223,7 +223,7 @@ def tpu_replicated_input_resolver(op, resource_reads, resource_writes): return False # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput # with the actual replicated inputs. This allows ACD to correct add control - # deps when there are multiple calls to `experimental_run_v2` in a + # deps when there are multiple calls to `run` in a # `tf.function`. def replace_with_unreplicated_resources(resource_inputs): """Replaces handles in `resource_inputs` with their unreplicated inputs.""" diff --git a/tensorflow/python/training/experimental/loss_scaling_gradient_tape.py b/tensorflow/python/training/experimental/loss_scaling_gradient_tape.py index a2502b8a43f..730f3bef9bc 100644 --- a/tensorflow/python/training/experimental/loss_scaling_gradient_tape.py +++ b/tensorflow/python/training/experimental/loss_scaling_gradient_tape.py @@ -40,7 +40,7 @@ def _convert_to_per_replicas(distribution, values): Returns: `values`, but each element has been converted to a PerReplica value. """ - return distribution.experimental_run_v2( + return distribution.run( lambda values: [array_ops.identity(v) for v in values], args=(values,) ) From 4679feb3ce8e888917562357c547047c24b21df0 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Wed, 18 Mar 2020 10:31:53 -0700 Subject: [PATCH 159/492] Use same var key in _create_slots/get_slot in V1 optimizer We have special handling for distributed variable in get_slot, but not create_slot. This happens to work before but upcoming change in distributed library will break it. PiperOrigin-RevId: 301617592 Change-Id: I3324e926b7695f8ad0de696eecaff20df6c62ea7 --- tensorflow/python/BUILD | 4 +++ tensorflow/python/training/optimizer.py | 38 ++++++++++---------- tensorflow/python/training/optimizer_test.py | 26 ++++++++++++++ 3 files changed, 48 insertions(+), 20 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index d932899ab0d..94d52a8ab06 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5220,6 +5220,7 @@ py_library( "//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/keras/optimizer_v2:learning_rate_schedule", @@ -6590,6 +6591,9 @@ cuda_py_tests( ":variable_scope", ":variables", "//tensorflow/core:protos_all_py", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:values", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index f1a31d01dd4..f89dc362cf8 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -27,6 +27,7 @@ import six from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util +from tensorflow.python.distribute import values as ds_values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -81,10 +82,17 @@ def _deduplicate_indexed_slices(values, indices): def _var_key(var): - # TODO(ashankar): Consolidate handling for eager and graph - if hasattr(var, "op"): + """Returns slot key for `var`.""" + # pylint: disable=protected-access + if hasattr(var, "_distributed_container"): + var = var._distributed_container() + if ops.executing_eagerly_outside_functions(): + return var._unique_id + if ds_values.is_distributed_variable(var): + return (var.graph, var._shared_name) + else: return (var.op.graph, var.op.name) - return var._unique_id # pylint: disable=protected-access + # pylint: enable=protected-access @six.add_metaclass(abc.ABCMeta) @@ -751,26 +759,16 @@ class Optimizer( Returns: The `Variable` for the slot if it was created, `None` otherwise. """ - # pylint: disable=protected-access named_slots = self._slots.get(name, None) if not named_slots: return None - - if hasattr(var, "_distributed_container"): - # NOTE: If this isn't patched, then there is no `handle` in - # `_resource_apply_dense`. - distributed_container = var._distributed_container() - assert distributed_container is not None - if ops.executing_eagerly_outside_functions(): - key = distributed_container._unique_id - else: - key = (distributed_container.graph, distributed_container._shared_name) - # pylint: enable=protected-access - mirrored_slot = named_slots.get(key, None) - if mirrored_slot is None: return None - return mirrored_slot._get_closest() # pylint: disable=protected-access - - return named_slots.get(_var_key(var), None) + slot = named_slots.get(_var_key(var), None) + if (ds_values.is_distributed_variable(slot) and + not ds_values.is_distributed_variable(var)): + # Make sure var and slot are either both DistributedVariable, or both + # per replica variables. + slot = slot._get_closest() # pylint: disable=protected-access + return slot def get_slot_names(self): """Return a list of the names of slots created by the `Optimizer`. diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 5775d0b8091..30fa5cd0388 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import cross_device_ops +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import values as ds_values from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,6 +32,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent @@ -269,6 +273,28 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-0.1, -0.1], self.evaluate(var0)) self.assertAllClose([0., 0.], self.evaluate(var1)) + @test_util.run_deprecated_v1 + def testGetSlotUnderDistributedStrategy(self): + # Only run this test in graph mode so we don't need actual GPU. + ds = mirrored_strategy.MirroredStrategy( + ['CPU:0', 'GPU:0'], + cross_device_ops=cross_device_ops.HierarchicalCopyAllReduce()) + # We need an optimizer that creates slots. + optimizer = adam.AdamOptimizer() + + def f(): + v = variables.Variable([1.0]) + self.assertTrue(ds_values.is_distributed_variable(v)) + # Slot variables are created in the first call to apply_gradients. + optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)]) + self.assertTrue(optimizer.get_slot_names()) + for name in optimizer.get_slot_names(): + slot = optimizer.get_slot(v, name) + self.assertIsNotNone(slot) + self.assertTrue(ds_values.is_distributed_variable(slot)) + + ds.experimental_run_v2(f) + if __name__ == '__main__': test.main() From 85f7677b4ac3ebbe444711adfd7a45f18b1b6b2b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Mar 2020 10:56:23 -0700 Subject: [PATCH 160/492] [XLA:Python] Update tpu_driver to add the same automatic tupling of arguments and untupling of results present in the local client. Update tests to use the automatic untupling support. PiperOrigin-RevId: 301623333 Change-Id: I1233e6a63eaea2bfef2ac7a85bf1b55b820361d1 --- .../python/tpu_driver/client/tpu_client.cc | 58 ++- .../xla/python/tpu_driver/client/tpu_client.h | 12 +- .../tpu_driver/client/tpu_client_extension.cc | 6 +- tensorflow/compiler/xla/python/xla.cc | 35 -- tensorflow/compiler/xla/python/xla_client.py | 28 +- .../compiler/xla/python/xla_client_test.py | 387 +++++++++--------- 6 files changed, 275 insertions(+), 251 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 706db57c4ac..56ac640cb6c 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -227,8 +227,8 @@ StatusOr> PyTpuBuffer::FromLiterals( /* static */ StatusOr> PyTpuBuffer::MakeTuple( - const std::vector buffers, - std::shared_ptr client, std::shared_ptr device) { + absl::Span buffers, std::shared_ptr client, + std::shared_ptr device) { std::vector child_shapes; std::vector> child_device_buffers; std::vector child_handle_ptrs; @@ -611,8 +611,8 @@ Status WaitForExecuteEvent(tpu_driver::Event* event) { return opt_status.value(); } -StatusOr> PyTpuExecutable::Execute( - absl::Span argument_handles) { +StatusOr>> PyTpuExecutable::Execute( + absl::Span argument_handles, bool tuple_arguments) { if (num_replicas() != 1) { return InvalidArgument( "Attempted to execute computation with %d replicas using Execute().", @@ -624,9 +624,18 @@ StatusOr> PyTpuExecutable::Execute( num_partitions()); } - std::vector all_core_arguments(argument_handles.begin(), - argument_handles.end()); + std::vector all_core_arguments; + std::unique_ptr tupled_arguments; + if (tuple_arguments) { + TF_ASSIGN_OR_RETURN(tupled_arguments, + PyTpuBuffer::MakeTuple(argument_handles, client_, + local_devices_.front())); + all_core_arguments = {tupled_arguments.get()}; + } else { + all_core_arguments = std::vector(argument_handles.begin(), + argument_handles.end()); + } ExecuteResult result = ExecuteHelper(absl::MakeSpan(&all_core_arguments, 1), argument_handles, /*replica=*/0, /*partition=*/0, RunId()); @@ -638,12 +647,19 @@ StatusOr> PyTpuExecutable::Execute( return status; } - return std::move(result.buffer); + if (result.buffer->on_host_shape().IsTuple()) { + return result.buffer->DestructureTuple(); + } else { + std::vector> outputs; + outputs.push_back(std::move(result.buffer)); + return outputs; + } } -StatusOr>> +StatusOr>>> PyTpuExecutable::ExecuteOnLocalDevices( - absl::Span> argument_handles) { + absl::Span> argument_handles, + bool tuple_arguments) { tensorflow::profiler::TraceMe traceme( "PyTpuExecutable::ExecuteOnLocalDevices"); @@ -661,6 +677,20 @@ PyTpuExecutable::ExecuteOnLocalDevices( << " num_partitions=" << num_partitions() << " num_local_devices=" << num_local_devices; + std::vector> tupled_arguments; + std::vector> tupled_argument_pointers; + if (tuple_arguments) { + tupled_arguments.resize(argument_handles.size()); + tupled_argument_pointers.resize(argument_handles.size()); + for (int i = 0; i < num_local_devices; ++i) { + TF_ASSIGN_OR_RETURN(tupled_arguments[i], + PyTpuBuffer::MakeTuple(argument_handles[i], client_, + local_devices_.at(i))); + tupled_argument_pointers[i] = {tupled_arguments[i].get()}; + } + argument_handles = tupled_argument_pointers; + } + absl::Mutex results_lock; std::vector results(num_local_devices); @@ -702,9 +732,15 @@ PyTpuExecutable::ExecuteOnLocalDevices( } VLOG(1) << "Replicated execution complete."; - std::vector> wrapped_results(num_local_devices); + std::vector>> wrapped_results( + num_local_devices); for (int i = 0; i < num_local_devices; ++i) { - wrapped_results[i] = std::move(results[i].buffer); + if (results[i].buffer->on_host_shape().IsTuple()) { + TF_ASSIGN_OR_RETURN(wrapped_results[i], + results[i].buffer->DestructureTuple()); + } else { + wrapped_results[i].push_back(std::move(results[i].buffer)); + } } return wrapped_results; } diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 4b7670707fb..2b1ac4a3044 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -166,7 +166,7 @@ class PyTpuBuffer { // Supports nested tuple creation. static StatusOr> MakeTuple( - const std::vector buffers, + absl::Span buffers, std::shared_ptr client, std::shared_ptr device); PyTpuBuffer() = delete; @@ -308,15 +308,17 @@ class PyTpuExecutable { // TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait // inside for computation to finish. Coordinate with JAX code change to see if // we can make both Execute and ExecutePerReplica non-blocking. - StatusOr> Execute( - absl::Span argument_handles); + StatusOr>> Execute( + absl::Span argument_handles, bool tuple_arguments); // Execute on local devices. Takes a sequence of argument lists (one argument // list per local device) and returns a tuple of results (one result per local // device). The number of argument lists must be equal to the local device // count. - StatusOr>> ExecuteOnLocalDevices( - absl::Span> argument_handles); + StatusOr>>> + ExecuteOnLocalDevices( + absl::Span> argument_handles, + bool tuple_arguments); void Delete() { executables_.clear(); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 0dcb9dc4c84..b4e8afb5853 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -203,9 +203,11 @@ PYBIND11_MODULE(tpu_client_extension, m) { &PyTpuExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyTpuExecutable::Delete) .def("Execute", &PyTpuExecutable::Execute, - py::call_guard(), py::arg("arguments")) + py::call_guard(), py::arg("arguments"), + py::arg("tuple_arguments")) .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, - py::call_guard(), py::arg("arguments")); + py::call_guard(), py::arg("arguments"), + py::arg("tuple_arguments")); py::class_>(m, "TpuDevice") .def_property_readonly("coords", &TpuDevice::coords) diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index ceea02f2374..d42636cde79 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -1133,20 +1133,6 @@ PYBIND11_MODULE(xla_extension, m) { .def("SizeOfGeneratedCodeInBytes", &PyLocalExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyLocalExecutable::Delete) - .def( - "Execute", - [](const PyLocalExecutable& executable, - absl::Span args) - -> StatusOr> { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN( - std::vector> output, - executable.Execute(args, ExecuteOptions())); - return WrapWithClient(executable.client()->shared_from_this(), - std::move(output.front())); - }, - py::arg("arguments")) - // TODO(phawkins): remove in favor of overload that returns a vector. .def( "Execute", [](const PyLocalExecutable& executable, @@ -1168,27 +1154,6 @@ PYBIND11_MODULE(xla_extension, m) { return outputs; }, py::arg("arguments"), py::arg("tuple_arguments")) - // TODO(phawkins): remove in favor of overload that returns a vector. - .def( - "ExecuteOnLocalDevices", - [](const PyLocalExecutable& executable, - absl::Span> args) - -> StatusOr>> { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN( - std::vector>> - output_buffers, - executable.ExecuteOnLocalDevices(args, ExecuteOptions())); - std::vector> outputs; - outputs.reserve(output_buffers.size()); - for (auto& buffers : output_buffers) { - outputs.push_back( - WrapWithClient(executable.client()->shared_from_this(), - std::move(buffers.front()))); - } - return outputs; - }, - py::arg("arguments")) .def( "ExecuteOnLocalDevices", [](const PyLocalExecutable& executable, diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index b6948b6d84d..d4df503677c 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -42,7 +42,6 @@ from tensorflow.compiler.xla.python.xla_extension import ops # consistency with XLA. # pylint: disable=invalid-name - profiler = _xla.profiler @@ -454,8 +453,8 @@ def transfer_to_infeed(value, device=None): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - device: the device to infeed the value to. Each device has a - distinct infeed queue. + device: the device to infeed the value to. Each device has a distinct infeed + queue. """ # TODO(phawkins): support non-default backends. backend = get_local_backend() @@ -501,7 +500,6 @@ def computation_count(): '''Returns the number of computations per replica.''' """ - Device = _xla.Device @@ -633,7 +631,8 @@ def execute_with_python_values(executable, arguments=(), backend=None): arg, device=executable.local_devices()[0], backend=backend) arguments = [put(arg) for arg in arguments] - return executable.Execute(arguments).to_py() + outputs = executable.Execute(arguments, tuple_arguments=False) + return [x.to_py() for x in outputs] def execute_with_python_values_replicated(executable, arguments, backend=None): @@ -641,8 +640,8 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): Arguments: executable: the program to run. - arguments: a list of lists of Python values indexed by - `[replica][arg_num]` to pass as inputs. + arguments: a list of lists of Python values indexed by `[replica][arg_num]` + to pass as inputs. backend: the backend we are targeting. Returns: @@ -661,7 +660,8 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): for replica_args in arguments: arg_buffers.append(flat_arg_buffers[:len(replica_args)]) flat_arg_buffers = flat_arg_buffers[len(replica_args):] - return [out.to_py() for out in executable.ExecuteOnLocalDevices(arg_buffers)] + return [[x.to_py() for x in xs] for xs in executable.ExecuteOnLocalDevices( + arg_buffers, tuple_arguments=False)] class PaddingType(enum.Enum): @@ -787,6 +787,7 @@ class ComputationBuilder(object): shape: a `Shape` describing the shape of the infed value. token: an optional `XlaOp` representing a token after which the infeed effect should be sequenced. + Returns: An XlaOp, representing a (value, token) pair. """ @@ -805,6 +806,7 @@ class ComputationBuilder(object): operand: an `XlaOp` representing the data to outfeed. token: an `XlaOp` representing a token after which the outfeed should be sequenced. + Returns: An `XlaOp` representing a token. """ @@ -880,7 +882,10 @@ class ComputationBuilder(object): """ return self.Constant(np.array(value, dtype=np.bool)) - def ParameterWithShape(self, shape, name=None, parameter_num=None, + def ParameterWithShape(self, + shape, + name=None, + parameter_num=None, replicated=False): """Enqueues a Parameter op onto the computation, given a shape. @@ -891,8 +896,8 @@ class ComputationBuilder(object): next linear parameter number is used. The default value capability can be used for auto-numbering. If you're using auto-numbering for some parameters, use it for *all* parameters to avoid clashes. - replicated: whether to mark the parameter's leaves as replicated. May be - a bool, in which case it applies to all leaves, or an iterable of bools. + replicated: whether to mark the parameter's leaves as replicated. May be a + bool, in which case it applies to all leaves, or an iterable of bools. Returns: An XlaOp. @@ -1791,6 +1796,7 @@ def register_custom_call_target(name, fn, platform='cpu'): """ _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform]) + # Deprecated. Use register_custom_call_target instead. register_cpu_custom_call_target = register_custom_call_target diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 72b536ade68..b28a97837fe 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -55,12 +55,14 @@ class ComputationTest(absltest.TestCase): def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): assert expected is not None - result = self._Execute(c, arguments) - # Numpy's comparison methods are a bit too lenient by treating inputs as - # "array-like", meaning that scalar 4 will be happily compared equal to - # [[4]]. We'd like to be more strict so assert shapes as well. - self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) - assert_func(result, expected) + results = self._Execute(c, arguments) + self.assertLen(results, len(expected)) + for result, e in zip(results, expected): + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) + assert_func(result, e) def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) @@ -166,32 +168,32 @@ class ComputationsWithConstantsTest(ComputationTest): def testConstantScalarSumS8(self): c = self._NewComputation() c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) - self._ExecuteAndCompareExact(c, expected=np.int8(3)) + self._ExecuteAndCompareExact(c, expected=[np.int8(3)]) def testConstantScalarSumBF16(self): c = self._NewComputation() c.Add(c.Constant(bfloat16(1.11)), c.Constant(bfloat16(3.14))) - self._ExecuteAndCompareClose(c, expected=bfloat16(4.25)) + self._ExecuteAndCompareClose(c, expected=[bfloat16(4.25)]) def testConstantScalarSumF32(self): c = self._NewComputation() c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=4.25) + self._ExecuteAndCompareClose(c, expected=[4.25]) def testConstantScalarSumF64(self): c = self._NewComputation() c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=4.25) + self._ExecuteAndCompareClose(c, expected=[4.25]) def testConstantScalarSumS32(self): c = self._NewComputation() c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) - self._ExecuteAndCompareClose(c, expected=3) + self._ExecuteAndCompareClose(c, expected=[3]) def testConstantScalarSumS64(self): c = self._NewComputation() c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) - self._ExecuteAndCompareClose(c, expected=3) + self._ExecuteAndCompareClose(c, expected=[3]) def testConstantVectorMulF16(self): c = self._NewComputation() @@ -199,108 +201,108 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(np.array([2.5, 3.3, -1.2, 0.7], np.float16)), c.Constant(np.array([-1.2, 2, -2, -3], np.float16))) self._ExecuteAndCompareClose( - c, expected=np.array([-3, 6.6, 2.4, -2.1], np.float16), rtol=2e-3) + c, expected=[np.array([-3, 6.6, 2.4, -2.1], np.float16)], rtol=2e-3) def testConstantVectorMulF32(self): c = self._NewComputation() c.Mul( c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) - self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + self._ExecuteAndCompareClose(c, expected=[[-3, 6.6, 2.4, -2.1]]) def testConstantVectorMulF64(self): c = self._NewComputation() c.Mul( c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) - self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + self._ExecuteAndCompareClose(c, expected=[[-3, 6.6, 2.4, -2.1]]) def testConstantVectorScalarDivF32(self): c = self._NewComputation() c.Div( c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), c.ConstantF32Scalar(2.0)) - self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + self._ExecuteAndCompareClose(c, expected=[[0.75, 1.25, 1.5, -5.4]]) def testConstantVectorScalarDivF64(self): c = self._NewComputation() c.Div( c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), c.ConstantF64Scalar(2.0)) - self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + self._ExecuteAndCompareClose(c, expected=[[0.75, 1.25, 1.5, -5.4]]) def testConstantVectorScalarPowF32(self): c = self._NewComputation() c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) - self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) def testConstantVectorScalarPowF64(self): c = self._NewComputation() c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) - self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) def testIota(self): c = self._NewComputation() c.Iota(np.float32, 10) - self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) + self._ExecuteAndCompareExact(c, expected=[np.arange(10, dtype=np.float32)]) def testBroadcastedIota(self): c = self._NewComputation() c.BroadcastedIota(np.int64, (2, 3), 1) expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) - self._ExecuteAndCompareExact(c, expected=expected) + self._ExecuteAndCompareExact(c, expected=[expected]) def testBooleanAnd(self): c = self._NewComputation() c.And( c.Constant(NumpyArrayBool([True, False, True, False])), c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) def testBooleanOr(self): c = self._NewComputation() c.Or( c.Constant(NumpyArrayBool([True, False, True, False])), c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) def testBooleanXor(self): c = self._NewComputation() c.Xor( c.Constant(NumpyArrayBool([True, False, True, False])), c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) def testSum2DF32(self): c = self._NewComputation() c.Add( c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) - self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) def testShiftLeft(self): c = self._NewComputation() c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2]))) - self._ExecuteAndCompareClose(c, expected=[12]) + self._ExecuteAndCompareClose(c, expected=[[12]]) def testShiftRightArithmetic(self): c = self._NewComputation() c.ShiftRightArithmetic( c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1]))) - self._ExecuteAndCompareClose(c, expected=[-1]) + self._ExecuteAndCompareClose(c, expected=[[-1]]) def testShiftRightLogical(self): c = self._NewComputation() c.ShiftRightLogical( c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1]))) - self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) + self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) def testSum2DF64(self): c = self._NewComputation() c.Add( c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) - self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) def testSum2DWith1DBroadcastDim0F32(self): # sum of a 2D array with a 1D array where the latter is replicated across @@ -311,7 +313,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(NumpyArrayF32([10, 20, 30])), broadcast_dimensions=(0,)) self._ExecuteAndCompareClose( - c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) def testSum2DWith1DBroadcastDim0F64(self): # sum of a 2D array with a 1D array where the latter is replicated across @@ -322,7 +324,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(NumpyArrayF64([10, 20, 30])), broadcast_dimensions=(0,)) self._ExecuteAndCompareClose( - c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) def testSum2DWith1DBroadcastDim1F32(self): # sum of a 2D array with a 1D array where the latter is replicated across @@ -333,7 +335,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(NumpyArrayF32([10, 20, 30])), broadcast_dimensions=(1,)) self._ExecuteAndCompareClose( - c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) def testSum2DWith1DBroadcastDim1F64(self): # sum of a 2D array with a 1D array where the latter is replicated across @@ -344,7 +346,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(NumpyArrayF64([10, 20, 30])), broadcast_dimensions=(1,)) self._ExecuteAndCompareClose( - c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) def testConstantAxpyF32(self): c = self._NewComputation() @@ -353,7 +355,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.ConstantF32Scalar(2), c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), c.Constant(NumpyArrayF32([100, -100, 200, -200]))) - self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + self._ExecuteAndCompareClose(c, expected=[[104.4, -93.4, 208.8, -189]]) def testConstantAxpyF64(self): c = self._NewComputation() @@ -362,7 +364,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.ConstantF64Scalar(2), c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), c.Constant(NumpyArrayF64([100, -100, 200, -200]))) - self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + self._ExecuteAndCompareClose(c, expected=[[104.4, -93.4, 208.8, -189]]) def testCustomCall(self): c = self._NewComputation() @@ -376,7 +378,7 @@ class ComputationsWithConstantsTest(ComputationTest): xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), )) - self._ExecuteAndCompareClose(c, expected=0.75) + self._ExecuteAndCompareClose(c, expected=[0.75]) class ParametersTest(ComputationTest): @@ -400,7 +402,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareClose( c, arguments=[self.f32_scalar_2, self.f32_4vector], - expected=[-4.6, 6.6, -8.6, 10.6]) + expected=[[-4.6, 6.6, -8.6, 10.6]]) def testScalarTimesVectorAutonumberF64(self): c = self._NewComputation() @@ -410,7 +412,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareClose( c, arguments=[self.f64_scalar_2, self.f64_4vector], - expected=[-4.6, 6.6, -8.6, 10.6]) + expected=[[-4.6, 6.6, -8.6, 10.6]]) def testScalarTimesVectorS32(self): c = self._NewComputation() @@ -420,7 +422,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareExact( c, arguments=[self.s32_scalar_3, self.s32_4vector], - expected=[30, 45, -6, 21]) + expected=[[30, 45, -6, 21]]) def testScalarTimesVectorS64(self): c = self._NewComputation() @@ -430,7 +432,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareExact( c, arguments=[self.s64_scalar_3, self.s64_4vector], - expected=[30, 45, -6, 21]) + expected=[[30, 45, -6, 21]]) def testScalarMinusVectorExplicitNumberingF32(self): # Use explicit numbering and pass parameter_num first. Sub is used since @@ -443,7 +445,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareClose( c, arguments=[self.f32_scalar_2, self.f32_4vector], - expected=[-4.3, 1.3, -6.3, 3.3]) + expected=[[-4.3, 1.3, -6.3, 3.3]]) def testScalarMinusVectorExplicitNumberingF64(self): # Use explicit numbering and pass parameter_num first. Sub is used since @@ -456,28 +458,22 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareClose( c, arguments=[self.f64_scalar_2, self.f64_4vector], - expected=[-4.3, 1.3, -6.3, 3.3]) + expected=[[-4.3, 1.3, -6.3, 3.3]]) class BufferTest(ComputationTest): """Tests focusing on execution with Buffers.""" - def _Execute(self, c, arguments): - compiled_c = c.Build().Compile() - arg_buffers = [xla_client.Buffer.from_pyval(arg) for arg in arguments] - result_buffer = compiled_c.Execute(arg_buffers) - return result_buffer.to_py() - def testConstantSum(self): c = self._NewComputation() c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=4.25) + self._ExecuteAndCompareClose(c, expected=[4.25]) def testOneParameterSum(self): c = self._NewComputation() c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) self._ExecuteAndCompareClose( - c, arguments=[NumpyArrayF32(1.11)], expected=4.25) + c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) def testTwoParameterSum(self): c = self._NewComputation() @@ -485,8 +481,10 @@ class BufferTest(ComputationTest): c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ParameterFromNumpy(NumpyArrayF32(0.))) self._ExecuteAndCompareClose( - c, arguments=[NumpyArrayF32(1.11), - NumpyArrayF32(3.14)], expected=4.25) + c, + arguments=[NumpyArrayF32(1.11), + NumpyArrayF32(3.14)], + expected=[4.25]) def testCannotCallWithDeletedBuffers(self): c = self._NewComputation() @@ -496,7 +494,7 @@ class BufferTest(ComputationTest): arg_buffer = xla_client.Buffer.from_pyval(arg) arg_buffer.delete() with self.assertRaises(RuntimeError): - compiled_c.Execute([arg_buffer]) + compiled_c.Execute([arg_buffer], tuple_arguments=False) def testDestructureTupleEmpty(self): device = xla_client.get_local_backend().devices()[0] @@ -646,7 +644,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])), ) c.Concatenate(args, dimension=0) - self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]) def testConcatenateF64(self): c = self._NewComputation() @@ -655,7 +653,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])), ) c.Concatenate(args, dimension=0) - self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]) def testConvertElementType(self): xla_types = { @@ -672,11 +670,12 @@ class SingleOpTest(ComputationTest): c.ConvertElementType(x, xla_types[dst_dtype]) result = xla_client.execute_with_python_values(c.Build().Compile()) + self.assertLen(result, 1) expected = np.array(template, dtype=dst_dtype) - self.assertEqual(result.shape, expected.shape) - self.assertEqual(result.dtype, expected.dtype) - np.testing.assert_equal(result, expected) + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) x = [0, 1, 0, 0, 1] for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): @@ -699,11 +698,12 @@ class SingleOpTest(ComputationTest): c.BitcastConvertType(x, dst_etype) result = xla_client.execute_with_python_values(c.Build().Compile()) + self.assertLen(result, 1) expected = np.array(template, src_dtype).view(dst_dtype) - self.assertEqual(result.shape, expected.shape) - self.assertEqual(result.dtype, expected.dtype) - np.testing.assert_equal(result, expected) + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) x = [0, 1, 0, 0, 1] for xla_types in [xla_x32_types, xla_x64_types]: @@ -720,7 +720,7 @@ class SingleOpTest(ComputationTest): for lhs in samples[:1]: c = self._NewComputation() c.AllToAll(c.Constant(lhs), 0, 0) - self._ExecuteAndCompareExact(c, expected=lhs) + self._ExecuteAndCompareExact(c, expected=[lhs]) def testCrossReplicaSumOneReplica(self): samples = [ @@ -732,12 +732,12 @@ class SingleOpTest(ComputationTest): for lhs in samples: c = self._NewComputation() c.CrossReplicaSum(c.Constant(lhs)) - self._ExecuteAndCompareExact(c, expected=lhs) + self._ExecuteAndCompareExact(c, expected=[lhs]) def testReplicaId(self): c = self._NewComputation() _ = c.ReplicaId() - self._ExecuteAndCompareExact(c, expected=0) + self._ExecuteAndCompareExact(c, expected=[0]) def testCrossReplicaSumOneReplicaWithSingletonGroup(self): samples = [ @@ -749,35 +749,35 @@ class SingleOpTest(ComputationTest): for lhs in samples: c = self._NewComputation() c.CrossReplicaSum(c.Constant(lhs), [[0]]) - self._ExecuteAndCompareExact(c, expected=lhs) + self._ExecuteAndCompareExact(c, expected=[lhs]) def testDotMatrixVectorF32(self): c = self._NewComputation() lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) rhs = NumpyArrayF32([[10.0], [20.0]]) c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) def testDotMatrixVectorF64(self): c = self._NewComputation() lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) rhs = NumpyArrayF64([[10.0], [20.0]]) c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) def testDotMatrixMatrixF32(self): c = self._NewComputation() lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) def testDotMatrixMatrixF64(self): c = self._NewComputation() lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) def testDotGeneral(self): c = self._NewComputation() @@ -786,7 +786,7 @@ class SingleOpTest(ComputationTest): rhs = NumpyArrayF32(rng.randn(10, 4, 5)) dimension_numbers = (([2], [1]), ([0], [0])) c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) def testDotGeneralWithDotDimensionNumbersProto(self): c = self._NewComputation() @@ -801,7 +801,7 @@ class SingleOpTest(ComputationTest): dimension_numbers.rhs_batch_dimensions.append(0) c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) def testDotGeneralWithPrecisionConfig(self): c = self._NewComputation() @@ -817,7 +817,7 @@ class SingleOpTest(ComputationTest): c.Constant(rhs), dimension_numbers, precision_config=config) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) def testConvF32Same(self): c = self._NewComputation() @@ -831,7 +831,7 @@ class SingleOpTest(ComputationTest): [880., 940., 1000., 380.], [1120., 1180., 1240., 460.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvF32Valid(self): c = self._NewComputation() @@ -844,7 +844,7 @@ class SingleOpTest(ComputationTest): [640., 700., 760.], [1120., 1180., 1240.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvWithGeneralPaddingF32(self): c = self._NewComputation() @@ -864,7 +864,7 @@ class SingleOpTest(ComputationTest): [0., 0., 0.], [40., 50., 0.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvGeneralDilatedF32(self): c = self._NewComputation() @@ -885,7 +885,7 @@ class SingleOpTest(ComputationTest): [0., 0., 0.], [40., 50., 0.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvGeneralDilatedF32WithPrecisionConfig(self): c = self._NewComputation() @@ -915,7 +915,7 @@ class SingleOpTest(ComputationTest): [0., 0., 0.], [40., 50., 0.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvGeneralDilatedPermutedF32(self): c = self._NewComputation() @@ -933,7 +933,8 @@ class SingleOpTest(ComputationTest): pads, lhs_dilation, rhs_dilation, dimension_numbers) result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], [40., 50., 0.]]]]) - self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + self._ExecuteAndCompareClose( + c, expected=[np.transpose(result, (1, 3, 0, 2))]) def testConvGeneralDilatedGroupedConvolutionF32(self): c = self._NewComputation() @@ -960,92 +961,92 @@ class SingleOpTest(ComputationTest): [0., 0., 0.], [480., 530., 220.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) c.Not(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=~arr) + self._ExecuteAndCompareClose(c, expected=[~arr]) def testPopulationCount(self): c = self._NewComputation() arr = NumpyArrayS32([3, 0, 1]) c.PopulationCount(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.array([2, 0, 1])) + self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) def testCountLeadingZeros(self): c = self._NewComputation() arr = NumpyArrayS32([0x7FFF, 0x12345678]) c.Clz(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[17, 3]) + self._ExecuteAndCompareClose(c, expected=[[17, 3]]) def testExp(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Exp(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) def testExpm1(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Expm1(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) def testRound(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Round(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.round(arr)) + self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) def testLog(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Log(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.log(arr)) + self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) def testLog1p(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Log1p(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) + self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) def testNeg(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Neg(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=-arr) + self._ExecuteAndCompareClose(c, expected=[-arr]) def testFloor(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Floor(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.floor(arr)) + self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) def testCeil(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Ceil(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) + self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) def testAbs(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) c.Abs(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.abs(arr)) + self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) def testTanh(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Tanh(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) def testTrans(self): def _TransposeAndTest(array): c = self._NewComputation() c.Trans(c.Constant(array)) - self._ExecuteAndCompareClose(c, expected=array.T) + self._ExecuteAndCompareClose(c, expected=[array.T]) # Test square and non-square matrices in both default (C) and F orders. for array_fun in [NumpyArrayF32, NumpyArrayF64]: @@ -1060,7 +1061,7 @@ class SingleOpTest(ComputationTest): c = self._NewComputation() c.Transpose(c.Constant(array), permutation) expected = np.transpose(array, permutation) - self._ExecuteAndCompareClose(c, expected=expected) + self._ExecuteAndCompareClose(c, expected=[expected]) _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) @@ -1077,14 +1078,14 @@ class SingleOpTest(ComputationTest): c.Eq( c.Constant(NumpyArrayS32([1, 2, 3, 4])), c.Constant(NumpyArrayS32([4, 2, 3, 1]))) - self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) def testNe(self): c = self._NewComputation() c.Ne( c.Constant(NumpyArrayS32([1, 2, 3, 4])), c.Constant(NumpyArrayS32([4, 2, 3, 1]))) - self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) c.Ne( c.Constant(NumpyArrayF32([-2.0, 0.0, @@ -1092,42 +1093,44 @@ class SingleOpTest(ComputationTest): float("nan")])), c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) self._ExecuteAndAssertWith( - np.testing.assert_allclose, c, (), expected=[True, False, True, True]) + np.testing.assert_allclose, c, (), expected=[[True, False, True, True]]) def testGt(self): c = self._NewComputation() c.Gt( c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) + self._ExecuteAndCompareExact( + c, expected=[[False, True, True, False, False]]) def testGe(self): c = self._NewComputation() c.Ge( c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) + self._ExecuteAndCompareExact(c, expected=[[True, True, True, False, False]]) def testLt(self): c = self._NewComputation() c.Lt( c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) + self._ExecuteAndCompareExact( + c, expected=[[False, False, False, True, True]]) def testLe(self): c = self._NewComputation() c.Le( c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, True, True]]) def testMax(self): c = self._NewComputation() c.Max( c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) - self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) + self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) def testMaxExplicitBroadcastDim0(self): c = self._NewComputation() @@ -1135,7 +1138,8 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), c.Constant(NumpyArrayF32([3, 4, 5])), broadcast_dimensions=(0,)) - self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) + self._ExecuteAndCompareExact( + c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) def testMaxExplicitBroadcastDim1(self): c = self._NewComputation() @@ -1143,14 +1147,15 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), c.Constant(NumpyArrayF32([3, 4, 5])), broadcast_dimensions=(1,)) - self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) + self._ExecuteAndCompareExact( + c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) def testMin(self): c = self._NewComputation() c.Min( c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) - self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) + self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) def testPad(self): c = self._NewComputation() @@ -1159,8 +1164,8 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)]) self._ExecuteAndCompareClose( c, - expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) def testPadWithPaddingConfig(self): c = self._NewComputation() @@ -1176,8 +1181,8 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32(0.0)), padding_config) self._ExecuteAndCompareClose( c, - expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) def testReshape(self): c = self._NewComputation() @@ -1185,14 +1190,14 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), dimensions=[0, 1], new_sizes=[2, 3]) - self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) def testCollapse(self): c = self._NewComputation() c.Collapse( c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), dimensions=[1, 2]) - self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) def testRev(self): c = self._NewComputation() @@ -1200,7 +1205,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), dimensions=[0, 2]) self._ExecuteAndCompareExact( - c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) + c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) def testReducePrecision(self): c = self._NewComputation() @@ -1208,7 +1213,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), exponent_bits=8, mantissa_bits=7) - self._ExecuteAndCompareClose(c, expected=[float.fromhex("0x1.32p-3")]) + self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) def testClampF32(self): c = self._NewComputation() @@ -1216,7 +1221,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32(-1)), c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), c.Constant(NumpyArrayF32(2))) - self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) def testClampS32(self): c = self._NewComputation() @@ -1224,7 +1229,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayS32(-1)), c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), c.Constant(NumpyArrayS32(2))) - self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) def testSelect(self): c = self._NewComputation() @@ -1232,14 +1237,14 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayBool([True, False, False, True, False])), c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) - self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) + self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) def testSlice(self): c = self._NewComputation() c.Slice( c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], [3, 2]) - self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) def testSliceInDim(self): c = self._NewComputation() @@ -1249,21 +1254,21 @@ class SingleOpTest(ComputationTest): limit_index=2, stride=1, dimno=1) - self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]]) + self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) c.SliceInDim( c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), start_index=0, limit_index=3, stride=2, dimno=0) - self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) def testDynamicSlice(self): c = self._NewComputation() c.DynamicSlice( c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), c.Constant(NumpyArrayS32([1, 0])), [2, 2]) - self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) def testDynamicUpdateSlice(self): c = self._NewComputation() @@ -1271,7 +1276,8 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), c.Constant(NumpyArrayS32([1, 1]))) - self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) + self._ExecuteAndCompareExact( + c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) def testTuple(self): c = self._NewComputation() @@ -1279,7 +1285,7 @@ class SingleOpTest(ComputationTest): c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), c.Constant(NumpyArrayBool([True, False, False, True]))) result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertIsInstance(result, tuple) + self.assertLen(result, 3) np.testing.assert_equal(result[0], 42) np.testing.assert_allclose(result[1], [1.0, 2.0]) np.testing.assert_equal(result[2], [True, False, False, True]) @@ -1290,20 +1296,20 @@ class SingleOpTest(ComputationTest): c.Tuple( c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), c.Constant(NumpyArrayBool([True, False, False, True]))), 1) - self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) def testBroadcast(self): c = self._NewComputation() c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) self._ExecuteAndCompareExact( - c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) + c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) def testBroadcastInDim(self): c = self._NewComputation() c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0]) - self._ExecuteAndCompareExact(c, expected=[[1, 1], [2, 2]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1]) - self._ExecuteAndCompareExact(c, expected=[[1, 2], [1, 2]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) def testRngNormal(self): shape = (2, 3) @@ -1314,8 +1320,9 @@ class SingleOpTest(ComputationTest): dims=shape) result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape and uniqueness - self.assertEqual(result.shape, shape) - self.assertLen(np.unique(result), np.prod(shape)) + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) def testRngUniformF32(self): lo, hi = 2., 4. @@ -1327,10 +1334,11 @@ class SingleOpTest(ComputationTest): dims=shape) result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape, uniqueness, and range - self.assertEqual(result.shape, shape) - self.assertLen(np.unique(result), np.prod(shape)) - self.assertTrue(np.all(lo <= result)) - self.assertTrue(np.all(result < hi)) + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) def testRngUniformS32(self): lo, hi = 2, 4 @@ -1342,24 +1350,25 @@ class SingleOpTest(ComputationTest): dims=shape) result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape, integrality, and range - self.assertEqual(result.shape, shape) - self.assertEqual(result.dtype, np.int32) - self.assertTrue(np.all(lo <= result)) - self.assertTrue(np.all(result < hi)) + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertEqual(result[0].dtype, np.int32) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) def testCholesky(self): l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], dtype=np.float32) c = self._NewComputation() c.Cholesky(c.Constant(np.dot(l, l.T))) - self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) + self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) def testSort(self): keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) c = self._NewComputation() c.Sort(c.Constant(keys)) self._ExecuteAndCompareClose( - c, expected=np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)) + c, expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) def testSortKeyVal(self): keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) @@ -1367,7 +1376,7 @@ class SingleOpTest(ComputationTest): c = self._NewComputation() c.Sort((c.Constant(keys), c.Constant(values)), dimension=0) result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertIsInstance(result, tuple) + self.assertLen(result, 2) np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) @@ -1387,7 +1396,7 @@ class SingleOpTest(ComputationTest): dimension=1, comparator=comparator) result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertIsInstance(result, tuple) + self.assertLen(result, 2) np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) @@ -1437,12 +1446,14 @@ class SingleOpTest(ComputationTest): transpose_a=True) self._ExecuteAndCompareClose( c, - expected=np.array([ - [0.5, 0.08333334, 0.04629629, 0.03367003], - [2.5, -0.25, -0.1388889, -0.1010101], - [4.5, -0.58333331, -0.32407406, -0.23569024], + expected=[ + np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], + dtype=np.float32) ], - dtype=np.float32), rtol=1e-4) def testIsConstant(self): @@ -1467,7 +1478,7 @@ class SingleOpTest(ComputationTest): dnums.index_vector_dim = 2 c = self._NewComputation() c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) - g = self._Execute(c, ()) + g, = self._Execute(c, ()) expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) np.testing.assert_allclose(g, expected, rtol=1e-4) @@ -1480,30 +1491,30 @@ class SingleOpTest(ComputationTest): c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:]) self._ExecuteAndCompareClose( - c, expected=np.fft.fftn(a, axes=(1, 2, 3)), rtol=1e-4) + c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) # IFFT c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:]) self._ExecuteAndCompareClose( - c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), rtol=1e-4) + c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) # RFFT b = rng.randn(*shape).astype(np.float32) c = self._NewComputation() c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:]) self._ExecuteAndCompareClose( - c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), rtol=1e-4) + c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) # IRFFT c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8]) self._ExecuteAndCompareClose( - c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), rtol=1e-4) + c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4) def testNextAfter(self): c = self._NewComputation() c.NextAfter( c.Constant(np.array([1, 2], dtype=np.float32)), c.Constant(np.array([2, 1], dtype=np.float32))) - out = self._Execute(c, ()) + out, = self._Execute(c, ()) eps = np.finfo(np.float32).eps np.testing.assert_equal(np.array([eps + 1, 2 - eps], dtype=np.float32), out) @@ -1515,7 +1526,7 @@ class SingleOpTest(ComputationTest): c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x)) expected = np.array( [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) - self._ExecuteAndCompareClose(c, expected=expected, rtol=1e-4) + self._ExecuteAndCompareClose(c, expected=[expected], rtol=1e-4) class EmbeddedComputationsTest(ComputationTest): @@ -1656,38 +1667,38 @@ class EmbeddedComputationsTest(ComputationTest): c.Call( self._CreateMulF32By2Computation(), operands=(c.ConstantF32Scalar(5.0),)) - self._ExecuteAndCompareClose(c, expected=10.0) + self._ExecuteAndCompareClose(c, expected=[10.0]) def testCallF64(self): c = self._NewComputation() c.Call( self._CreateMulF64By2Computation(), operands=(c.ConstantF64Scalar(5.0),)) - self._ExecuteAndCompareClose(c, expected=10.0) + self._ExecuteAndCompareClose(c, expected=[10.0]) def testMapEachElementToS32Constant(self): c = self._NewComputation() c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], self._CreateConstantS32Computation(), [0]) - self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) def testMapEachElementToS64Constant(self): c = self._NewComputation() c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], self._CreateConstantS64Computation(), [0]) - self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) def testMapMulBy2F32(self): c = self._NewComputation() c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], self._CreateMulF32By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) def testMapMulBy2F64(self): c = self._NewComputation() c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], self._CreateMulF64By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) def testSimpleMapChainF32(self): # Chains a map of constant-f32 with a map of mul-by-2 @@ -1695,7 +1706,7 @@ class EmbeddedComputationsTest(ComputationTest): const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], self._CreateConstantF32Computation(), [0]) c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) def testSimpleMapChainF64(self): # Chains a map of constant-f64 with a map of mul-by-2 @@ -1703,21 +1714,21 @@ class EmbeddedComputationsTest(ComputationTest): const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], self._CreateConstantF64Computation(), [0]) c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) def testDivVectorsWithMapF32(self): c = self._NewComputation() c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), self._CreateBinaryDivF32Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + self._ExecuteAndCompareClose(c, expected=[[0.2, 0.4, 0.75, 1.0]]) def testDivVectorsWithMapF64(self): c = self._NewComputation() c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), self._CreateBinaryDivF64Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + self._ExecuteAndCompareClose(c, expected=[[0.2, 0.4, 0.75, 1.0]]) def testSelectAndScatterF32(self): c = self._NewComputation() @@ -1730,7 +1741,7 @@ class EmbeddedComputationsTest(ComputationTest): source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), init_value=c.Constant(NumpyArrayF32(1)), scatter=self._CreateBinaryAddF32Computation()) - self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) + self._ExecuteAndCompareClose(c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]]) def testSelectAndScatterF64(self): c = self._NewComputation() @@ -1743,7 +1754,7 @@ class EmbeddedComputationsTest(ComputationTest): source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), init_value=c.Constant(NumpyArrayF64(1)), scatter=self._CreateBinaryAddF64Computation()) - self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) + self._ExecuteAndCompareClose(c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]]) def testReduce1DtoScalarF32(self): c = self._NewComputation() @@ -1752,7 +1763,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF32Scalar(0), computation_to_apply=self._CreateBinaryAddF32Computation(), dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=10) + self._ExecuteAndCompareClose(c, expected=[10]) def testReduce1DtoScalarF64(self): c = self._NewComputation() @@ -1761,7 +1772,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF64Scalar(0), computation_to_apply=self._CreateBinaryAddF64Computation(), dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=10) + self._ExecuteAndCompareClose(c, expected=[10]) def testReduce2DTo1DDim0F32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1771,7 +1782,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF32Scalar(0), computation_to_apply=self._CreateBinaryAddF32Computation(), dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + self._ExecuteAndCompareClose(c, expected=[[5, 7, 9]]) def testReduce2DTo1DDim0F64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1781,7 +1792,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF64Scalar(0), computation_to_apply=self._CreateBinaryAddF64Computation(), dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + self._ExecuteAndCompareClose(c, expected=[[5, 7, 9]]) def testReduce2DTo1DDim1F32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1791,7 +1802,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF32Scalar(0), computation_to_apply=self._CreateBinaryAddF32Computation(), dimensions=[1]) - self._ExecuteAndCompareClose(c, expected=[6, 15]) + self._ExecuteAndCompareClose(c, expected=[[6, 15]]) def testReduce2DTo1DDim1F64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1801,7 +1812,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF64Scalar(0), computation_to_apply=self._CreateBinaryAddF64Computation(), dimensions=[1]) - self._ExecuteAndCompareClose(c, expected=[6, 15]) + self._ExecuteAndCompareClose(c, expected=[[6, 15]]) def testReduce3DAllPossibleWaysF32(self): input_array = self._MakeSample3DArrayF32() @@ -1814,7 +1825,7 @@ class EmbeddedComputationsTest(ComputationTest): computation_to_apply=self._CreateBinaryAddF32Computation(), dimensions=dims) self._ExecuteAndCompareClose( - c, expected=np.sum(input_array, axis=tuple(dims))) + c, expected=[np.sum(input_array, axis=tuple(dims))]) _ReduceAndTest(0) _ReduceAndTest(0, 1) @@ -1833,7 +1844,7 @@ class EmbeddedComputationsTest(ComputationTest): computation_to_apply=self._CreateBinaryAddF64Computation(), dimensions=dims) self._ExecuteAndCompareClose( - c, expected=np.sum(input_array, axis=tuple(dims))) + c, expected=[np.sum(input_array, axis=tuple(dims))]) _ReduceAndTest(0) _ReduceAndTest(0) @@ -1852,7 +1863,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 1), padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) def testReduceWindowSameUnitStridesF32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1864,7 +1875,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 1), padding=xla_client.PaddingType.SAME) - self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) def testReduceWindowValidGeneralStridesF32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1876,7 +1887,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 2), padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) def testReduceWindowValidUnitStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1888,7 +1899,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 1), padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) def testReduceWindowSameUnitStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1900,7 +1911,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 1), padding=xla_client.PaddingType.SAME) - self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) def testReduceWindowValidGeneralStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1912,7 +1923,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 2), padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) def testWhileF32(self): cond = self._CreateTestF32Lt10Computation() @@ -1920,7 +1931,7 @@ class EmbeddedComputationsTest(ComputationTest): c = self._NewComputation() init = c.ConstantF32Scalar(1.) c.While(cond, body, init) - self._ExecuteAndCompareClose(c, expected=16.) + self._ExecuteAndCompareClose(c, expected=[16.]) def testWhileF64(self): cond = self._CreateTestF64Lt10Computation() @@ -1928,7 +1939,7 @@ class EmbeddedComputationsTest(ComputationTest): c = self._NewComputation() init = c.ConstantF64Scalar(1.) c.While(cond, body, init) - self._ExecuteAndCompareClose(c, expected=16.) + self._ExecuteAndCompareClose(c, expected=[16.]) def testConditionalTrue(self): c = self._NewComputation() @@ -1939,7 +1950,7 @@ class EmbeddedComputationsTest(ComputationTest): false_computation = self._CreateConstantF32Computation() c.Conditional(pred, true_operand, true_computation, false_operand, false_computation) - self._ExecuteAndCompareClose(c, expected=6.) + self._ExecuteAndCompareClose(c, expected=[6.]) def testConditionalFalse(self): c = self._NewComputation() @@ -1950,7 +1961,7 @@ class EmbeddedComputationsTest(ComputationTest): false_computation = self._CreateConstantF32Computation() c.Conditional(pred, true_operand, true_computation, false_operand, false_computation) - self._ExecuteAndCompareClose(c, expected=1.) + self._ExecuteAndCompareClose(c, expected=[1.]) def testInfeedS32Values(self): to_infeed = NumpyArrayS32([1, 2, 3, 4]) @@ -1961,7 +1972,7 @@ class EmbeddedComputationsTest(ComputationTest): xla_client.transfer_to_infeed(item) for item in to_infeed: - result = xla_client.execute_with_python_values(compiled_c) + result, = xla_client.execute_with_python_values(compiled_c) self.assertEqual(result, item) def testInfeedTuple(self): @@ -1972,6 +1983,7 @@ class EmbeddedComputationsTest(ComputationTest): xla_client.transfer_to_infeed(to_infeed) result = xla_client.execute_with_python_values(compiled_c) + self.assertLen(result, 2) np.testing.assert_equal(result[0], to_infeed[0]) np.testing.assert_equal(result[1], to_infeed[1]) @@ -1986,7 +1998,8 @@ class EmbeddedComputationsTest(ComputationTest): compiled_c = c.Build().Compile() for want in to_round_trip: - execution = threading.Thread(target=lambda: compiled_c.Execute([])) + execution = threading.Thread( + target=lambda: compiled_c.Execute([], tuple_arguments=False)) execution.start() xla_client.transfer_to_infeed(want) got = xla_client.transfer_from_outfeed( @@ -2010,7 +2023,7 @@ class EmbeddedComputationsTest(ComputationTest): c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), self._CreateBinaryAddS32Computation(), dnums) expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) - self._ExecuteAndCompareClose(c, expected=expected) + self._ExecuteAndCompareClose(c, expected=[expected]) class ErrorTest(ComputationTest): @@ -2063,7 +2076,7 @@ class ComputationRootTest(ComputationTest): arg = NumpyArrayF32(1.0) compiled_c = c.Build(result).Compile() - ans = xla_client.execute_with_python_values(compiled_c, [arg]) + ans, = xla_client.execute_with_python_values(compiled_c, [arg]) np.testing.assert_allclose(ans, 4.14) @@ -2086,7 +2099,7 @@ class SetShardingTest(ComputationTest): extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable arg = NumpyArrayF32(1.0) compiled_c = c.Build(result).Compile() - ans = xla_client.execute_with_python_values(compiled_c, [arg]) + ans, = xla_client.execute_with_python_values(compiled_c, [arg]) np.testing.assert_allclose(ans, 4.14) From d8bccdb1b8e90e07bf7cb4232f9e120e28492d4b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Mar 2020 11:33:00 -0700 Subject: [PATCH 161/492] [XLA:Python] Remove PyLocalBuffer.make_tuple and PyLocalBuffer.destructure() from the API. Since Execute() now supports tupling and untupling, we no longer need tuples in the Python API. This change is in preparation for changing the aliasing behavior of Execute(). PiperOrigin-RevId: 301631669 Change-Id: Idec8c5ebf0025052d6c0cef523f2c77c92e89e0a --- .../python/tpu_driver/client/tpu_client.py | 3 - .../tpu_driver/client/tpu_client_extension.cc | 16 ---- tensorflow/compiler/xla/python/xla.cc | 31 ------ tensorflow/compiler/xla/python/xla_client.py | 18 ---- .../compiler/xla/python/xla_client_test.py | 95 ------------------- 5 files changed, 163 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index 9e44a3d7aed..2c4be78c9c5 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -86,9 +86,6 @@ class TpuBackend(xla_client.Backend): device = self.client.local_devices()[0] return _tpu_client.PyTpuBuffer.from_python(pyval, self.client, device) - def make_tuple(self, c_buffers, device): - return _tpu_client.PyTpuBuffer.make_tuple(c_buffers, self.client, device) - def compile(self, c_computation, compile_options): options = _xla.ExecutableBuildOptions() options.num_replicas = compile_options.num_replicas diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index b4e8afb5853..752ea4c4907 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -125,21 +125,6 @@ PYBIND11_MODULE(tpu_client_extension, m) { std::move(leaves), tree.shape, std::move(py_buffer_ref), std::move(client), std::move(device)); }) - .def_static("make_tuple", - [](const std::vector buffers, - std::shared_ptr client, - std::shared_ptr device) - -> StatusOr> { - CHECK(device != nullptr); - auto iter = client->id_to_device().find(device->id()); - if (iter->second != device) { - return InvalidArgument( - "Cannot make tuple on device '%s' with '%s' backend", - device->DebugString(), client->platform_name()); - } - return PyTpuBuffer::MakeTuple(buffers, std::move(client), - std::move(device)); - }) .def("copy_to_device", [](PyTpuBuffer* buffer, std::shared_ptr dst_device) { CHECK(dst_device != nullptr); @@ -148,7 +133,6 @@ PYBIND11_MODULE(tpu_client_extension, m) { return buffer->CopyToDevice(std::move(dst_device)); }) .def("delete", &PyTpuBuffer::Delete) - .def("destructure", &PyTpuBuffer::DestructureTuple) .def("block_host_until_ready", [](PyTpuBuffer* buffer) { GlobalPyRefManager()->CollectGarbage(); diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index d42636cde79..60952c393ab 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -960,23 +960,6 @@ PYBIND11_MODULE(xla_extension, m) { }, py::arg("argument"), py::arg("client"), py::arg("device"), py::arg("force_copy") = false) - .def_static( - "make_tuple", - [](std::vector buffers, - std::shared_ptr client, - Device* device) -> StatusOr> { - CHECK(device != nullptr); - auto iter = client->id_to_device().find(device->id()); - if (iter->second != device) { - return InvalidArgument( - "Cannot make tuple on device '%s' with '%s' backend", - device->DebugString(), client->platform_name()); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr buffer, - PyLocalBuffer::MakeTuple(buffers, client.get(), device)); - return WrapWithClient(std::move(client), std::move(buffer)); - }) .def("copy_to_device", [](PyLocalBuffer* buffer, const ClientAndPtr& dst_device) -> StatusOr> { @@ -988,20 +971,6 @@ PYBIND11_MODULE(xla_extension, m) { return WrapWithClient(dst_device.client, std::move(out)); }) .def("delete", &PyLocalBuffer::Delete) - .def("destructure", - [](const PyLocalBuffer& buffer) - -> StatusOr>> { - TF_ASSIGN_OR_RETURN( - std::vector> parts, - buffer.DestructureTuple()); - std::vector> output; - output.reserve(parts.size()); - for (auto& part : parts) { - output.push_back(WrapWithClient( - buffer.client()->shared_from_this(), std::move(part))); - } - return std::move(output); - }) .def("block_host_until_ready", [](PyLocalBuffer* buffer) { GlobalPyRefManager()->CollectGarbage(); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index d4df503677c..a7e8903b113 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -76,10 +76,6 @@ class Backend(object, metaclass=abc.ABCMeta): def buffer_from_pyval(self, pyval, device=None, force_copy=False): """Allocates a fresh buffer and populates it with `pyval`.""" - @abc.abstractmethod - def make_tuple(self, c_buffers, device): - """Makes a tuple from a sequence of backend buffer objects.""" - @abc.abstractmethod def compile(self, computation, compile_options): """Compiles a computation. Returns an executable.""" @@ -137,9 +133,6 @@ class LocalBackend(Backend): return _xla.PyLocalBuffer.from_python(pyval, self.client, device, force_copy) - def make_tuple(self, c_buffers, device): - return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device) - def compile(self, c_computation, compile_options): options = _xla.ExecutableBuildOptions() options.num_replicas = compile_options.num_replicas @@ -396,18 +389,12 @@ class Buffer(object): backend = backend or get_local_backend() return backend.buffer_from_pyval(pyval, device, force_copy=force_copy) - @staticmethod - def make_tuple(buffers, device, backend=None): - backend = backend or get_local_backend() - return backend.make_tuple(buffers, device) - # Buffer is not an instantiable type and exists only for its static methods. # The underlying buffer objects are C++ object with the following # API: # def shape(self) -> Shape: # def device(self) -> int: # def delete(self): - # def destructure(self) -> [Buffer] # def is_deleted(self) -> bool: # def block_host_until_ready(self): # """Blocks the calling thread until the buffer is ready on device.""" @@ -426,11 +413,6 @@ class Buffer(object): # clients call methods on Backend to create buffers. -# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops -# compatibility with Jaxlib versions older than 0.1.13. -LocalBuffer = Buffer - - def shape_from_pyval(pyval): """Returns a Shape that describes a tuple-tree of Numpy arrays.""" diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index b28a97837fe..848e8c881d2 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -496,84 +496,6 @@ class BufferTest(ComputationTest): with self.assertRaises(RuntimeError): compiled_c.Execute([arg_buffer], tuple_arguments=False) - def testDestructureTupleEmpty(self): - device = xla_client.get_local_backend().devices()[0] - local_buffer = xla_client.Buffer.make_tuple((), device=device) - pieces = local_buffer.destructure() - self.assertFalse(local_buffer.is_deleted()) - self.assertEmpty(pieces) - - def testDestructureTupleOneArrayElement(self): - device = xla_client.get_local_backend().devices()[0] - t = xla_client.Buffer.from_pyval(np.array([1, 2, 3, 4], dtype=np.int32)) - local_buffer = xla_client.Buffer.make_tuple((t,), device) - pieces = local_buffer.destructure() - self.assertFalse(local_buffer.is_deleted()) - self.assertLen(pieces, 1) - array = pieces[0] - got = array.to_py() - want = NumpyArrayS32([1, 2, 3, 4]) - np.testing.assert_equal(want, got) - - def testDestructureTupleTwoArrayElementDifferentType(self): - device = xla_client.get_local_backend().devices()[0] - t = ( - xla_client.Buffer.from_pyval( - np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)), - xla_client.Buffer.from_pyval(np.array([2, 3, 4, 5], dtype=np.int32)), - ) - local_buffer = xla_client.Buffer.make_tuple(t, device) - # Run the test twice to verify that the original tuple buffer remains valid - # even after destructuring. - for _ in range(2): - pieces = local_buffer.destructure() - self.assertFalse(local_buffer.is_deleted()) - self.assertLen(pieces, 2) - array0, array1 = pieces - got = array0.to_py() - want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) - np.testing.assert_equal(want, got) - got = array1.to_py() - want = NumpyArrayS32([2, 3, 4, 5]) - np.testing.assert_equal(want, got) - - def testDestructureTupleNested(self): - device = xla_client.get_local_backend().devices()[0] - t = xla_client.Buffer.make_tuple( - (xla_client.Buffer.from_pyval(NumpyArrayF32([1.0, 2.0])), - xla_client.Buffer.from_pyval(NumpyArrayS32([3, 4]))), device) - local_buffer = xla_client.Buffer.make_tuple( - (t, xla_client.Buffer.from_pyval(NumpyArrayS32([5]))), device) - pieces = local_buffer.destructure() - self.assertFalse(local_buffer.is_deleted()) - self.assertLen(pieces, 2) - tuple0, array1 = pieces - got = array1.to_py() - want = NumpyArrayS32([5]) - np.testing.assert_equal(want, got) - got = tuple0.to_py() - self.assertEqual(type(got), tuple) - self.assertLen(got, 2) - np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) - np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) - - def testMakeTuple(self): - t = ( - np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), - np.array([2, 3, 4, 5], dtype=np.int32), - ) - b0 = xla_client.Buffer.from_pyval(t[0]) - b1 = xla_client.Buffer.from_pyval(t[1]) - device = xla_client.get_local_backend().local_devices()[0] - btup = xla_client.Buffer.make_tuple([b0, b1], device=device) - pieces = btup.destructure() - self.assertLen(pieces, 2) - array0, array1 = pieces - np.testing.assert_equal( - np.array([1, 2, 3, 4], dtype=np.float32), array0.to_py()) - np.testing.assert_equal( - np.array([2, 3, 4, 5], dtype=np.int32), array1.to_py()) - def testShape(self): pyval = np.array([[1., 2.]], np.float32) local_buffer = xla_client.Buffer.from_pyval(pyval) @@ -581,23 +503,6 @@ class BufferTest(ComputationTest): self.assertEqual(xla_shape.dimensions(), (1, 2)) self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) - def testTupleShape(self): - t = ( - np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32), - np.array([2, 3, 4, 5], dtype=np.int32), - ) - b0 = xla_client.Buffer.from_pyval(t[0]) - b1 = xla_client.Buffer.from_pyval(t[1]) - device = xla_client.get_local_backend().local_devices()[0] - tuple_buffer = xla_client.Buffer.make_tuple([b0, b1], device=device) - tuple_shape = tuple_buffer.shape() - self.assertEqual(tuple_shape.leaf_count(), 2) - shapes = tuple_shape.tuple_shapes() - self.assertLen(shapes, 2) - shape1, shape2 = shapes - self.assertEqual(shape1.dimensions(), (1, 4)) - self.assertEqual(shape2.dimensions(), (4,)) - def testBlockHostUntilReadyWorks(self): arg = np.array([[1., 2.]], np.float32) arg_buffer = xla_client.Buffer.from_pyval(arg) From 09cf2e6d20969555b3ad3e7bb303efbece83654e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 11:46:40 -0700 Subject: [PATCH 162/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301634357 Change-Id: I0d0dd1279aea671584f7fc4800c730b84e59a463 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 52a9bf9551b..6456f104ad3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11757,7 +11757,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12014,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12025,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12243,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12254,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19095,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20166,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21434,7 +21434,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22142,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22338,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22407,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22522,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22581,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22755,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23136,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25576,7 +25576,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25639,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25973,7 +25973,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26023,7 +26023,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26273,7 +26273,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26903,7 +26903,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45738,7 +45738,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value From eaa1729a52952f2a541491a1ce2c34af7ab66fc8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 18 Mar 2020 13:03:19 -0700 Subject: [PATCH 163/492] Force CPU placement for ops that has DT_VARIANT inputs with host-only underlying data type. Fix for #28007 PiperOrigin-RevId: 301650148 Change-Id: I47fa9c1b0b7a7d56c5a519095687f36651892644 --- tensorflow/core/BUILD | 1 + .../core/common_runtime/colocation_graph.cc | 80 +++++++++++++++++++ .../core/common_runtime/colocation_graph.h | 8 ++ tensorflow/core/graph/algorithm.cc | 27 ++++--- tensorflow/core/graph/algorithm.h | 11 ++- .../python/kernel_tests/list_ops_test.py | 30 ++++++- 6 files changed, 142 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 188988d92c4..8efada20e24 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2668,6 +2668,7 @@ tf_cuda_library( "@com_google_absl//absl/base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc index ab58902f415..41058ae208a 100644 --- a/tensorflow/core/common_runtime/colocation_graph.cc +++ b/tensorflow/core/common_runtime/colocation_graph.cc @@ -23,7 +23,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/function.h" @@ -39,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_node_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -726,6 +729,82 @@ Status ColocationGraph::ColocateResourceAndRefEdges( return Status::OK(); } +namespace { +// Returns tensor list element data type, if the node is one of the ops that +// operate with TensorLists. Otherwise returns DT_INVALID. +DataType GetElementDataType(const Node& node) { + static absl::flat_hash_set* tensor_list_ops = + new absl::flat_hash_set( + {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList", + "TensorListSplit", "TensorListScatter", "TensorListScatterV2", + "TensorListScatterIntoExistingList", "TensorListPushBack", + "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack", + "TensorListConcat", "TensorListConcatV2", "TensorListGetItem", + "TensorListSetItem", "TensorListGather", "TensorListConcatLists"}); + + if (tensor_list_ops->contains(node.type_string())) { + DataType element_type; + if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) { + return element_type; + } + } + + return DT_INVALID; +} +} // namespace + +Status ColocationGraph::AddHostOnlyDataTypesConstraints() { + auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; }; + + auto is_cpu_device = [](const std::pair& entry) -> bool { + return entry.first == DEVICE_CPU; + }; + + for (Node* node : graph_.nodes()) { + // Skip nodes that do not have DT_VARIANT inputs. + if (absl::c_none_of(node->input_types(), is_variant)) continue; + + // Skip nodes that can't be placed on GPU anyway. + Member& root = members_[FindAndUpdateRoot(node->id())]; + if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) continue; + + // Stop DFS traversal when found the underlying data type of a variant. + absl::optional is_host_data_type; + + auto edge_filter = [&](const Edge& edge) -> bool { + return !is_host_data_type.has_value(); + }; + + auto enter = [&](Node* n) -> void { + DataType element_type = GetElementDataType(*n); + // To handle nested lists continue traversal after finding a TensorList + // operation that uses DT_VARIANT for element type. + if (element_type == DT_INVALID || element_type == DT_VARIANT) return; + is_host_data_type = DataTypeAlwaysOnHost(element_type); + }; + + ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr, + /*stable_comparator=*/nullptr, edge_filter); + + if (is_host_data_type.has_value() && *is_host_data_type) { + VLOG(2) << "Limit node possible devices to CPU only, because it has a " + "DT_VARIANT input with host-only underlying data type: " + << "node=" << node->name(); + + // Restrict possible device types to CPU only. + PossibleDevices possible_devices; + absl::c_copy_if(root.supported_device_types(), + std::back_inserter(possible_devices.device_types), + is_cpu_device); + + TF_RETURN_IF_ERROR(root.LimitToPossibleDevices( + possible_devices, /*allow_soft_placement=*/false)); + } + } + + return Status::OK(); +} + Status ColocationGraph::AddInspectionConstraints( const std::unordered_set& inspection_required) { for (Node* node : inspection_required) { @@ -744,6 +823,7 @@ Status ColocationGraph::Initialize() { std::unordered_set inspection_required; TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required)); + TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints()); TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required)); TF_RETURN_IF_ERROR(ColocateAllNodes()); diff --git a/tensorflow/core/common_runtime/colocation_graph.h b/tensorflow/core/common_runtime/colocation_graph.h index 65fddf931ef..d0714d54a5a 100644 --- a/tensorflow/core/common_runtime/colocation_graph.h +++ b/tensorflow/core/common_runtime/colocation_graph.h @@ -283,6 +283,14 @@ class ColocationGraph { Status ColocateResourceAndRefEdges( std::unordered_set* inspection_required); + // Updates this ColocationGraph by making sure that all nodes having inputs of + // a DT_VARIANT data type with a host-only underlying types (e.g. strings) can + // be placed only on CPU device. We do that by reverse-DFS traversal from all + // nodes that take variant inputs to the node that produces that variant. + // TODO(ezhulenev): This function does not yet support "deep op" inspection, + // that we have for DT_RESOURCE edges. + Status AddHostOnlyDataTypesConstraints(); + Status AddInspectionConstraints( const std::unordered_set& inspection_required); diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc index 5524ab53c5a..f80822d5b00 100644 --- a/tensorflow/core/graph/algorithm.cc +++ b/tensorflow/core/graph/algorithm.cc @@ -112,8 +112,10 @@ void DFSFrom(const Graph& g, gtl::ArraySlice start, void ReverseDFS(const Graph& g, const std::function& enter, const std::function& leave, - const NodeComparator& stable_comparator) { - ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator, + edge_filter); } namespace { @@ -122,7 +124,8 @@ template void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice start, const std::function& enter, const std::function& leave, - const NodeComparator& stable_comparator) { + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { // Stack of work to do. struct Work { T node; @@ -161,7 +164,9 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice start, if (stable_comparator) { std::vector nodes_sorted; for (const Edge* in_edge : n->in_edges()) { - nodes_sorted.emplace_back(in_edge->src()); + if (!edge_filter || edge_filter(*in_edge)) { + nodes_sorted.emplace_back(in_edge->src()); + } } std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); for (T in : nodes_sorted) { @@ -169,7 +174,9 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice start, } } else { for (const Edge* in_edge : n->in_edges()) { - add_work(in_edge->src()); + if (!edge_filter || edge_filter(*in_edge)) { + add_work(in_edge->src()); + } } } } @@ -180,15 +187,17 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice start, void ReverseDFSFrom(const Graph& g, gtl::ArraySlice start, const std::function& enter, const std::function& leave, - const NodeComparator& stable_comparator) { - ReverseDFSFromHelper(g, start, enter, leave, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter); } void ReverseDFSFrom(const Graph& g, gtl::ArraySlice start, const std::function& enter, const std::function& leave, - const NodeComparator& stable_comparator) { - ReverseDFSFromHelper(g, start, enter, leave, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter); } void GetPostOrder(const Graph& g, std::vector* order, diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h index 8774a67a91e..9a9595a86d6 100644 --- a/tensorflow/core/graph/algorithm.h +++ b/tensorflow/core/graph/algorithm.h @@ -77,23 +77,28 @@ extern void DFSFrom(const Graph& g, gtl::ArraySlice start, // If leave is not empty, calls leave(n) after visiting all parents of n. // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. extern void ReverseDFS(const Graph& g, const std::function& enter, const std::function& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Perform a reverse depth-first-search on g starting at the 'start' nodes. // If enter is not empty, calls enter(n) before visiting any parents of n. // If leave is not empty, calls leave(n) after visiting all parents of n. // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice start, const std::function& enter, const std::function& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice start, const std::function& enter, const std::function& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Stores in *order the post-order numbering of all nodes // in graph found via a depth first search starting at the source node. diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 11f882b5bf3..e618e21ed9d 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -1632,14 +1633,37 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllEqual(t, [1.0, 2.0, 3.0]) def testTensorListStrings(self): - self.skipTest("b/150742232") - @def_function.function def f(): return map_fn.map_fn(string_ops.string_upper, constant_op.constant(["a", "b", "c"])) - self.assertAllEqual(f(), ["A", "B", "C"]) + self.assertAllEqual(f(), [b"A", b"B", b"C"]) + + def testTensorListStringsNoInline(self): + # Generator function output type is a variant with a host-only underlying + # data type. "ColocationGraph::AddHostOnlyDataTypesConstraints" needs to + # have "deep op inspection" to be able to correctly place the while loop + # generated from map_fn. + self.skipTest("b/150742232") + + @function.defun_with_attributes(attributes={"_noinline": True}) + def generator(): + c = constant_op.constant(["a", "b", "c"]) + return list_ops.tensor_list_from_tensor(c, element_shape=[]) + + @def_function.function + def f(): + l = generator() + + def upper(i): + e = list_ops.tensor_list_get_item(l, i, element_dtype=dtypes.string) + return string_ops.string_upper(e) + + return map_fn.map_fn( + upper, constant_op.constant([0, 1, 2]), dtype=dtypes.string) + + self.assertAllEqual(f(), [b"A", b"B", b"C"]) if __name__ == "__main__": From ef46437ea43d6bae6d31ddcfa949706803864ce5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 13:23:44 -0700 Subject: [PATCH 164/492] Update Eigen to: https://gitlab.com/libeigen/eigen/-/commit/7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83 PiperOrigin-RevId: 301653945 Change-Id: I04e44525e292f5bbaec3db3631a1b15a4de4bec9 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index f02e2eb1538..3d29648c1ba 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -201,11 +201,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "3d9cbec40e27093956ad46a4482bb03f968964cabb7b9f35807fd80852ec026a", # SHARED_EIGEN_SHA - strip_prefix = "eigen-b733b8b680885c0fcdfddea5423171468609b5a6", + sha256 = "ce221392c106e90fa28a2ffccf6e45869477b40e17a0b0728334e5e1970de294", # SHARED_EIGEN_SHA + strip_prefix = "eigen-7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/b733b8b680885c0fcdfddea5423171468609b5a6/eigen-b733b8b680885c0fcdfddea5423171468609b5a6.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/b733b8b680885c0fcdfddea5423171468609b5a6/eigen-b733b8b680885c0fcdfddea5423171468609b5a6.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83/eigen-7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83/eigen-7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83.tar.gz", ], ) From b2a5472997ae9d366bdf6689caf2ea9526854ebb Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Wed, 18 Mar 2020 13:36:10 -0700 Subject: [PATCH 165/492] Creates `framework_lib` target. PiperOrigin-RevId: 301656623 Change-Id: Ied2158a67bfb0fa29753d201610d2967f57d3504 --- tensorflow/lite/BUILD | 66 ++++++++++++++++++++++++++--------- tensorflow/lite/kernels/BUILD | 4 +-- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 5e22b1fed5c..9c4740b8c0a 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -62,6 +62,22 @@ TFLITE_DEFAULT_COPTS = if_not_windows([ "-Wno-extern-c-compat", ]) +FRAMEWORK_LIB_HDRS = [ + "allocation.h", + "context.h", + "context_util.h", + "core/macros.h", + "core/subgraph.h", + "error_reporter.h", + "graph_info.h", + "interpreter.h", + "model.h", + "mutable_op_resolver.h", + "op_resolver.h", + "optional_debug_tools.h", + "stderr_reporter.h", +] + cc_library( name = "version", hdrs = ["version.h"], @@ -200,9 +216,8 @@ cc_library( ], ) -# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. cc_library( - name = "framework", + name = "framework_lib", srcs = [ "core/subgraph.cc", "graph_info.cc", @@ -212,23 +227,42 @@ cc_library( "optional_debug_tools.cc", "stderr_reporter.cc", ], - hdrs = [ - "allocation.h", - "context.h", - "context_util.h", - "core/macros.h", - "core/subgraph.h", - "error_reporter.h", - "graph_info.h", - "interpreter.h", - "model.h", - "mutable_op_resolver.h", - "op_resolver.h", - "optional_debug_tools.h", - "stderr_reporter.h", + hdrs = FRAMEWORK_LIB_HDRS, + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//tensorflow/lite:__subpackages__", ], + deps = [ + ":allocation", + ":arena_planner", + ":external_cpu_backend_context", + ":graph_info", + ":memory_planner", + ":minimal_logging", + ":simple_memory_arena", + ":string", + ":type_to_tflitetype", + ":util", + ":version", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + "//tensorflow/lite/experimental/resource", + "//tensorflow/lite/nnapi:nnapi_implementation", + "//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) + +# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. +cc_library( + name = "framework", + srcs = [ + ], + hdrs = FRAMEWORK_LIB_HDRS, copts = tflite_copts() + TFLITE_DEFAULT_COPTS, deps = [ + ":framework_lib", ":allocation", ":arena_planner", ":external_cpu_backend_context", diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 57e9b876ec1..1f04cc3ee47 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -526,7 +526,7 @@ cc_library( ":lstm_shared", ":op_macros", ":padding", - "//tensorflow/lite:framework", + "//tensorflow/lite:framework_lib", "//tensorflow/lite:minimal_logging", "//tensorflow/lite:string_util", "//tensorflow/lite/c:common", @@ -660,7 +660,7 @@ cc_library( ], deps = [ ":builtin_op_kernels", - "//tensorflow/lite:framework", + "//tensorflow/lite:framework_lib", "//tensorflow/lite/c:common", ], ) From b76fdc4d454d2de37b33413d4526dba9ef27532b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cjaketae=E2=80=9D?= Date: Thu, 19 Mar 2020 05:42:07 +0900 Subject: [PATCH 166/492] Replace arg with --- tensorflow/python/keras/preprocessing/text.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/preprocessing/text.py b/tensorflow/python/keras/preprocessing/text.py index 96f4b19660e..33740159d6b 100644 --- a/tensorflow/python/keras/preprocessing/text.py +++ b/tensorflow/python/keras/preprocessing/text.py @@ -28,7 +28,7 @@ Tokenizer = text.Tokenizer @keras_export('keras.preprocessing.text.text_to_word_sequence') -def text_to_word_sequence(text, +def text_to_word_sequence(input_text, filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', lower=True, split=" "): """Converts a text to a sequence of words (or tokens). @@ -36,12 +36,12 @@ def text_to_word_sequence(text, This function transforms a string of text into a list of words while ignoring `filters` which include punctuations by default. - >>> text = 'This is a sample sentence.' - >>> tf.keras.preprocessing.text.text_to_word_sequence(text) + >>> sample_text = 'This is a sample sentence.' + >>> tf.keras.preprocessing.text.text_to_word_sequence(sample_text) ['this', 'is', 'a', 'sample', 'sentence'] Arguments: - text: Input text (string). + input_text: Input text (string). filters: list (or concatenation) of characters to filter out, such as punctuation. Default: `'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n'`, includes basic punctuation, tabs, and newlines. @@ -52,11 +52,11 @@ def text_to_word_sequence(text, A list of words (or tokens). """ return text.text_to_word_sequence( - text, filters=filters, lower=lower, split=split) + input_text, filters=filters, lower=lower, split=split) @keras_export('keras.preprocessing.text.one_hot') -def one_hot(text, n, +def one_hot(input_text, n, filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', lower=True, split=' '): @@ -66,12 +66,12 @@ def one_hot(text, n, list of encoded integers each corresponding to a word (or token) in the given input string. - >>> text = 'This is a sample sentence.' - >>> tf.keras.preprocessing.text.one_hot(text, 20) + >>> sample_text = 'This is a sample sentence.' + >>> tf.keras.preprocessing.text.one_hot(sample_text, 20) [4, 18, 1, 15, 17] Arguments: - text: Input text (string). + input_text: Input text (string). n: int. Size of vocabulary. filters: list (or concatenation) of characters to filter out, such as punctuation. Default: ``!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n``, @@ -84,7 +84,7 @@ def one_hot(text, n, (unicity non-guaranteed). """ return text.one_hot( - text, n, filters=filters, lower=lower, split=split) + input_text, n, filters=filters, lower=lower, split=split) # text.tokenizer_from_json is only available if keras_preprocessing >= 1.1.0 From 1e0821e6017b159db48239b24584a82b6f471bc3 Mon Sep 17 00:00:00 2001 From: Philip Pham Date: Wed, 18 Mar 2020 13:42:34 -0700 Subject: [PATCH 167/492] Make Proximal Yogi available in Python TPU Embedding API PiperOrigin-RevId: 301657986 Change-Id: I3f1cbc88bdf3fbb729ca16e1597d6d29d76ec464 --- ...adTPUEmbeddingProximalYogiParameters.pbtxt | 4 + ...ProximalYogiParametersGradAccumDebug.pbtxt | 4 + ...veTPUEmbeddingProximalYogiParameters.pbtxt | 4 + ...ProximalYogiParametersGradAccumDebug.pbtxt | 4 + ...embedding_optimization_parameters_utils.cc | 6 +- tensorflow/python/tpu/tpu_embedding.py | 181 ++++++++++++++++++ .../api/golden/v1/tensorflow.raw_ops.pbtxt | 16 ++ .../api/golden/v2/tensorflow.raw_ops.pbtxt | 16 ++ 8 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingProximalYogiParameters.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingProximalYogiParameters.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingProximalYogiParameters.pbtxt b/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingProximalYogiParameters.pbtxt new file mode 100644 index 00000000000..b27fef9b304 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingProximalYogiParameters.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "LoadTPUEmbeddingProximalYogiParameters" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.pbtxt b/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.pbtxt new file mode 100644 index 00000000000..3804dc9e7a3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingProximalYogiParameters.pbtxt b/tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingProximalYogiParameters.pbtxt new file mode 100644 index 00000000000..fd143b47510 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingProximalYogiParameters.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "RetrieveTPUEmbeddingProximalYogiParameters" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.pbtxt b/tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.pbtxt new file mode 100644 index 00000000000..58822443d82 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug" + visibility: HIDDEN +} diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc index f7a14aab3c2..acc1dfd765e 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc @@ -546,13 +546,13 @@ Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg, case OptimizationAlgorithm::kCenteredRmsProp: case OptimizationAlgorithm::kMdlAdagradLight: case OptimizationAlgorithm::kAdadelta: - case OptimizationAlgorithm::kProximalAdagrad: { + case OptimizationAlgorithm::kProximalAdagrad: + case OptimizationAlgorithm::kProximalYogi: { *internal = false; return Status::OK(); } case OptimizationAlgorithm::kBoundedAdagrad: - case OptimizationAlgorithm::kOnlineYogi: - case OptimizationAlgorithm::kProximalYogi: { + case OptimizationAlgorithm::kOnlineYogi: { *internal = true; return Status::OK(); } diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index a6677c82daf..e3dbe7fb93f 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -241,6 +241,9 @@ ProximalAdagradSlotVariableName = collections.namedtuple( FtrlSlotVariableName = collections.namedtuple( 'FtrlSlotVariableName', ['accumulator', 'linear']) +ProximalYogiSlotVariableNames = collections.namedtuple( + 'ProximalYogiSlotVariableNames', ['v', 'm']) + AdamSlotVariables = collections.namedtuple( 'AdamSlotVariables', ['m', 'v']) @@ -253,6 +256,9 @@ ProximalAdagradSlotVariable = collections.namedtuple( FtrlSlotVariable = collections.namedtuple( 'FtrlSlotVariable', ['accumulator', 'linear']) +ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables', + ['v', 'm']) + VariablesAndOps = collections.namedtuple( 'VariablesAndOps', ['embedding_variables_by_table', 'slot_variables_by_table', @@ -545,6 +551,83 @@ class FtrlParameters(_OptimizationParameters): self.l2_regularization_strength = l2_regularization_strength +class ProximalYogiParameters(_OptimizationParameters): + # pylint: disable=line-too-long + """Optimization parameters for Proximal Yogi with TPU embeddings. + + Implements the Yogi optimizer as described in + [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization). + + Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the + `optimization_parameters` argument to set the optimizer and its parameters. + See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` + for more details. + """ + # pylint: enable=line-too-long + + def __init__(self, + learning_rate=0.01, + beta1=0.9, + beta2=0.999, + epsilon=1e-3, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0, + initial_accumulator_value=1e-6, + use_gradient_accumulation=True, + clip_weight_min=None, + clip_weight_max=None, + weight_decay_factor=None, + multiply_weight_decay_factor_by_learning_rate=None): + """Optimization parameters for Proximal Yogi. + + Args: + learning_rate: a floating point value. The learning rate. + beta1: A float value. The exponential decay rate for the 1st moment + estimates. + beta2: A float value. The exponential decay rate for the 2nd moment + estimates. + epsilon: A small constant for numerical stability. + l1_regularization_strength: A float value, must be greater than or equal + to zero. + l2_regularization_strength: A float value, must be greater than or equal + to zero. + initial_accumulator_value: The starting value for accumulators. Only zero + or positive values are allowed. + use_gradient_accumulation: setting this to `False` makes embedding + gradients calculation less accurate but faster. Please see + `optimization_parameters.proto` for details. for details. + clip_weight_min: the minimum value to clip by; None means -infinity. + clip_weight_max: the maximum value to clip by; None means +infinity. + weight_decay_factor: amount of weight decay to apply; None means that the + weights are not decayed. + multiply_weight_decay_factor_by_learning_rate: if true, + `weight_decay_factor` is multiplied by the current learning rate. + """ + super(ProximalYogiParameters, + self).__init__(learning_rate, use_gradient_accumulation, + clip_weight_min, clip_weight_max, weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate) + if beta1 < 0. or beta1 >= 1.: + raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) + if beta2 < 0. or beta2 >= 1.: + raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) + if epsilon <= 0.: + raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) + if l1_regularization_strength < 0.: + raise ValueError('l1_regularization_strength must be greater than or ' + 'equal to 0. got {}.'.format(l1_regularization_strength)) + if l2_regularization_strength < 0.: + raise ValueError('l2_regularization_strength must be greater than or ' + 'equal to 0. got {}.'.format(l2_regularization_strength)) + + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.l1_regularization_strength = l1_regularization_strength + self.l2_regularization_strength = l2_regularization_strength + self.initial_accumulator_value = initial_accumulator_value + + @tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters']) class StochasticGradientDescentParameters(_OptimizationParameters): """Optimization parameters for stochastic gradient descent for TPU embeddings. @@ -1706,6 +1789,102 @@ class _FtrlHandler(_OptimizerHandler): return slot_variables, load_ops_fn, retrieve_ops_fn +class _ProximalYogiHandler(_OptimizerHandler): + """Handles Proximal Yogi specific logic.""" + + def set_optimization_parameters(self, table_descriptor): + table_descriptor.optimization_parameters.proximal_yogi.SetInParent() + table_descriptor.optimization_parameters.proximal_yogi.beta1 = ( + self._optimization_parameters.beta1) + table_descriptor.optimization_parameters.proximal_yogi.beta2 = ( + self._optimization_parameters.beta2) + table_descriptor.optimization_parameters.proximal_yogi.epsilon = ( + self._optimization_parameters.epsilon) + table_descriptor.optimization_parameters.proximal_yogi.l1 = ( + self._optimization_parameters.l1_regularization_strength) + table_descriptor.optimization_parameters.proximal_yogi.l2 = ( + self._optimization_parameters.l2_regularization_strength) + + def get_default_slot_variable_names(self, table): + return ProximalYogiSlotVariableNames( + '{}/{}'.format(table, 'ProximalYogi'), # v + '{}/{}_1'.format(table, 'ProximalYogi')) # m + + def create_variables_and_ops(self, table, slot_variable_names, num_hosts, + table_config, table_variables, config_proto): + v_initializer = init_ops.constant_initializer( + self._optimization_parameters.initial_accumulator_value) + v_variables = _create_partitioned_variables( + name=slot_variable_names.v, + num_hosts=num_hosts, + vocabulary_size=table_config.vocabulary_size, + embedding_dimension=table_config.dimension, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + initializer=v_initializer) + m_initializer = init_ops.zeros_initializer() + m_variables = _create_partitioned_variables( + name=slot_variable_names.m, + num_hosts=num_hosts, + vocabulary_size=table_config.vocabulary_size, + embedding_dimension=table_config.dimension, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + initializer=m_initializer) + slot_variables = ProximalYogiSlotVariables(v_variables, m_variables) + + def load_ops_fn(): + """Returns the load ops for Proximal Yogi embedding tables. + + Returns: + A list of ops to load embedding and slot variables from CPU to TPU. + """ + load_op_list = [] + config = config_proto + for host_id, table_variable, v_variable, m_variable in (zip( + range(num_hosts), table_variables, v_variables, m_variables)): + with ops.colocate_with(table_variable): + load_parameters_op = ( + tpu_ops.load_tpu_embedding_proximal_yogi_parameters( + parameters=table_variable, + v=v_variable, + m=m_variable, + table_name=table, + num_shards=num_hosts, + shard_id=host_id, + config=config)) + # Set config to None to enforce that config is only loaded to the first + # table. + config = None + load_op_list.append(load_parameters_op) + return load_op_list + + def retrieve_ops_fn(): + """Returns the retrieve ops for Proximal Yogi embedding tables. + + Returns: + A list of ops to retrieve embedding and slot variables from TPU to CPU. + """ + retrieve_op_list = [] + config = config_proto + for host_id, table_variable, v_variable, m_variable in (zip( + range(num_hosts), table_variables, v_variables, m_variables)): + with ops.colocate_with(table_variable): + retrieved_table, retrieved_v, retrieved_m = ( + tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters( + table_name=table, + num_shards=num_hosts, + shard_id=host_id, + config=config)) + retrieve_parameters_op = control_flow_ops.group( + state_ops.assign(table_variable, retrieved_table), + state_ops.assign(v_variable, retrieved_v), + state_ops.assign(m_variable, retrieved_m)) + config = None + retrieve_op_list.append(retrieve_parameters_op) + return retrieve_op_list + + return slot_variables, load_ops_fn, retrieve_ops_fn + + class _StochasticGradientDescentHandler(_OptimizerHandler): """Handles stochastic gradient descent specific logic.""" @@ -1779,6 +1958,8 @@ def _get_optimization_handler(optimization_parameters): return _AdamHandler(optimization_parameters) elif isinstance(optimization_parameters, FtrlParameters): return _FtrlHandler(optimization_parameters) + elif isinstance(optimization_parameters, ProximalYogiParameters): + return _ProximalYogiHandler(optimization_parameters) elif isinstance(optimization_parameters, StochasticGradientDescentParameters): return _StochasticGradientDescentHandler(optimization_parameters) else: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 8df5fe219f6..fa3462e6d44 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -2072,6 +2072,14 @@ tf_module { name: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug" argspec: "args=[\'parameters\', \'accumulators\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " } + member_method { + name: "LoadTPUEmbeddingProximalYogiParameters" + argspec: "args=[\'parameters\', \'v\', \'m\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " + } + member_method { + name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug" + argspec: "args=[\'parameters\', \'v\', \'m\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " + } member_method { name: "LoadTPUEmbeddingRMSPropParameters" argspec: "args=[\'parameters\', \'ms\', \'mom\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " @@ -3608,6 +3616,14 @@ tf_module { name: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug" argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " } + member_method { + name: "RetrieveTPUEmbeddingProximalYogiParameters" + argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " + } + member_method { + name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug" + argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " + } member_method { name: "RetrieveTPUEmbeddingRMSPropParameters" argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 8df5fe219f6..fa3462e6d44 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -2072,6 +2072,14 @@ tf_module { name: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug" argspec: "args=[\'parameters\', \'accumulators\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " } + member_method { + name: "LoadTPUEmbeddingProximalYogiParameters" + argspec: "args=[\'parameters\', \'v\', \'m\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " + } + member_method { + name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug" + argspec: "args=[\'parameters\', \'v\', \'m\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " + } member_method { name: "LoadTPUEmbeddingRMSPropParameters" argspec: "args=[\'parameters\', \'ms\', \'mom\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " @@ -3608,6 +3616,14 @@ tf_module { name: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug" argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " } + member_method { + name: "RetrieveTPUEmbeddingProximalYogiParameters" + argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " + } + member_method { + name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug" + argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " + } member_method { name: "RetrieveTPUEmbeddingRMSPropParameters" argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], " From 019639ca42640cb98834b23ae977dc15cdf64197 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 13:46:16 -0700 Subject: [PATCH 168/492] Use same var key in _create_slots/get_slot in V1 optimizer We have special handling for distributed variable in get_slot, but not create_slot. This happens to work before but upcoming change in distributed library will break it. PiperOrigin-RevId: 301658655 Change-Id: I9fc3dd9bacb277a9a6c7d9dba743b5885cad59e4 --- tensorflow/python/BUILD | 4 --- tensorflow/python/training/optimizer.py | 38 ++++++++++---------- tensorflow/python/training/optimizer_test.py | 26 -------------- 3 files changed, 20 insertions(+), 48 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 94d52a8ab06..d932899ab0d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5220,7 +5220,6 @@ py_library( "//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:reduce_util", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/keras/optimizer_v2:learning_rate_schedule", @@ -6591,9 +6590,6 @@ cuda_py_tests( ":variable_scope", ":variables", "//tensorflow/core:protos_all_py", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/distribute:values", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index f89dc362cf8..f1a31d01dd4 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -27,7 +27,6 @@ import six from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util -from tensorflow.python.distribute import values as ds_values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -82,17 +81,10 @@ def _deduplicate_indexed_slices(values, indices): def _var_key(var): - """Returns slot key for `var`.""" - # pylint: disable=protected-access - if hasattr(var, "_distributed_container"): - var = var._distributed_container() - if ops.executing_eagerly_outside_functions(): - return var._unique_id - if ds_values.is_distributed_variable(var): - return (var.graph, var._shared_name) - else: + # TODO(ashankar): Consolidate handling for eager and graph + if hasattr(var, "op"): return (var.op.graph, var.op.name) - # pylint: enable=protected-access + return var._unique_id # pylint: disable=protected-access @six.add_metaclass(abc.ABCMeta) @@ -759,16 +751,26 @@ class Optimizer( Returns: The `Variable` for the slot if it was created, `None` otherwise. """ + # pylint: disable=protected-access named_slots = self._slots.get(name, None) if not named_slots: return None - slot = named_slots.get(_var_key(var), None) - if (ds_values.is_distributed_variable(slot) and - not ds_values.is_distributed_variable(var)): - # Make sure var and slot are either both DistributedVariable, or both - # per replica variables. - slot = slot._get_closest() # pylint: disable=protected-access - return slot + + if hasattr(var, "_distributed_container"): + # NOTE: If this isn't patched, then there is no `handle` in + # `_resource_apply_dense`. + distributed_container = var._distributed_container() + assert distributed_container is not None + if ops.executing_eagerly_outside_functions(): + key = distributed_container._unique_id + else: + key = (distributed_container.graph, distributed_container._shared_name) + # pylint: enable=protected-access + mirrored_slot = named_slots.get(key, None) + if mirrored_slot is None: return None + return mirrored_slot._get_closest() # pylint: disable=protected-access + + return named_slots.get(_var_key(var), None) def get_slot_names(self): """Return a list of the names of slots created by the `Optimizer`. diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 30fa5cd0388..5775d0b8091 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -18,9 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.distribute import cross_device_ops -from tensorflow.python.distribute import mirrored_strategy -from tensorflow.python.distribute import values as ds_values from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,7 +29,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent @@ -273,28 +269,6 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-0.1, -0.1], self.evaluate(var0)) self.assertAllClose([0., 0.], self.evaluate(var1)) - @test_util.run_deprecated_v1 - def testGetSlotUnderDistributedStrategy(self): - # Only run this test in graph mode so we don't need actual GPU. - ds = mirrored_strategy.MirroredStrategy( - ['CPU:0', 'GPU:0'], - cross_device_ops=cross_device_ops.HierarchicalCopyAllReduce()) - # We need an optimizer that creates slots. - optimizer = adam.AdamOptimizer() - - def f(): - v = variables.Variable([1.0]) - self.assertTrue(ds_values.is_distributed_variable(v)) - # Slot variables are created in the first call to apply_gradients. - optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)]) - self.assertTrue(optimizer.get_slot_names()) - for name in optimizer.get_slot_names(): - slot = optimizer.get_slot(v, name) - self.assertIsNotNone(slot) - self.assertTrue(ds_values.is_distributed_variable(slot)) - - ds.experimental_run_v2(f) - if __name__ == '__main__': test.main() From 4ecf20e800330d335c4b091fd805c073d824f2d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 14:00:24 -0700 Subject: [PATCH 169/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301661813 Change-Id: I5ba0e5f243300b93dea4a34c0c9cb269e9d92c8a --- tensorflow/go/op/wrappers.go | 2594 +++++++++++++++++----------------- 1 file changed, 1297 insertions(+), 1297 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 6456f104ad3..7be0c66548c 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -8483,62 +8483,6 @@ func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Out return op.Output(0) } -// DatasetToGraphAttr is an optional argument to DatasetToGraph. -type DatasetToGraphAttr func(optionalAttr) - -// DatasetToGraphStatefulWhitelist sets the optional stateful_whitelist attribute to value. -// If not specified, defaults to {} -// -// REQUIRES: len(value) >= 0 -func DatasetToGraphStatefulWhitelist(value []string) DatasetToGraphAttr { - return func(m optionalAttr) { - m["stateful_whitelist"] = value - } -} - -// DatasetToGraphAllowStateful sets the optional allow_stateful attribute to value. -// If not specified, defaults to false -func DatasetToGraphAllowStateful(value bool) DatasetToGraphAttr { - return func(m optionalAttr) { - m["allow_stateful"] = value - } -} - -// DatasetToGraphStripDeviceAssignment sets the optional strip_device_assignment attribute to value. -// If not specified, defaults to false -func DatasetToGraphStripDeviceAssignment(value bool) DatasetToGraphAttr { - return func(m optionalAttr) { - m["strip_device_assignment"] = value - } -} - -// Returns a serialized GraphDef representing `input_dataset`. -// -// Returns a graph representation for `input_dataset`. -// -// Arguments: -// input_dataset: A variant tensor representing the dataset to return the graph representation for. -// -// Returns The graph representation of the dataset (as serialized GraphDef). -func DatasetToGraph(scope *Scope, input_dataset tf.Output, optional ...DatasetToGraphAttr) (graph tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DatasetToGraph", - Input: []tf.Input{ - input_dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Converts the given `resource_handle` representing an iterator to a string. // // Arguments: @@ -11733,6 +11677,120 @@ func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_ou return op.Output(0) } +// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. +type NonMaxSuppressionAttr func(optionalAttr) + +// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. +// +// value: A float representing the threshold for deciding whether boxes +// overlap too much with respect to IOU. +// If not specified, defaults to 0.5 +func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { + return func(m optionalAttr) { + m["iou_threshold"] = value + } +} + +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression( +// boxes, scores, max_output_size, iou_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "NonMaxSuppression", + Input: []tf.Input{ + boxes, scores, max_output_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. +type CropAndResizeGradBoxesAttr func(optionalAttr) + +// CropAndResizeGradBoxesMethod sets the optional method attribute to value. +// +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { + return func(m optionalAttr) { + m["method"] = value + } +} + +// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. +// +// Arguments: +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +// Both `image_height` and `image_width` need to be positive. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// +// Returns A 2-D tensor of shape `[num_boxes, 4]`. +func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CropAndResizeGradBoxes", + Input: []tf.Input{ + grads, image, boxes, box_ind, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) @@ -11757,7 +11815,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12014,7 +12072,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12025,7 +12083,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12243,7 +12301,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12254,7 +12312,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19095,7 +19153,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20166,7 +20224,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21063,102 +21121,6 @@ func Atanh(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// OutfeedDequeueAttr is an optional argument to OutfeedDequeue. -type OutfeedDequeueAttr func(optionalAttr) - -// OutfeedDequeueDeviceOrdinal sets the optional device_ordinal attribute to value. -// -// value: The TPU device to use. This should be -1 when the Op -// is running on a TPU device, and >= 0 when the Op is running on the CPU -// device. -// If not specified, defaults to -1 -func OutfeedDequeueDeviceOrdinal(value int64) OutfeedDequeueAttr { - return func(m optionalAttr) { - m["device_ordinal"] = value - } -} - -// Retrieves a single tensor from the computation outfeed. -// -// This operation will block indefinitely until data is available. -// -// Arguments: -// dtype: The type of elements in the tensor. -// shape: The shape of the tensor. -// -// Returns A tensor that will be read from the device outfeed. -func OutfeedDequeue(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...OutfeedDequeueAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OutfeedDequeue", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. -type CropAndResizeGradImageAttr func(optionalAttr) - -// CropAndResizeGradImageMethod sets the optional method attribute to value. -// -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input image tensor. -// -// Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` -// containing the original image size. Both `image_height` and `image_width` need -// to be positive. -// -// -// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"T": T} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CropAndResizeGradImage", - Input: []tf.Input{ - grads, boxes, box_ind, image_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes hyperbolic tangent of `x` element-wise. // // Given an input tensor, this function computes hyperbolic tangent of every @@ -21434,7 +21396,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22142,7 +22104,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22338,7 +22300,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22407,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22522,7 +22484,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22581,7 +22543,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22755,7 +22717,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23136,7 +23098,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -24832,103 +24794,6 @@ func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backpr return op.Output(0) } -// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. -type QuantizeAndDequantizeV3Attr func(optionalAttr) - -// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { - return func(m optionalAttr) { - m["signed_input"] = value - } -} - -// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { - return func(m optionalAttr) { - m["range_given"] = value - } -} - -// QuantizeAndDequantizeV3NarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func QuantizeAndDequantizeV3NarrowRange(value bool) QuantizeAndDequantizeV3Attr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// QuantizeAndDequantizeV3Axis sets the optional axis attribute to value. -// If not specified, defaults to -1 -func QuantizeAndDequantizeV3Axis(value int64) QuantizeAndDequantizeV3Attr { - return func(m optionalAttr) { - m["axis"] = value - } -} - -// Quantizes then dequantizes a tensor. -// -// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a -// tensor, so its value can change during training. -func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizeAndDequantizeV3", - Input: []tf.Input{ - input, input_min, input_max, num_bits, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns x * y element-wise. -// -// *NOTE*: `Multiply` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Mul(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Mul", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes softplus gradients for a softplus operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding softplus operation. -// features: The features passed as input to the corresponding softplus operation. -// -// Returns The gradients: `gradients / (1 + exp(-features))`. -func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SoftplusGrad", - Input: []tf.Input{ - gradients, features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the LSTM cell backward propagation for 1 timestep. // // This implementation is to be used in conjunction of LSTMBlockCell. @@ -25576,7 +25441,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25639,7 +25504,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25680,97 +25545,6 @@ func Conv3D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, pa return op.Output(0) } -// Compute the lower regularized incomplete Gamma function `P(a, x)`. -// -// The lower regularized incomplete Gamma function is defined as: -// -// -// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) -// -// where -// -// \\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) -// -// is the lower incomplete Gamma function. -// -// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete -// Gamma function. -func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Igamma", - Input: []tf.Input{ - a, x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StringSplitV2Attr is an optional argument to StringSplitV2. -type StringSplitV2Attr func(optionalAttr) - -// StringSplitV2Maxsplit sets the optional maxsplit attribute to value. -// -// value: An `int`. If `maxsplit > 0`, limit of the split of the result. -// If not specified, defaults to -1 -func StringSplitV2Maxsplit(value int64) StringSplitV2Attr { - return func(m optionalAttr) { - m["maxsplit"] = value - } -} - -// Split elements of `source` based on `sep` into a `SparseTensor`. -// -// Let N be the size of source (typically N will be the batch size). Split each -// element of `source` based on `sep` and return a `SparseTensor` -// containing the split tokens. Empty tokens are ignored. -// -// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', -// then the output will be -// ``` -// st.indices = [0, 0; -// 0, 1; -// 1, 0; -// 1, 1; -// 1, 2] -// st.shape = [2, 3] -// st.values = ['hello', 'world', 'a', 'b', 'c'] -// ``` -// -// If `sep` is given, consecutive delimiters are not grouped together and are -// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and -// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty -// string, consecutive whitespace are regarded as a single separator, and the -// result will contain no empty strings at the startor end if the string has -// leading or trailing whitespace. -// -// Note that the above mentioned behavior matches python's str.split. -// -// Arguments: -// input: `1-D` string `Tensor`, the strings to split. -// sep: `0-D` string `Tensor`, the delimiter character. -func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StringSplitV2", - Input: []tf.Input{ - input, sep, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // UniqueWithCountsAttr is an optional argument to UniqueWithCounts. type UniqueWithCountsAttr func(optionalAttr) @@ -25973,7 +25747,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26019,313 +25793,6 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil return op.Output(0) } -// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. -type Conv3DBackpropInputAttr func(optionalAttr) - -// Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} -func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of 3-D convolution with respect to the input. -// -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 -// -// Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Conv3DBackpropInput", - Input: []tf.Input{ - input, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates an all-zeros CSRSparseMatrix with shape `dense_shape`. -// -// Arguments: -// dense_shape: The desired matrix shape. -// -// -// Returns An empty CSR matrix with shape `dense_shape`. -func SparseMatrixZeros(scope *Scope, dense_shape tf.Output, type_ tf.DataType) (sparse_matrix tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"type": type_} - opspec := tf.OpSpec{ - Type: "SparseMatrixZeros", - Input: []tf.Input{ - dense_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Forwards `data` to the output port determined by `pred`. -// -// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, -// the data goes to `output_false`. -// -// See also `RefSwitch` and `Merge`. -// -// Arguments: -// data: The tensor to be forwarded to the appropriate output. -// pred: A scalar that specifies which output port will receive data. -// -// Returns: -// output_false: If `pred` is false, data will be forwarded to this output. -// output_true: If `pred` is true, data will be forwarded to this output. -func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Output, output_true tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Switch", - Input: []tf.Input{ - data, pred, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// UnicodeEncodeAttr is an optional argument to UnicodeEncode. -type UnicodeEncodeAttr func(optionalAttr) - -// UnicodeEncodeErrors sets the optional errors attribute to value. -// -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeEncodeErrors(value string) UnicodeEncodeAttr { - return func(m optionalAttr) { - m["errors"] = value - } -} - -// UnicodeEncodeReplacementChar sets the optional replacement_char attribute to value. -// -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD (U+65533). -// If not specified, defaults to 65533 -func UnicodeEncodeReplacementChar(value int64) UnicodeEncodeAttr { - return func(m optionalAttr) { - m["replacement_char"] = value - } -} - -// Encode a tensor of ints into unicode strings. -// -// Returns a vector of strings, where `output[i]` is constructed by encoding the -// Unicode codepoints in `input_values[input_splits[i]:input_splits[i+1]]` -// using `output_encoding`. -// -// --- -// -// Example: -// -// ``` -// input_values = [72, 101, 108, 108, 111, 87, 111, 114, 108, 100] -// input_splits = [0, 5, 10] -// output_encoding = 'UTF-8' -// -// output = ['Hello', 'World'] -// ``` -// -// Arguments: -// input_values: A 1D tensor containing the unicode codepoints that should be encoded. -// input_splits: A 1D tensor specifying how the unicode codepoints should be split into strings. -// In particular, `output[i]` is constructed by encoding the codepoints in the -// slice `input_values[input_splits[i]:input_splits[i+1]]`. -// output_encoding: Unicode encoding of the output strings. Valid encodings are: `"UTF-8", -// "UTF-16-BE", and "UTF-32-BE"`. -// -// Returns The 1-D Tensor of strings encoded from the provided unicode codepoints. -func UnicodeEncode(scope *Scope, input_values tf.Output, input_splits tf.Output, output_encoding string, optional ...UnicodeEncodeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_encoding": output_encoding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UnicodeEncode", - Input: []tf.Input{ - input_values, input_splits, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingFTRLParametersGradAccumDebug. -type RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Retrieve FTRL embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns: -// parameters: Parameter parameters updated by the FTRL optimization algorithm. -// accumulators: Parameter accumulators updated by the FTRL optimization algorithm. -// linears: Parameter linears updated by the FTRL optimization algorithm. -// gradient_accumulators: Parameter gradient_accumulators updated by the FTRL optimization algorithm. -func RetrieveTPUEmbeddingFTRLParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. -type DepthwiseConv2dNativeAttr func(optionalAttr) - -// DepthwiseConv2dNativeDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthwiseConv2dNativeDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} -func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. -// -// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -// and a filter / kernel tensor of shape -// `[filter_height, filter_width, in_channels, channel_multiplier]`, containing -// `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies -// a different filter to each input channel (expanding from 1 channel to -// `channel_multiplier` channels for each), then concatenates the results -// together. Thus, the output has `in_channels * channel_multiplier` channels. -// -// ``` -// for k in 0..in_channels-1 -// for q in 0..channel_multiplier-1 -// output[b, i, j, k * channel_multiplier + q] = -// sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * -// filter[di, dj, k, q] -// ``` -// -// Must have `strides[0] = strides[3] = 1`. For the most common case of the same -// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. -// -// Arguments: -// -// -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. -// padding: The type of padding algorithm to use. -func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNative", - Input: []tf.Input{ - input, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Updates the table to associates keys with values. // // The tensor `keys` must be of the same type as the keys of the table. @@ -26903,7 +26370,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -29414,59 +28881,6 @@ func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ... return op.Output(0) } -// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. -type CropAndResizeGradBoxesAttr func(optionalAttr) - -// CropAndResizeGradBoxesMethod sets the optional method attribute to value. -// -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. -// -// Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -// Both `image_height` and `image_width` need to be positive. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// -// Returns A 2-D tensor of shape `[num_boxes, 4]`. -func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CropAndResizeGradBoxes", - Input: []tf.Input{ - grads, image, boxes, box_ind, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the derivative of a Gamma random sample w.r.t. `alpha`. func RandomGammaGrad(scope *Scope, alpha tf.Output, sample tf.Output) (output tf.Output) { if scope.Err() != nil { @@ -33025,6 +32439,97 @@ func CSRSparseMatrixComponents(scope *Scope, csr_sparse_matrix tf.Output, index return op.Output(0), op.Output(1), op.Output(2) } +// StringSplitV2Attr is an optional argument to StringSplitV2. +type StringSplitV2Attr func(optionalAttr) + +// StringSplitV2Maxsplit sets the optional maxsplit attribute to value. +// +// value: An `int`. If `maxsplit > 0`, limit of the split of the result. +// If not specified, defaults to -1 +func StringSplitV2Maxsplit(value int64) StringSplitV2Attr { + return func(m optionalAttr) { + m["maxsplit"] = value + } +} + +// Split elements of `source` based on `sep` into a `SparseTensor`. +// +// Let N be the size of source (typically N will be the batch size). Split each +// element of `source` based on `sep` and return a `SparseTensor` +// containing the split tokens. Empty tokens are ignored. +// +// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', +// then the output will be +// ``` +// st.indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// st.shape = [2, 3] +// st.values = ['hello', 'world', 'a', 'b', 'c'] +// ``` +// +// If `sep` is given, consecutive delimiters are not grouped together and are +// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and +// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty +// string, consecutive whitespace are regarded as a single separator, and the +// result will contain no empty strings at the startor end if the string has +// leading or trailing whitespace. +// +// Note that the above mentioned behavior matches python's str.split. +// +// Arguments: +// input: `1-D` string `Tensor`, the strings to split. +// sep: `0-D` string `Tensor`, the delimiter character. +func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StringSplitV2", + Input: []tf.Input{ + input, sep, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Compute the lower regularized incomplete Gamma function `P(a, x)`. +// +// The lower regularized incomplete Gamma function is defined as: +// +// +// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) +// +// where +// +// \\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) +// +// is the lower incomplete Gamma function. +// +// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete +// Gamma function. +func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Igamma", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Convert a (possibly batched) CSRSparseMatrix to dense. // // Arguments: @@ -36500,67 +36005,6 @@ func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Outpu return op.Output(0), op.Output(1) } -// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. -type NonMaxSuppressionAttr func(optionalAttr) - -// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. -// -// value: A float representing the threshold for deciding whether boxes -// overlap too much with respect to IOU. -// If not specified, defaults to 0.5 -func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { - return func(m optionalAttr) { - m["iou_threshold"] = value - } -} - -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) -// -// Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "NonMaxSuppression", - Input: []tf.Input{ - boxes, scores, max_output_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Removes keys and its associated values from a table. // // The tensor `keys` must of the same type as the keys of the table. Keys not @@ -37799,6 +37243,400 @@ func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// Creates a dataset that passes a sliding window over `input_dataset`. +// +// Arguments: +// +// window_size: A scalar representing the number of elements in the +// sliding window. +// window_shift: A scalar representing the steps moving the sliding window +// forward in one iteration. It must be positive. +// window_stride: A scalar representing the stride of the input elements of the sliding window. +// It must be positive. +// +// +func SlidingWindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, window_shift tf.Output, window_stride tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "SlidingWindowDataset", + Input: []tf.Input{ + input_dataset, window_size, window_shift, window_stride, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Locks a mutex resource. The output is the lock. So long as the lock tensor +// +// is alive, any other request to use `MutexLock` with this mutex will wait. +// +// This is particularly useful for creating a critical section when used in +// conjunction with `MutexLockIdentity`: +// +// ```python +// +// mutex = mutex_v2( +// shared_name=handle_name, container=container, name=name) +// +// def execute_in_critical_section(fn, *args, **kwargs): +// lock = gen_resource_variable_ops.mutex_lock(mutex) +// +// with ops.control_dependencies([lock]): +// r = fn(*args, **kwargs) +// +// with ops.control_dependencies(nest.flatten(r)): +// with ops.colocate_with(mutex): +// ensure_lock_exists = mutex_lock_identity(lock) +// +// # Make sure that if any element of r is accessed, all of +// # them are executed together. +// r = nest.map_structure(tf.identity, r) +// +// with ops.control_dependencies([ensure_lock_exists]): +// return nest.map_structure(tf.identity, r) +// ``` +// +// While `fn` is running in the critical section, no other functions which wish to +// use this critical section may run. +// +// Often the use case is that two executions of the same graph, in parallel, +// wish to run `fn`; and we wish to ensure that only one of them executes +// at a time. This is especially important if `fn` modifies one or more +// variables at a time. +// +// It is also useful if two separate functions must share a resource, but we +// wish to ensure the usage is exclusive. +// +// Arguments: +// mutex: The mutex resource to lock. +// +// Returns A tensor that keeps a shared pointer to a lock on the mutex; +// when the Tensor is destroyed, the use count on the shared pointer is decreased +// by 1. When it reaches 0, the lock is released. +func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MutexLock", + Input: []tf.Input{ + mutex, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolGradWithArgmaxAttr is an optional argument to MaxPoolGradWithArgmax. +type MaxPoolGradWithArgmaxAttr func(optionalAttr) + +// MaxPoolGradWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. +// +// value: Whether to include batch dimension in flattened index of `argmax`. +// If not specified, defaults to false +func MaxPoolGradWithArgmaxIncludeBatchInIndex(value bool) MaxPoolGradWithArgmaxAttr { + return func(m optionalAttr) { + m["include_batch_in_index"] = value + } +} + +// Computes gradients of the maxpooling function. +// +// Arguments: +// input: The original input. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the +// output of `max_pool`. +// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients w.r.t. the input of `max_pool`. +func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradWithArgmaxAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolGradWithArgmax", + Input: []tf.Input{ + input, grad, argmax, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 2D fast Fourier transform. +// +// Computes the 2-dimensional discrete Fourier transform over the inner-most +// 2 dimensions of `input`. +// +// Arguments: +// input: A complex tensor. +// +// Returns A complex tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft2 +// @end_compatibility +func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SdcaOptimizerAttr is an optional argument to SdcaOptimizer. +type SdcaOptimizerAttr func(optionalAttr) + +// SdcaOptimizerAdaptative sets the optional adaptative attribute to value. +// +// value: Whether to use Adaptive SDCA for the inner loop. +// If not specified, defaults to true +func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { + return func(m optionalAttr) { + m["adaptative"] = value + } +} + +// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for +// +// linear models with L1 + L2 regularization. As global optimization objective is +// strongly-convex, the optimizer optimizes the dual objective at each step. The +// optimizer applies each update one example at a time. Examples are sampled +// uniformly, and the optimizer is learning rate free and enjoys linear convergence +// rate. +// +// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
+// Shai Shalev-Shwartz, Tong Zhang. 2012 +// +// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ +// +// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
+// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, +// Peter Richtarik, Martin Takac. 2015 +// +// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
+// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 +// +// Arguments: +// sparse_example_indices: a list of vectors which contain example indices. +// sparse_feature_indices: a list of vectors which contain feature indices. +// sparse_feature_values: a list of vectors which contains feature value +// associated with each feature group. +// dense_features: a list of matrices which contains the dense feature values. +// example_weights: a vector which contains the weight associated with each +// example. +// example_labels: a vector which contains the label/target associated with each +// example. +// sparse_indices: a list of vectors where each value is the indices which has +// corresponding weights in sparse_weights. This field maybe omitted for the +// dense approach. +// sparse_weights: a list of vectors where each value is the weight associated with +// a sparse feature group. +// dense_weights: a list of vectors where the values are the weights associated +// with a dense feature group. +// example_state_data: a list of vectors containing the example state data. +// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, +// squared and hinge losses. +// l1: Symmetric l1 regularization strength. +// l2: Symmetric l2 regularization strength. +// num_loss_partitions: Number of partitions of the global loss function. +// num_inner_iterations: Number of iterations per mini-batch. +// +// Returns: +// out_example_state_data: a list of vectors containing the updated example state +// data. +// out_delta_sparse_weights: a list of vectors where each value is the delta +// weights associated with a sparse feature group. +// out_delta_dense_weights: a list of vectors where the values are the delta +// weights associated with a dense feature group. +func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SdcaOptimizer", + Input: []tf.Input{ + tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + out_example_state_data = op.Output(idx) + if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return + } + if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return + } + return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights +} + +// Inverse fast Fourier transform. +// +// Computes the inverse 1-dimensional discrete Fourier transform over the +// inner-most dimension of `input`. +// +// Arguments: +// input: A complex tensor. +// +// Returns A complex tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its inverse 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft +// @end_compatibility +func IFFT(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IFFT", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the next record (key, value pair) produced by a Reader. +// +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). +// +// Arguments: +// reader_handle: Handle to a Reader. +// queue_handle: Handle to a Queue, with string work items. +// +// Returns: +// key: A scalar. +// value: A scalar. +func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderReadV2", + Input: []tf.Input{ + reader_handle, queue_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// CumprodAttr is an optional argument to Cumprod. +type CumprodAttr func(optionalAttr) + +// CumprodExclusive sets the optional exclusive attribute to value. +// +// value: If `True`, perform exclusive cumprod. +// If not specified, defaults to false +func CumprodExclusive(value bool) CumprodAttr { + return func(m optionalAttr) { + m["exclusive"] = value + } +} + +// CumprodReverse sets the optional reverse attribute to value. +// +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumprodReverse(value bool) CumprodAttr { + return func(m optionalAttr) { + m["reverse"] = value + } +} + +// Compute the cumulative product of the tensor `x` along `axis`. +// +// By default, this op performs an inclusive cumprod, which means that the first +// element of the input is identical to the first element of the output: +// +// ```python +// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +// ``` +// +// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +// performed instead: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +// ``` +// +// By setting the `reverse` kwarg to `True`, the cumprod is performed in the +// opposite direction: +// +// ```python +// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +// ``` +// +// This is more efficient than using separate `tf.reverse` ops. +// +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +// ``` +// +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Cumprod", + Input: []tf.Input{ + x, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // CollectiveGatherAttr is an optional argument to CollectiveGather. type CollectiveGatherAttr func(optionalAttr) @@ -38039,112 +37877,6 @@ func StatefulTruncatedNormal(scope *Scope, resource tf.Output, algorithm tf.Outp return op.Output(0) } -// Returns the next record (key, value pair) produced by a Reader. -// -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). -// -// Arguments: -// reader_handle: Handle to a Reader. -// queue_handle: Handle to a Queue, with string work items. -// -// Returns: -// key: A scalar. -// value: A scalar. -func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderReadV2", - Input: []tf.Input{ - reader_handle, queue_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// CumprodAttr is an optional argument to Cumprod. -type CumprodAttr func(optionalAttr) - -// CumprodExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumprod. -// If not specified, defaults to false -func CumprodExclusive(value bool) CumprodAttr { - return func(m optionalAttr) { - m["exclusive"] = value - } -} - -// CumprodReverse sets the optional reverse attribute to value. -// -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumprodReverse(value bool) CumprodAttr { - return func(m optionalAttr) { - m["reverse"] = value - } -} - -// Compute the cumulative product of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumprod, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is -// performed instead: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] -// ``` -// -// By setting the `reverse` kwarg to `True`, the cumprod is performed in the -// opposite direction: -// -// ```python -// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] -// ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] -// ``` -// -// Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Cumprod", - Input: []tf.Input{ - x, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns the rank of a tensor. // // This operation returns an integer representing the rank of `input`. @@ -38332,132 +38064,6 @@ func IsBoostedTreesQuantileStreamResourceInitialized(scope *Scope, quantile_stre return op.Output(0) } -// SdcaOptimizerAttr is an optional argument to SdcaOptimizer. -type SdcaOptimizerAttr func(optionalAttr) - -// SdcaOptimizerAdaptative sets the optional adaptative attribute to value. -// -// value: Whether to use Adaptive SDCA for the inner loop. -// If not specified, defaults to true -func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { - return func(m optionalAttr) { - m["adaptative"] = value - } -} - -// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for -// -// linear models with L1 + L2 regularization. As global optimization objective is -// strongly-convex, the optimizer optimizes the dual objective at each step. The -// optimizer applies each update one example at a time. Examples are sampled -// uniformly, and the optimizer is learning rate free and enjoys linear convergence -// rate. -// -// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
-// Shai Shalev-Shwartz, Tong Zhang. 2012 -// -// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ -// -// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
-// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, -// Peter Richtarik, Martin Takac. 2015 -// -// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
-// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 -// -// Arguments: -// sparse_example_indices: a list of vectors which contain example indices. -// sparse_feature_indices: a list of vectors which contain feature indices. -// sparse_feature_values: a list of vectors which contains feature value -// associated with each feature group. -// dense_features: a list of matrices which contains the dense feature values. -// example_weights: a vector which contains the weight associated with each -// example. -// example_labels: a vector which contains the label/target associated with each -// example. -// sparse_indices: a list of vectors where each value is the indices which has -// corresponding weights in sparse_weights. This field maybe omitted for the -// dense approach. -// sparse_weights: a list of vectors where each value is the weight associated with -// a sparse feature group. -// dense_weights: a list of vectors where the values are the weights associated -// with a dense feature group. -// example_state_data: a list of vectors containing the example state data. -// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, -// squared and hinge losses. -// l1: Symmetric l1 regularization strength. -// l2: Symmetric l2 regularization strength. -// num_loss_partitions: Number of partitions of the global loss function. -// num_inner_iterations: Number of iterations per mini-batch. -// -// Returns: -// out_example_state_data: a list of vectors containing the updated example state -// data. -// out_delta_sparse_weights: a list of vectors where each value is the delta -// weights associated with a sparse feature group. -// out_delta_dense_weights: a list of vectors where the values are the delta -// weights associated with a dense feature group. -func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SdcaOptimizer", - Input: []tf.Input{ - tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - out_example_state_data = op.Output(idx) - if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return - } - if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return - } - return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights -} - -// Inverse fast Fourier transform. -// -// Computes the inverse 1-dimensional discrete Fourier transform over the -// inner-most dimension of `input`. -// -// Arguments: -// input: A complex tensor. -// -// Returns A complex tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its inverse 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.ifft -// @end_compatibility -func IFFT(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IFFT", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Fast Fourier transform. // // Computes the 1-dimensional discrete Fourier transform over the inner-most @@ -39251,6 +38857,170 @@ func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_f return op.Output(0), op.Output(1), op.Output(2) } +// Forwards `data` to the output port determined by `pred`. +// +// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, +// the data goes to `output_false`. +// +// See also `RefSwitch` and `Merge`. +// +// Arguments: +// data: The tensor to be forwarded to the appropriate output. +// pred: A scalar that specifies which output port will receive data. +// +// Returns: +// output_false: If `pred` is false, data will be forwarded to this output. +// output_true: If `pred` is true, data will be forwarded to this output. +func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Output, output_true tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Switch", + Input: []tf.Input{ + data, pred, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingFTRLParametersGradAccumDebug. +type RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve FTRL embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns: +// parameters: Parameter parameters updated by the FTRL optimization algorithm. +// accumulators: Parameter accumulators updated by the FTRL optimization algorithm. +// linears: Parameter linears updated by the FTRL optimization algorithm. +// gradient_accumulators: Parameter gradient_accumulators updated by the FTRL optimization algorithm. +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// UnicodeEncodeAttr is an optional argument to UnicodeEncode. +type UnicodeEncodeAttr func(optionalAttr) + +// UnicodeEncodeErrors sets the optional errors attribute to value. +// +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeEncodeErrors(value string) UnicodeEncodeAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeEncodeReplacementChar sets the optional replacement_char attribute to value. +// +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD (U+65533). +// If not specified, defaults to 65533 +func UnicodeEncodeReplacementChar(value int64) UnicodeEncodeAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// Encode a tensor of ints into unicode strings. +// +// Returns a vector of strings, where `output[i]` is constructed by encoding the +// Unicode codepoints in `input_values[input_splits[i]:input_splits[i+1]]` +// using `output_encoding`. +// +// --- +// +// Example: +// +// ``` +// input_values = [72, 101, 108, 108, 111, 87, 111, 114, 108, 100] +// input_splits = [0, 5, 10] +// output_encoding = 'UTF-8' +// +// output = ['Hello', 'World'] +// ``` +// +// Arguments: +// input_values: A 1D tensor containing the unicode codepoints that should be encoded. +// input_splits: A 1D tensor specifying how the unicode codepoints should be split into strings. +// In particular, `output[i]` is constructed by encoding the codepoints in the +// slice `input_values[input_splits[i]:input_splits[i+1]]`. +// output_encoding: Unicode encoding of the output strings. Valid encodings are: `"UTF-8", +// "UTF-16-BE", and "UTF-32-BE"`. +// +// Returns The 1-D Tensor of strings encoded from the provided unicode codepoints. +func UnicodeEncode(scope *Scope, input_values tf.Output, input_splits tf.Output, output_encoding string, optional ...UnicodeEncodeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_encoding": output_encoding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UnicodeEncode", + Input: []tf.Input{ + input_values, input_splits, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // PrelinearizeTupleAttr is an optional argument to PrelinearizeTuple. type PrelinearizeTupleAttr func(optionalAttr) @@ -40497,6 +40267,66 @@ func ShutdownDistributedTPU(scope *Scope) (o *tf.Operation) { return scope.AddOperation(opspec) } +// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. +type ResourceApplyMomentumAttr func(optionalAttr) + +// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, the tensor passed to compute grad will be +// var - lr * momentum * accum, so in the end, the var you get is actually +// var - lr * momentum * accum. +// If not specified, defaults to false +func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the momentum scheme. Set use_nesterov = True if you +// +// want to use Nesterov momentum. +// +// accum = accum * momentum + grad +// var -= lr * accum +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// grad: The gradient. +// momentum: Momentum. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyMomentum", + Input: []tf.Input{ + var_, accum, lr, grad, momentum, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // Returns the value stored in an Optional variant or raises an error if none exists. func OptionalGetValue(scope *Scope, optional tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { @@ -42073,6 +41903,68 @@ func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Ou return op.Output(0) } +// Computes sigmoid of `x` element-wise. +// +// Specifically, `y = 1 / (1 + exp(-x))`. +func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sigmoid", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. +type ResourceSparseApplyAdadeltaAttr func(optionalAttr) + +// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// var: Should be from a Variable(). +// +// Arguments: +// +// accum: Should be from a Variable(). +// accum_update: : Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyAdadelta", + Input: []tf.Input{ + var_, accum, accum_update, lr, rho, epsilon, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. type ResourceApplyAdamAttr func(optionalAttr) @@ -42136,68 +42028,6 @@ func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, b return scope.AddOperation(opspec) } -// Computes sigmoid of `x` element-wise. -// -// Specifically, `y = 1 / (1 + exp(-x))`. -func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sigmoid", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. -type ResourceSparseApplyAdadeltaAttr func(optionalAttr) - -// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// var: Should be from a Variable(). -// -// Arguments: -// -// accum: Should be from a Variable(). -// accum_update: : Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// -// Returns the created operation. -func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdadelta", - Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingADAMParametersGradAccumDebug. type RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) @@ -42539,6 +42369,159 @@ func RetrieveTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, num_s return op.Output(0) } +// DatasetToGraphAttr is an optional argument to DatasetToGraph. +type DatasetToGraphAttr func(optionalAttr) + +// DatasetToGraphStatefulWhitelist sets the optional stateful_whitelist attribute to value. +// If not specified, defaults to {} +// +// REQUIRES: len(value) >= 0 +func DatasetToGraphStatefulWhitelist(value []string) DatasetToGraphAttr { + return func(m optionalAttr) { + m["stateful_whitelist"] = value + } +} + +// DatasetToGraphAllowStateful sets the optional allow_stateful attribute to value. +// If not specified, defaults to false +func DatasetToGraphAllowStateful(value bool) DatasetToGraphAttr { + return func(m optionalAttr) { + m["allow_stateful"] = value + } +} + +// DatasetToGraphStripDeviceAssignment sets the optional strip_device_assignment attribute to value. +// If not specified, defaults to false +func DatasetToGraphStripDeviceAssignment(value bool) DatasetToGraphAttr { + return func(m optionalAttr) { + m["strip_device_assignment"] = value + } +} + +// Returns a serialized GraphDef representing `input_dataset`. +// +// Returns a graph representation for `input_dataset`. +// +// Arguments: +// input_dataset: A variant tensor representing the dataset to return the graph representation for. +// +// Returns The graph representation of the dataset (as serialized GraphDef). +func DatasetToGraph(scope *Scope, input_dataset tf.Output, optional ...DatasetToGraphAttr) (graph tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DatasetToGraph", + Input: []tf.Input{ + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. +type QuantizeAndDequantizeV3Attr func(optionalAttr) + +// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["signed_input"] = value + } +} + +// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// QuantizeAndDequantizeV3NarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func QuantizeAndDequantizeV3NarrowRange(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// QuantizeAndDequantizeV3Axis sets the optional axis attribute to value. +// If not specified, defaults to -1 +func QuantizeAndDequantizeV3Axis(value int64) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["axis"] = value + } +} + +// Quantizes then dequantizes a tensor. +// +// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a +// tensor, so its value can change during training. +func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizeAndDequantizeV3", + Input: []tf.Input{ + input, input_min, input_max, num_bits, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x * y element-wise. +// +// *NOTE*: `Multiply` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Mul(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Mul", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes softplus gradients for a softplus operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding softplus operation. +// features: The features passed as input to the corresponding softplus operation. +// +// Returns The gradients: `gradients / (1 + exp(-features))`. +func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SoftplusGrad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns x / y element-wise. // // *NOTE*: `Div` supports broadcasting. More about broadcasting @@ -43857,66 +43840,6 @@ func LoadTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, parameter return scope.AddOperation(opspec) } -// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. -type ResourceApplyMomentumAttr func(optionalAttr) - -// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var - lr * momentum * accum, so in the end, the var you get is actually -// var - lr * momentum * accum. -// If not specified, defaults to false -func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update '*var' according to the momentum scheme. Set use_nesterov = True if you -// -// want to use Nesterov momentum. -// -// accum = accum * momentum + grad -// var -= lr * accum -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. -// momentum: Momentum. Must be a scalar. -// -// Returns the created operation. -func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyMomentum", - Input: []tf.Input{ - var_, accum, lr, grad, momentum, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // Strip leading and trailing whitespaces from the Tensor. // // Arguments: @@ -45531,168 +45454,6 @@ func TPUOrdinalSelector(scope *Scope) (device_ordinals tf.Output) { return op.Output(0) } -// Creates a dataset that passes a sliding window over `input_dataset`. -// -// Arguments: -// -// window_size: A scalar representing the number of elements in the -// sliding window. -// window_shift: A scalar representing the steps moving the sliding window -// forward in one iteration. It must be positive. -// window_stride: A scalar representing the stride of the input elements of the sliding window. -// It must be positive. -// -// -func SlidingWindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, window_shift tf.Output, window_stride tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "SlidingWindowDataset", - Input: []tf.Input{ - input_dataset, window_size, window_shift, window_stride, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Locks a mutex resource. The output is the lock. So long as the lock tensor -// -// is alive, any other request to use `MutexLock` with this mutex will wait. -// -// This is particularly useful for creating a critical section when used in -// conjunction with `MutexLockIdentity`: -// -// ```python -// -// mutex = mutex_v2( -// shared_name=handle_name, container=container, name=name) -// -// def execute_in_critical_section(fn, *args, **kwargs): -// lock = gen_resource_variable_ops.mutex_lock(mutex) -// -// with ops.control_dependencies([lock]): -// r = fn(*args, **kwargs) -// -// with ops.control_dependencies(nest.flatten(r)): -// with ops.colocate_with(mutex): -// ensure_lock_exists = mutex_lock_identity(lock) -// -// # Make sure that if any element of r is accessed, all of -// # them are executed together. -// r = nest.map_structure(tf.identity, r) -// -// with ops.control_dependencies([ensure_lock_exists]): -// return nest.map_structure(tf.identity, r) -// ``` -// -// While `fn` is running in the critical section, no other functions which wish to -// use this critical section may run. -// -// Often the use case is that two executions of the same graph, in parallel, -// wish to run `fn`; and we wish to ensure that only one of them executes -// at a time. This is especially important if `fn` modifies one or more -// variables at a time. -// -// It is also useful if two separate functions must share a resource, but we -// wish to ensure the usage is exclusive. -// -// Arguments: -// mutex: The mutex resource to lock. -// -// Returns A tensor that keeps a shared pointer to a lock on the mutex; -// when the Tensor is destroyed, the use count on the shared pointer is decreased -// by 1. When it reaches 0, the lock is released. -func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MutexLock", - Input: []tf.Input{ - mutex, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolGradWithArgmaxAttr is an optional argument to MaxPoolGradWithArgmax. -type MaxPoolGradWithArgmaxAttr func(optionalAttr) - -// MaxPoolGradWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. -// -// value: Whether to include batch dimension in flattened index of `argmax`. -// If not specified, defaults to false -func MaxPoolGradWithArgmaxIncludeBatchInIndex(value bool) MaxPoolGradWithArgmaxAttr { - return func(m optionalAttr) { - m["include_batch_in_index"] = value - } -} - -// Computes gradients of the maxpooling function. -// -// Arguments: -// input: The original input. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the -// output of `max_pool`. -// argmax: The indices of the maximum values chosen for each output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients w.r.t. the input of `max_pool`. -func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradWithArgmaxAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradWithArgmax", - Input: []tf.Input{ - input, grad, argmax, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 2D fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform over the inner-most -// 2 dimensions of `input`. -// -// Arguments: -// input: A complex tensor. -// -// Returns A complex tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fft2 -// @end_compatibility -func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. type Conv2DBackpropFilterAttr func(optionalAttr) @@ -45738,7 +45499,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46522,6 +46283,149 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms return scope.AddOperation(opspec) } +// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. +type Conv3DBackpropInputAttr func(optionalAttr) + +// Conv3DBackpropInputDilations sets the optional dilations attribute to value. +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of 3-D convolution with respect to the input. +// +// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv3DBackpropInput", + Input: []tf.Input{ + input, filter, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. +type DepthwiseConv2dNativeAttr func(optionalAttr) + +// DepthwiseConv2dNativeDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthwiseConv2dNativeDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to {i:1 i:1 i:1 i:1} +func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. +// +// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +// and a filter / kernel tensor of shape +// `[filter_height, filter_width, in_channels, channel_multiplier]`, containing +// `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies +// a different filter to each input channel (expanding from 1 channel to +// `channel_multiplier` channels for each), then concatenates the results +// together. Thus, the output has `in_channels * channel_multiplier` channels. +// +// ``` +// for k in 0..in_channels-1 +// for q in 0..channel_multiplier-1 +// output[b, i, j, k * channel_multiplier + q] = +// sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * +// filter[di, dj, k, q] +// ``` +// +// Must have `strides[0] = strides[3] = 1`. For the most common case of the same +// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// +// Arguments: +// +// +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. +// padding: The type of padding algorithm to use. +func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DepthwiseConv2dNative", + Input: []tf.Input{ + input, filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates an all-zeros CSRSparseMatrix with shape `dense_shape`. +// +// Arguments: +// dense_shape: The desired matrix shape. +// +// +// Returns An empty CSR matrix with shape `dense_shape`. +func SparseMatrixZeros(scope *Scope, dense_shape tf.Output, type_ tf.DataType) (sparse_matrix tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"type": type_} + opspec := tf.OpSpec{ + Type: "SparseMatrixZeros", + Input: []tf.Input{ + dense_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // EnqueueTPUEmbeddingIntegerBatchAttr is an optional argument to EnqueueTPUEmbeddingIntegerBatch. type EnqueueTPUEmbeddingIntegerBatchAttr func(optionalAttr) @@ -48497,6 +48401,102 @@ func StatelessRandomUniformInt(scope *Scope, shape tf.Output, seed tf.Output, mi return op.Output(0) } +// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. +type CropAndResizeGradImageAttr func(optionalAttr) + +// CropAndResizeGradImageMethod sets the optional method attribute to value. +// +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { + return func(m optionalAttr) { + m["method"] = value + } +} + +// Computes the gradient of the crop_and_resize op wrt the input image tensor. +// +// Arguments: +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` +// containing the original image size. Both `image_height` and `image_width` need +// to be positive. +// +// +// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CropAndResizeGradImage", + Input: []tf.Input{ + grads, boxes, box_ind, image_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// OutfeedDequeueAttr is an optional argument to OutfeedDequeue. +type OutfeedDequeueAttr func(optionalAttr) + +// OutfeedDequeueDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func OutfeedDequeueDeviceOrdinal(value int64) OutfeedDequeueAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// Retrieves a single tensor from the computation outfeed. +// +// This operation will block indefinitely until data is available. +// +// Arguments: +// dtype: The type of elements in the tensor. +// shape: The shape of the tensor. +// +// Returns A tensor that will be read from the device outfeed. +func OutfeedDequeue(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...OutfeedDequeueAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OutfeedDequeue", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // LSTMBlockCellAttr is an optional argument to LSTMBlockCell. type LSTMBlockCellAttr func(optionalAttr) From d651e3ea4bf452824ea707cbc0d15a0efa170f2c Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 18 Mar 2020 14:16:28 -0700 Subject: [PATCH 170/492] Change indices from int to pointer-sized integers. This removes large number of mov r32,r32 instructions in the generated code. PiperOrigin-RevId: 301665968 Change-Id: Ic2b07354bd7d3fbca711656eb05c0f0c4c57e407 --- .../internal/optimized/sse_tensor_utils.cc | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc index 26395a2a704..fe970dd8b39 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc @@ -22,6 +22,8 @@ limitations under the License. #include // SSE4.1 #endif +#include + #include "tensorflow/lite/kernels/internal/compatibility.h" namespace tflite { @@ -92,16 +94,16 @@ void SseMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ vectors, const float* __restrict__ scaling_factors, int n_batch, float* __restrict__ result) { - for (int batch = 0; batch < n_batch; ++batch) { + for (std::intptr_t batch = 0; batch < n_batch; ++batch) { const float batch_scaling_factor = scaling_factors[batch]; // Compute dot-product for every column. - for (int row = 0; row < m_rows; ++row) { + for (std::intptr_t row = 0; row < m_rows; ++row) { // Get the address of the first element of the row. const int8_t* __restrict__ row_ptr = matrix + row * m_cols; // Initialize the dot product sum for the row to 0. __m128i dotprod_32x4 = _mm_setzero_si128(); - int col = 0; + std::intptr_t col = 0; // For every block of 16x 8-bit inputs. while (col < (m_cols & ~15)) { const __m128i vec_8x16 = @@ -165,10 +167,10 @@ void SseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ scaling_factors, int n_batch, float* __restrict__ result, const float* __restrict__ per_channel_scale, const int32_t* __restrict__ input_offset) { - static constexpr int kBlockSize = 16; - for (int batch = 0; batch < n_batch; ++batch) { + static constexpr std::intptr_t kBlockSize = 16; + for (std::intptr_t batch = 0; batch < n_batch; ++batch) { const float batch_scaling_factor = scaling_factors[batch]; - for (int row = 0; row < m_rows; ++row) { + for (std::intptr_t row = 0; row < m_rows; ++row) { const int8_t* __restrict__ row_ptr = matrix + row * m_cols; float scale = batch_scaling_factor; if (per_channel_scale != nullptr) { @@ -176,7 +178,7 @@ void SseMatrixBatchVectorMultiplyAccumulate( } __m128i dotprod_32x4 = _mm_setzero_si128(); __m128i row_sum_16x8 = _mm_setzero_si128(); - int col = 0; + std::intptr_t col = 0; for (; col < (m_cols & ~(kBlockSize - 1)); col += kBlockSize) { const __m128i vec_8x16 = _mm_loadu_si128(reinterpret_cast(vectors + col)); @@ -217,15 +219,15 @@ inline void SseSparseMatrixVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger, const int m_rows, const int m_cols, const int8_t* __restrict__ vector, const float scaling_factor, float* __restrict__ result) { - static const int kBlockSize = 16; + static const std::intptr_t kBlockSize = 16; TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0); const uint8_t* __restrict__ ledger_ptr = ledger; - for (int row = 0; row < m_rows; ++row) { + for (std::intptr_t row = 0; row < m_rows; ++row) { // Initialize the dot product sum for the row to 0. __m128i dotprod_32x4 = _mm_setzero_si128(); - int num_nonzero_blocks = *ledger_ptr++; - for (int i = 0; i < num_nonzero_blocks; i++) { - const int col_index = *ledger_ptr++ * kBlockSize; + std::intptr_t num_nonzero_blocks = *ledger_ptr++; + for (std::intptr_t i = 0; i < num_nonzero_blocks; i++) { + const std::intptr_t col_index = *ledger_ptr++ * kBlockSize; const __m128i vec_8x16 = _mm_loadu_si128(reinterpret_cast(vector + col_index)); const __m128i row_8x16 = @@ -251,7 +253,7 @@ inline void SseSparseMatrix4VectorsMultiplyAccumulate( const int m_rows, const int m_cols, const int8_t* __restrict__ const vectors, const __m128 scaling_factors_fx4, float* __restrict__ const results) { - static const int kBlockSize = 16; + static const std::intptr_t kBlockSize = 16; TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0); const int8_t* __restrict__ vector0 = vectors + 0 * m_cols; @@ -263,16 +265,16 @@ inline void SseSparseMatrix4VectorsMultiplyAccumulate( float* __restrict__ result2 = results + 2 * m_rows; float* __restrict__ result3 = results + 3 * m_rows; - for (int row = 0; row < m_rows; ++row) { + for (std::intptr_t row = 0; row < m_rows; ++row) { // Initialize the dot product sum for the row to 0. __m128i dp0_32x4 = _mm_setzero_si128(); __m128i dp1_32x4 = _mm_setzero_si128(); __m128i dp2_32x4 = _mm_setzero_si128(); __m128i dp3_32x4 = _mm_setzero_si128(); - int num_nonzero_blocks = *ledger++; - for (int i = 0; i < num_nonzero_blocks; i++) { - const int col_index = *ledger++ * kBlockSize; + std::intptr_t num_nonzero_blocks = *ledger++; + for (std::intptr_t i = 0; i < num_nonzero_blocks; i++) { + const std::intptr_t col_index = *ledger++ * kBlockSize; // vecN are for different batches const __m128i vec0_8x16 = _mm_loadu_si128( reinterpret_cast(vector0 + col_index)); From 5041af787968ee9614341af883d370ed73f7d84f Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Wed, 18 Mar 2020 14:23:06 -0700 Subject: [PATCH 171/492] Restrict Keras package visibility. PiperOrigin-RevId: 301667734 Change-Id: I0406105cf93d1d2296905f434a80b643618133f7 --- tensorflow/python/keras/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index f3078ddaabb..6af56e7ab77 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -19,7 +19,6 @@ py_library( "ops.py", ], srcs_version = "PY2AND3", - visibility = ["//visibility:public"], deps = [ ":backend", ":engine", From 5b00c31c4b517f94b60ce6c1723fe2702eddcd02 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 14:23:48 -0700 Subject: [PATCH 172/492] makes gradient_checker_v2 work with sparse tensor reshape. PiperOrigin-RevId: 301667976 Change-Id: I133c918c8a5703b40d4b5a69ead4bc0ec74f27b8 --- tensorflow/python/ops/gradient_checker_v2.py | 4 +++- .../python/ops/gradient_checker_v2_test.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/gradient_checker_v2.py b/tensorflow/python/ops/gradient_checker_v2.py index 4edb5adb113..633b5e57d95 100644 --- a/tensorflow/python/ops/gradient_checker_v2.py +++ b/tensorflow/python/ops/gradient_checker_v2.py @@ -174,7 +174,6 @@ def _compute_theoretical_jacobian(f, y_shape, y_dtype, xs, param): dy_data_flat[row] = 1 grad = _to_numpy(grad_fn(dy_data, *xs)[0]) grad = _eval_indexed_slices(grad) - dy_data_flat[row] = 0 if isinstance(grad, ops.IndexedSlicesValue): for i, v in zip(grad.indices, grad.values): c_begin = i * x_val_size @@ -182,6 +181,9 @@ def _compute_theoretical_jacobian(f, y_shape, y_dtype, xs, param): jacobian[row, c_begin:c_end] += v.flat elif grad is not None: jacobian[row, :] = grad.ravel().view(jacobian.dtype) + # This reset of `dy_data_flat` needs to happen after `grad` is copied to + # `jacobian` because `grad` and `dy_data_flat` may share memory. + dy_data_flat[row] = 0 # If the output is empty, run the gradients at least once and make sure # they produce zeros. diff --git a/tensorflow/python/ops/gradient_checker_v2_test.py b/tensorflow/python/ops/gradient_checker_v2_test.py index 96a942d8a90..b77c95d8968 100644 --- a/tensorflow/python/ops/gradient_checker_v2_test.py +++ b/tensorflow/python/ops/gradient_checker_v2_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient @@ -30,6 +31,7 @@ from tensorflow.python.ops import \ gradient_checker_v2 as gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import sparse_ops # needs this to register gradient for SoftmaxCrossEntropyWithLogits: import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -46,6 +48,20 @@ def _random_complex(shape, dtype): @test_util.run_all_in_graph_and_eager_modes class GradientCheckerTest(test.TestCase): + def testSparseTensorReshape(self): + x = constant_op.constant(2.0, shape=(2,)) + + def sparse_tensor_reshape(values): + sparse = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], values=values, dense_shape=[3, 4]) + sparse = sparse_ops.sparse_reshape(sparse, shape=(12,)) + return sparse.values + + error = gradient_checker.max_error( + *gradient_checker.compute_gradient(sparse_tensor_reshape, [x])) + + self.assertLess(error, 1e-4) + def testWithStaticShape(self): size = (2, 3) constant = constant_op.constant(2.0, shape=size, name="const") From bfafc1acef59ff5a7ba2bf2675350812e552d5ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 14:38:13 -0700 Subject: [PATCH 173/492] don't trace arguments (include tensor shapes and op attributes used for cost analysis) for dataset ops. PiperOrigin-RevId: 301671357 Change-Id: I4cf44855603ea26007a9652d6c01866db4d83c5b --- tensorflow/core/framework/dataset.cc | 4 ++++ tensorflow/core/framework/dataset.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 97c4d212223..cccbdd5d8e4 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -508,6 +508,10 @@ void DatasetOpKernel::Compute(OpKernelContext* ctx) { } } +string DatasetOpKernel::TraceString(OpKernelContext* ctx, bool verbose) { + return strings::StrCat(name_view(), ":", type_string_view()); +} + // static bool DatasetOpKernel::IsDatasetOp(const OpDef* op_def) { if (DatasetOpRegistry::IsRegistered(op_def->name())) { diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 92f7a52b632..25cc8fd759e 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -1073,6 +1073,8 @@ class DatasetOpKernel : public OpKernel { // the `DatasetOpKernel` class. static bool IsDatasetOp(const OpDef* op_def); + string TraceString(OpKernelContext* ctx, bool verbose) override; + protected: // Subclasses should implement this method. It will be called during Compute // execution. From 374a0d3a4a6d938b36973eb7eea9d31265cdbf71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cjaketae=E2=80=9D?= Date: Thu, 19 Mar 2020 07:06:10 +0900 Subject: [PATCH 174/492] Erase example as hashing is not always reproducible --- tensorflow/python/keras/preprocessing/text.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorflow/python/keras/preprocessing/text.py b/tensorflow/python/keras/preprocessing/text.py index 33740159d6b..7c17ae9cfcc 100644 --- a/tensorflow/python/keras/preprocessing/text.py +++ b/tensorflow/python/keras/preprocessing/text.py @@ -66,10 +66,6 @@ def one_hot(input_text, n, list of encoded integers each corresponding to a word (or token) in the given input string. - >>> sample_text = 'This is a sample sentence.' - >>> tf.keras.preprocessing.text.one_hot(sample_text, 20) - [4, 18, 1, 15, 17] - Arguments: input_text: Input text (string). n: int. Size of vocabulary. From 4f315c18bcfb7118c188082fe6ac7643e1e1c532 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 15:15:03 -0700 Subject: [PATCH 175/492] Add Checkpoint.read() that just loads a checkpoint without adding a `save_counter`. This is symmetrical to write(). Loading a checkpoint written with write() with read() instead of restore() allows to reliably call assert_existing_objects_matched() to check that all objects where read from the checkpoint. PiperOrigin-RevId: 301679428 Change-Id: I4acf3c2b7eb63ad25bb4db163bfca365e18bea6f --- tensorflow/python/training/tracking/util.py | 56 ++++++++++++++++--- .../python/training/tracking/util_test.py | 18 ++++-- .../v2/tensorflow.train.-checkpoint.pbtxt | 4 ++ 3 files changed, 64 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py index e4138864bd3..eeaa2a541c5 100644 --- a/tensorflow/python/training/tracking/util.py +++ b/tensorflow/python/training/tracking/util.py @@ -1847,6 +1847,8 @@ class Checkpoint(tracking.AutoTrackable): use by higher level checkpoint management utilities. `save` provides a very basic implementation of these features. + Checkpoints written with `write` must be read with `read`. + Args: file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). @@ -1888,7 +1890,7 @@ class Checkpoint(tracking.AutoTrackable): sequentially numbering checkpoints using `save_counter` and updating the metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint management, for example garbage collection and custom numbering, may be - provided by other utilities which also wrap `write` + provided by other utilities which also wrap `write` and `read`. (`tf.train.CheckpointManager` for example). Args: @@ -1932,20 +1934,58 @@ class Checkpoint(tracking.AutoTrackable): save_relative_paths=True) return file_path + def read(self, save_path): + """Read a training checkpoint written with `write`. + + Reads this `Checkpoint` and any objects it depends on. + + This method is just like `restore()` but does not expect the `save_counter` + variable in the checkpoint. It only restores the objects that the checkpoint + already depends on. + + The method is primarily intended for use by higher level checkpoint + management utilities that use `write()` instead of `save()` and have their + own mechanisms to number and track checkpoints. + + Example usage: + + ```python + # Create a checkpoint with write() + ckpt = tf.train.Checkpoint(v=tf.Variable(1.)) + path = ckpt.write('/tmp/my_checkpoint') + + # Later, load the checkpoint with read() + # With restore() assert_consumed() would have failed. + checkpoint.read(path).assert_consumed() + ``` + + Args: + save_path: The path to the checkpoint as returned by `write`. + + Returns: + A load status object, which can be used to make assertions about the + status of a checkpoint restoration. See `restore` for details. + """ + return self._saver.restore(save_path=save_path) + def restore(self, save_path): """Restore a training checkpoint. Restores this `Checkpoint` and any objects it depends on. - Either assigns values immediately if variables to restore have been created - already, or defers restoration until the variables are created. Dependencies - added after this call will be matched if they have a corresponding object in - the checkpoint (the restore request will queue in any trackable object - waiting for the expected dependency to be added). + This method is intended to be used to load checkpoints created by `save()`. + For checkpoints created by `write()` use the `read()` method which does not + expect the `save_counter` variable added by `save()`. + + `restore()` either assigns values immediately if variables to restore have + been created already, or defers restoration until the variables are + created. Dependencies added after this call will be matched if they have a + corresponding object in the checkpoint (the restore request will queue in + any trackable object waiting for the expected dependency to be added). To ensure that loading is complete and no more assignments will take place, use the `assert_consumed()` method of the status object returned by - `restore`: + `restore()`: ```python checkpoint = tf.train.Checkpoint( ... ) @@ -2006,7 +2046,7 @@ class Checkpoint(tracking.AutoTrackable): checkpoint file or object when the `Checkpoint` object is deleted (often at program shutdown). """ - status = self._saver.restore(save_path=save_path) + status = self.read(save_path) # Create the save counter now so it gets initialized with other variables # when graph building. Creating it earlier would lead to errors when using, # say, train.Saver() to save the model before initializing it. diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py index 6e57d690726..e63baa60003 100644 --- a/tensorflow/python/training/tracking/util_test.py +++ b/tensorflow/python/training/tracking/util_test.py @@ -1376,8 +1376,7 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase): @test_util.run_in_graph_and_eager_modes def test_write_checkpoint_from_function(self): checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt") - save_checkpoint = trackable_utils.Checkpoint( - v=variables_lib.Variable(1.)) + save_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(1.)) @def_function.function def _write_checkpoint(): @@ -1386,14 +1385,21 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase): self.evaluate([save_checkpoint.v.initializer]) self.evaluate(_write_checkpoint()) - load_checkpoint = trackable_utils.Checkpoint( - v=variables_lib.Variable(0.)) - load_checkpoint.restore(checkpoint_prefix).run_restore_ops() + load_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(0.)) + # Use read() instead of restore() which allows us to check that all + # existing objects were loaded. + status = load_checkpoint.read(checkpoint_prefix) + status.assert_existing_objects_matched() + status.assert_consumed() + status.run_restore_ops() self.assertEqual(1., self.evaluate(load_checkpoint.v)) self.evaluate(save_checkpoint.v.assign(3.)) self.evaluate(_write_checkpoint()) self.evaluate(save_checkpoint.v.assign(0.)) - load_checkpoint.restore(checkpoint_prefix).run_restore_ops() + status = load_checkpoint.read(checkpoint_prefix) + status.assert_existing_objects_matched() + status.assert_consumed() + status.run_restore_ops() self.assertEqual(3., self.evaluate(load_checkpoint.v)) def test_inititialize_with_data_structures(self): diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt index deb93d7adca..d7e93a0f937 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" } + member_method { + name: "read" + argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "restore" argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None" From 8b30b821904f0a87ea5bacf5d88c9fb255b32c1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 15:20:27 -0700 Subject: [PATCH 176/492] Internal change PiperOrigin-RevId: 301680486 Change-Id: If79ff677cfe72fa72ac02996b71047f3caa8b7d2 --- .../bucket_by_sequence_length_test.py | 7 +- .../data/kernel_tests/from_generator_test.py | 63 +----- .../python/data/kernel_tests/iterator_test.py | 4 +- tensorflow/python/data/ops/dataset_ops.py | 183 +++++++----------- tensorflow/python/data/util/structure.py | 16 +- tensorflow/python/ops/ragged/ragged_tensor.py | 5 - tensorflow/python/ops/script_ops.py | 49 +---- .../v1/tensorflow.-ragged-tensor-spec.pbtxt | 4 - .../golden/v1/tensorflow.data.-dataset.pbtxt | 2 +- ...ow.data.-fixed-length-record-dataset.pbtxt | 2 +- .../tensorflow.data.-t-f-record-dataset.pbtxt | 2 +- .../tensorflow.data.-text-line-dataset.pbtxt | 2 +- ...rflow.data.experimental.-csv-dataset.pbtxt | 2 +- ...ow.data.experimental.-random-dataset.pbtxt | 2 +- ...rflow.data.experimental.-sql-dataset.pbtxt | 2 +- .../v2/tensorflow.-ragged-tensor-spec.pbtxt | 4 - .../golden/v2/tensorflow.data.-dataset.pbtxt | 2 +- ...ow.data.-fixed-length-record-dataset.pbtxt | 2 +- .../tensorflow.data.-t-f-record-dataset.pbtxt | 2 +- .../tensorflow.data.-text-line-dataset.pbtxt | 2 +- ...rflow.data.experimental.-csv-dataset.pbtxt | 2 +- ...ow.data.experimental.-random-dataset.pbtxt | 2 +- ...rflow.data.experimental.-sql-dataset.pbtxt | 2 +- 23 files changed, 110 insertions(+), 253 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py index d23bbbe615a..0dd7ae1f083 100644 --- a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py @@ -48,7 +48,7 @@ def _format_record(array, sparse): return { "values": array, "indices": [[i] for i in range(len(array))], - "dense_shape": [len(array),] + "dense_shape": (len(array),) } return array @@ -402,16 +402,13 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase, bucket_size = 10 def _build_dataset(): - input_data = [list(range(i + 1)) for i in range(min_len, max_len)] - + input_data = [range(i+1) for i in range(min_len, max_len)] def generator_fn(): for record in input_data: yield _format_record(record, sparse=True) - dataset = dataset_ops.Dataset.from_generator( generator=generator_fn, output_types=_get_record_type(sparse=True)) - dataset = dataset.map(_to_sparse_tensor) return dataset diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py index 288d0e694f2..d320b281136 100644 --- a/tensorflow/python/data/kernel_tests/from_generator_test.py +++ b/tensorflow/python/data/kernel_tests/from_generator_test.py @@ -28,12 +28,7 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import script_ops -from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops.ragged import ragged_factory_ops -from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test @@ -246,7 +241,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaisesOpError("The expected type was int64"): self.evaluate(get_next()) self.assertAllEqual([7, 8, 9], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -266,7 +261,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): self.evaluate(get_next()) self.assertAllEqual([11, 12, 13], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -287,9 +282,11 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual((1, 2), self.evaluate(get_next())) self.assertEqual((3, 4), self.evaluate(get_next())) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaisesOpError( + r"The expected structure was \(tf\.int64, tf\.int64\)"): self.evaluate(get_next()) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaisesOpError( + r"The expected structure was \(tf\.int64, tf\.int64\)"): self.evaluate(get_next()) self.assertEqual((9, 10), self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -408,12 +405,8 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): stateful=True) dummy = constant_op.constant(37) - - dataset = dataset_ops._GeneratorDataset( - dummy, lambda x: x, lambda x: x, finalize_fn, - tensor_spec.TensorSpec((), dtypes.int32)) - - dataset = dataset.take(2) + dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x, + finalize_fn).take(2) get_next = self.getNext(dataset) self.assertAllEqual(37, self.evaluate(get_next())) @@ -435,46 +428,6 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([20], self.evaluate(get_next())) - @combinations.generate(test_base.default_test_combinations()) - def testFromGeneratorRaggedTensor(self): - - def generator(): - yield ragged_factory_ops.constant([[1, 2], [3]], - dtype=dtypes.int64, - ragged_rank=1) - - dataset = dataset_ops.Dataset.from_generator( - generator, - output_signature=ragged_tensor.RaggedTensorSpec( - shape=(2, None), dtype=dtypes.int64)) - get_next = self.getNext(dataset) - - ret = get_next() - - self.assertIsInstance(ret, ragged_tensor.RaggedTensor) - self.assertAllEqual([1, 2, 3], ret.values) - - @combinations.generate(test_base.default_test_combinations()) - def testFromGeneratorSparseTensor(self): - - def generator(): - yield sparse_tensor.SparseTensor( - indices=[[0, 0], [1, 2]], - values=constant_op.constant([1, 2], dtype=dtypes.int64), - dense_shape=[3, 4]) - - dataset = dataset_ops.Dataset.from_generator( - generator, - output_signature=sparse_tensor.SparseTensorSpec([3, 4], dtypes.int64)) - - get_next = self.getNext(dataset) - - ret = get_next() - - self.assertIsInstance(ret, sparse_tensor.SparseTensor) - self.assertAllEqual([[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]], - sparse_ops.sparse_tensor_to_dense(ret)) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 94b50a7864d..36689ed75fb 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -946,9 +946,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @def_function.function def fn(): - output_spec = tensor_spec.TensorSpec((), dtypes.int64) - dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, - output_spec) + dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn) iterator = iter(dataset) next(iterator) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 9eb38bfc0d1..32ab469363e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -408,7 +408,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): def element_spec(self): """The type specification of an element of this dataset. - >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec + >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) + >>> dataset.element_spec TensorSpec(shape=(), dtype=tf.int32, name=None) Returns: @@ -674,48 +675,27 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): del self._iterators[iterator_id] @staticmethod - @deprecation.deprecated_args(None, "Use output_signature instead", - "output_types", "output_shapes") - def from_generator(generator, - output_types=None, - output_shapes=None, - args=None, - output_signature=None): + def from_generator(generator, output_types, output_shapes=None, args=None): """Creates a `Dataset` whose elements are generated by `generator`. The `generator` argument must be a callable object that returns an object that supports the `iter()` protocol (e.g. a generator function). + The elements generated by `generator` must be compatible with the given + `output_types` and (optional) `output_shapes` arguments. - The elements generated by `generator` must be compatible with either the - given `output_signature` argument or with the given `output_types` and - (optionally) `output_shapes` arguments whichiver was specified. - - The recommended way to call `from_generator` is to use the - `output_signature` argument. In this case the output will be assumed to - consist of objects with the classes, shapes and types defined by - `tf.TypeSpec` objects from `output_signature` argument: - + >>> import itertools + >>> >>> def gen(): - ... ragged_tensor = tf.ragged.constant([[1, 2], [3]], - ... ragged_rank=1, - ... dtype=tf.int64) - ... yield 42, ragged_tensor + ... for i in itertools.count(1): + ... yield (i, [1] * i) >>> >>> dataset = tf.data.Dataset.from_generator( ... gen, - ... output_signature=( - ... tf.TensorSpec(shape=(), dtype=tf.int64), - ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int64))) + ... (tf.int64, tf.int64), + ... (tf.TensorShape([]), tf.TensorShape([None]))) >>> - >>> list(dataset.take(1)) - [(, - )] - - There is also a deprecated way to call `from_generator` by either with - `output_types` argument alone or together with `output_shapes` argument. - In this case the output of the function will be assumed to consist of - `tf.Tensor` objects with with the types defined by `output_types` and with - the shapes which are either unknown or defined by `output_shapes`. + >>> list(dataset.take(3).as_numpy_iterator()) + [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))] Note: The current implementation of `Dataset.from_generator()` uses `tf.numpy_function` and inherits the same constraints. In particular, it @@ -739,56 +719,31 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): `iter()` protocol. If `args` is not specified, `generator` must take no arguments; otherwise it must take as many arguments as there are values in `args`. - output_types: (Optional.) A nested structure of `tf.DType` objects - corresponding to each component of an element yielded by `generator`. + output_types: A nested structure of `tf.DType` objects corresponding to + each component of an element yielded by `generator`. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element yielded by `generator`. args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated and passed to `generator` as NumPy-array arguments. - output_signature: (Optional.) A nested structure of `tf.TypeSpec` objects - corresponding to each component of an element yielded by `generator`. Returns: Dataset: A `Dataset`. """ if not callable(generator): raise TypeError("`generator` must be callable.") - - if output_signature is not None: - if output_types is not None: - raise TypeError("`output_types` can not be used together with " - "`output_signature`") - if output_shapes is not None: - raise TypeError("`output_shapes` can not be used together with " - "`output_signature`") - if not all( - isinstance(_, type_spec.TypeSpec) - for _ in nest.flatten(output_signature)): - raise TypeError("All the elements of `output_siganture` must be " - "a `tf.TypeSpec` objects.") + if output_shapes is None: + output_shapes = nest.map_structure( + lambda _: tensor_shape.TensorShape(None), output_types) else: - if output_types is None and output_shapes is not None: - raise TypeError("`output_shapes` can not be used alone without " - "`output_types`") - - if output_signature is None: - if output_shapes is None: - output_shapes = nest.map_structure( - lambda _: tensor_shape.TensorShape(None), output_types) - else: - output_shapes = nest.map_structure_up_to(output_types, - tensor_shape.as_shape, - output_shapes) - output_signature = nest.map_structure_up_to(output_types, - tensor_spec.TensorSpec, - output_shapes, output_types) - + output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) if args is None: args = () else: args = tuple(ops.convert_n_to_tensor(args, name="args")) - flat_output_types = structure.get_flat_tensor_types(output_signature) + flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)] + flattened_shapes = nest.flatten(output_shapes) generator_state = DatasetV2._GeneratorState(generator) @@ -826,41 +781,56 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): """A `py_func` that will be called to invoke the iterator.""" # `next()` raises `StopIteration` when there are no more # elements remaining to be generated. - values = next(generator_state.get_iterator(iterator_id.numpy())) - - def serialize_structure(s): - return nest.map_structure(lambda ts: ts._serialize(), s) # pylint: disable=protected-access + values = next(generator_state.get_iterator(iterator_id)) + # Use the same _convert function from the py_func() implementation to + # convert the returned values to arrays early, so that we can inspect + # their values. try: - output_dtypes = nest.map_structure(lambda t: t.dtype, - output_signature) - values = structure.normalize_element(values, dtypes=output_dtypes) + flattened_values = nest.flatten_up_to(output_types, values) except (TypeError, ValueError): - six.reraise( - TypeError, - TypeError( - "`generator` yielded an element that did not match the " - "expected structure. The expected structure was %s, but the " - "yielded element was %s." % - (serialize_structure(output_signature), values)), - sys.exc_info()[2]) + six.reraise(TypeError, TypeError( + "`generator` yielded an element that did not match the expected " + "structure. The expected structure was %s, but the yielded " + "element was %s." % (output_types, values)), sys.exc_info()[2]) + ret_arrays = [] + for ret, dtype in zip(flattened_values, flattened_types): + try: + ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access + ret, dtype=dtype.as_numpy_dtype)) + except (TypeError, ValueError): + six.reraise(TypeError, TypeError( + "`generator` yielded an element that could not be converted to " + "the expected type. The expected type was %s, but the yielded " + "element was %s." % (dtype.name, ret)), sys.exc_info()[2]) - values_spec = structure.type_spec_from_value(values) + # Additional type and shape checking to ensure that the components + # of the generated element match the `output_types` and `output_shapes` + # arguments. + for (ret_array, expected_dtype, expected_shape) in zip( + ret_arrays, flattened_types, flattened_shapes): + if ret_array.dtype != expected_dtype.as_numpy_dtype: + raise TypeError( + "`generator` yielded an element of type %s where an element " + "of type %s was expected." % (ret_array.dtype, + expected_dtype.as_numpy_dtype)) + if not expected_shape.is_compatible_with(ret_array.shape): + raise ValueError( + "`generator` yielded an element of shape %s where an element " + "of shape %s was expected." % (ret_array.shape, expected_shape)) - if not structure.are_compatible(values_spec, output_signature): - raise TypeError( - "`generator` yielded an element of TypeSpec%s where an element " - "of TypeSpec%s was expected." % - (serialize_structure(values_spec), - serialize_structure(output_signature))) + return ret_arrays - return structure.to_tensor_list(output_signature, values) + flat_values = script_ops.numpy_function(generator_py_func, + [iterator_id_t], flattened_types) - return script_ops._eager_py_func( # pylint: disable=protected-access - generator_py_func, - inp=[iterator_id_t], - Tout=flat_output_types, - use_tape_cache=False) + # The `py_func()` op drops the inferred shapes, so we add them back in + # here. + if output_shapes is not None: + for ret_t, shape in zip(flat_values, flattened_shapes): + ret_t.set_shape(shape) + + return nest.pack_sequence_as(output_types, flat_values) def finalize_fn(iterator_id_t): """Releases host-side state for the iterator with ID `iterator_id_t`.""" @@ -886,7 +856,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): # given ID, and raises StopIteration when that iterator contains no # more elements. return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, - finalize_fn, output_signature) + finalize_fn) # A single-element dataset that, each time it is evaluated, contains a # freshly-generated and unique (for the returned dataset) int64 @@ -2308,14 +2278,9 @@ class DatasetV1(DatasetV2): @staticmethod @functools.wraps(DatasetV2.from_generator) - def from_generator(generator, - output_types=None, - output_shapes=None, - args=None, - output_signature=None): - return DatasetV1Adapter( - DatasetV2.from_generator(generator, output_types, output_shapes, args, - output_signature)) + def from_generator(generator, output_types, output_shapes=None, args=None): + return DatasetV1Adapter(DatasetV2.from_generator( + generator, output_types, output_shapes, args)) @staticmethod @functools.wraps(DatasetV2.range) @@ -3296,8 +3261,7 @@ class StructuredFunctionWrapper(object): class _GeneratorDataset(DatasetSource): """A `Dataset` that generates elements by invoking a function.""" - def __init__(self, init_args, init_func, next_func, finalize_func, - output_signature): + def __init__(self, init_args, init_func, next_func, finalize_func): """Constructs a `_GeneratorDataset`. Args: @@ -3311,8 +3275,6 @@ class _GeneratorDataset(DatasetSource): finalize_func: A TensorFlow function that will be called on the result of `init_func` immediately before a C++ iterator over this dataset is destroyed. The return value is ignored. - output_signature: A nested structure of `tf.TypeSpec` objects describing - the output of `next_func`. """ self._init_args = init_args @@ -3332,9 +3294,6 @@ class _GeneratorDataset(DatasetSource): finalize_func, self._transformation_name(), input_structure=self._init_func.output_structure) - - self._output_signature = output_signature - variant_tensor = gen_dataset_ops.generator_dataset( structure.to_tensor_list(self._init_structure, self._init_args) + self._init_func.function.captured_inputs, @@ -3348,7 +3307,7 @@ class _GeneratorDataset(DatasetSource): @property def element_spec(self): - return self._output_signature + return self._next_func.output_structure def _transformation_name(self): return "Dataset.from_generator()" diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index ee6151742f6..87825005069 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -67,7 +67,7 @@ def _RaggedTensorStructure(dtype, shape, ragged_rank): # TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once # it is a subclass of `CompositeTensor`. -def normalize_element(element, dtypes=None): +def normalize_element(element): """Normalizes a nested structure of element components. * Components matching `SparseTensorSpec` are converted to `SparseTensor`. @@ -78,10 +78,6 @@ def normalize_element(element, dtypes=None): Args: element: A nested structure of individual components. - dtypes: (Optional.) A nested structure of `tf.DType` objects corresponding - to each component of `element`. If specified, it will be used to set the - exact type of output tensor when converting input components which - are not tensors themselves (e.g. numpy arrays, native python types, etc.) Returns: A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`, @@ -89,21 +85,17 @@ def normalize_element(element, dtypes=None): """ components = nest.flatten(element) normalized_components = [] - if dtypes is None: - flattened_dtypes = [None] * len(components) - else: - flattened_dtypes = nest.flatten(dtypes) with ops.name_scope("normalize_element"): # Imported here to avoid circular dependency. from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top - for i, (t, dtype) in enumerate(zip(components, flattened_dtypes)): + for i, t in enumerate(components): try: spec = type_spec_from_value(t, use_fallback=False) except TypeError: # TypeError indicates it was not possible to compute a `TypeSpec` for # the value. As a fallback try converting the value to a tensor. normalized_components.append( - ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) + ops.convert_to_tensor(t, name="component_%d" % i)) else: if isinstance(spec, sparse_tensor.SparseTensorSpec): normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) @@ -120,7 +112,7 @@ def normalize_element(element, dtypes=None): normalized_components.append(t) else: normalized_components.append( - ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) + ops.convert_to_tensor(t, name="component_%d" % i)) return nest.pack_sequence_as(element, normalized_components) diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 6d365210308..78be28b7ec6 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -2085,11 +2085,6 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): else: return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value) - @property - def dtype(self): - """The `tf.dtypes.DType` specified by this type for the RaggedTensor.""" - return self._dtype - def _serialize(self): return (self._shape, self._dtype, self._ragged_rank, self._row_splits_dtype) diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index dd53b388bd4..bee85dc4a5b 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -70,7 +70,7 @@ def _maybe_copy_to_context_device(tensor, device_name): class EagerFunc(object): """A wrapper for a function owned by an EagerPyFunc.""" - def __init__(self, func, Tout, is_grad_func, use_tape_cache=True): + def __init__(self, func, Tout, is_grad_func): """Constructs an EagerFunc. Args: @@ -79,12 +79,10 @@ class EagerFunc(object): None. is_grad_func: Whether this EagerFunc is the gradient of another EagerPyFunc. - use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`. """ self._func = func self._out_dtypes = Tout self._is_grad_func = is_grad_func - self._use_tape_cache = use_tape_cache def _convert(self, value, dtype): """Converts `value` to a tensor of type `dtype`, with error checking. @@ -148,8 +146,7 @@ class EagerFunc(object): else: outputs = _maybe_copy_to_context_device( self._convert(ret, dtype=self._out_dtypes[0]), device_name) - if self._use_tape_cache: - tape_cache[compat.as_bytes(token)] = (tape, args, outputs) + tape_cache[compat.as_bytes(token)] = (tape, args, outputs) return outputs @@ -279,8 +276,7 @@ def _internal_py_func(func, stateful=None, eager=False, is_grad_func=False, - name=None, - use_tape_cache=True): + name=None): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError("Expected func to be callable, got func of type {}".format( @@ -296,7 +292,7 @@ def _internal_py_func(func, Tout = [Tout] if eager: - func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache) + func = EagerFunc(func, Tout, is_grad_func) # Tying the registered function's lifetime with the current default graph is # not reliable. For example, Estimator-based binaries may switch graphs in @@ -373,35 +369,6 @@ def _EagerPyFuncGrad(op, *dy): is_grad_func=True) -# NOTE(lithuak): this function as a layer of indirection was added with one -# specific purpose: as a workaround for github issue #35084. -# It does all the same as `eager_py_func` used to do with one difference: -# it can be used to instruct underlying EagerFunc not to use `tape_cache` -# to avoid memory leak. When the issue #35084 is fixed - this function should -# be removed, its body should be moved back to become the body of -# `eager_py_func` and all the call sites should be reverted to -# using `eager_py_func` without `use_tape_cache` argument of any value. -def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True): - """Wraps a python function into a TensorFlow op that executes it eagerly.""" - if ops.executing_eagerly_outside_functions(): - with ops.device(context.context().host_address_space()): - return _internal_py_func( - func=func, - inp=inp, - Tout=Tout, - eager=True, - name=name, - use_tape_cache=use_tape_cache) - - return _internal_py_func( - func=func, - inp=inp, - Tout=Tout, - eager=True, - name=name, - use_tape_cache=use_tape_cache) - - @tf_export("py_function") def eager_py_func(func, inp, Tout, name=None): """Wraps a python function into a TensorFlow op that executes it eagerly. @@ -482,8 +449,12 @@ def eager_py_func(func, inp, Tout, name=None): A list of `Tensor` or a single `Tensor` which `func` computes; an empty list if `func` returns None. """ - return _eager_py_func( - func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True) + if ops.executing_eagerly_outside_functions(): + with ops.device(context.context().host_address_space()): + return _internal_py_func( + func=func, inp=inp, Tout=Tout, eager=True, name=name) + + return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name) def py_func_common(func, inp, Tout, stateful=True, name=None): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt index 029d04fee9b..2ec5bb46ed1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt @@ -4,10 +4,6 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - member { - name: "dtype" - mtype: "" - } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 841b142c082..872d03770ed 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -63,7 +63,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index 42225d3f566..a84c5aa3caf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index 81a1c7fbd9c..a3862ae2a19 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index e9e3962a498..baaaf7ea7be 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt index 20712fb14a7..afdeea5d018 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt index c139c6b9cc8..76113c5e01d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt index 41a67db62dc..1a11026fd19 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt index 029d04fee9b..2ec5bb46ed1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt @@ -4,10 +4,6 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - member { - name: "dtype" - mtype: "" - } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index 3cb50feac2d..d9414c31e7d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -46,7 +46,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 9e2fa7255fd..28efdb6e855 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index 1bd43d28bc4..c9553efb58c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -47,7 +47,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 2e295c44b5f..16a878144ae 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index 91175909f77..d1d2db041e0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index 09ed74d3460..18a6b8cbd1b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index c245d563e9e..0cf3d94ba68 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" From d9d7b0cb78e3e1f763ecef62abffbf623c63d8c0 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Wed, 18 Mar 2020 15:20:29 -0700 Subject: [PATCH 177/492] Adds TFLRT dependency to `framework` target. PiperOrigin-RevId: 301680495 Change-Id: I1d10bb6f66a0718682fc7db007e963426c64785e --- tensorflow/lite/BUILD | 9 +++++++-- tensorflow/lite/experimental/tflite_api_dispatcher/BUILD | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 9c4740b8c0a..e6164c395e3 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test") -load("//tensorflow/lite:build_def.bzl", "tflite_cc_shared_object", "tflite_copts") +load("//tensorflow/lite:build_def.bzl", "if_tflite_experimental_runtime", "tflite_cc_shared_object", "tflite_copts", "tflite_experimental_runtime_linkopts") load("//tensorflow/lite/micro:build_def.bzl", "cc_library") load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") @@ -261,6 +261,11 @@ cc_library( ], hdrs = FRAMEWORK_LIB_HDRS, copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + defines = if_tflite_experimental_runtime( + if_eager = ["TFLITE_EXPERIMENTAL_RUNTIME_EAGER"], + if_non_eager = ["TFLITE_EXPERIMENTAL_RUNTIME_NON_EAGER"], + if_none = [], + ), deps = [ ":framework_lib", ":allocation", @@ -285,7 +290,7 @@ cc_library( "//tensorflow/lite/profiling:platform_profiler", ], "//conditions:default": [], - }), + }) + tflite_experimental_runtime_linkopts(), alwayslink = 1, ) diff --git a/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD b/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD index 4419f84a972..f4ade28eab8 100644 --- a/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD +++ b/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD @@ -14,7 +14,7 @@ cc_library( if_none = [], ), deps = [ - "//tensorflow/lite:framework", + "//tensorflow/lite:framework_lib", ] + tflite_experimental_runtime_linkopts(), ) @@ -23,7 +23,7 @@ cc_library( hdrs = ["tflite_api_dispatcher.h"], deps = [ ":tflite_api_dispatcher", - "//tensorflow/lite:framework", + "//tensorflow/lite:framework_lib", ] + tflite_experimental_runtime_linkopts( if_eager = [ # "//tensorflow/lite/experimental/tf_runtime/opdef:tflrt_opdefs", From c7fbc76d4a3f48ab3d243e412403f83e7bde9e15 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 15:31:22 -0700 Subject: [PATCH 178/492] Remove the trailing 'm' , since the tag for py38 is not cp38m but cp38. PiperOrigin-RevId: 301682546 Change-Id: I86e3632702ee77f82e3a59ba489a298930588ab9 --- .../release/windows/gpu_py38_full/release_pip_rename.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py38_full/release_pip_rename.sh b/tensorflow/tools/ci_build/release/windows/gpu_py38_full/release_pip_rename.sh index 039f9516d86..11744ea734d 100644 --- a/tensorflow/tools/ci_build/release/windows/gpu_py38_full/release_pip_rename.sh +++ b/tensorflow/tools/ci_build/release/windows/gpu_py38_full/release_pip_rename.sh @@ -19,6 +19,6 @@ set -x source tensorflow/tools/ci_build/release/common.sh # Copy and rename to tensorflow -for f in $(ls py_test_dir/tensorflow-*cp3*-cp3*m-win_amd64.whl); do +for f in $(ls py_test_dir/tensorflow-*cp3*-cp3*-win_amd64.whl); do copy_to_new_project_name "${f}" tensorflow_gpu done From 7034c2cc0a9122d508c494d7738ed486a95fb140 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 16:07:02 -0700 Subject: [PATCH 179/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301689315 Change-Id: I80babdfe5d0e9d5c70ff6f25e3baf1964a380796 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 7be0c66548c..3d05bb08fa3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11815,7 +11815,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12072,7 +12072,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12083,7 +12083,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12301,7 +12301,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12312,7 +12312,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19153,7 +19153,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20224,7 +20224,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21396,7 +21396,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22104,7 +22104,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22300,7 +22300,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22484,7 +22484,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22543,7 +22543,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22717,7 +22717,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23098,7 +23098,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25441,7 +25441,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25504,7 +25504,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25747,7 +25747,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26370,7 +26370,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45499,7 +45499,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46287,7 +46287,7 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46350,7 +46350,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From c30d6a30d5dbf46526712f7108677926fcfd98f1 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Wed, 18 Mar 2020 16:13:50 -0700 Subject: [PATCH 180/492] Update TF Lite schema to correspond to Flatbuffers update. PiperOrigin-RevId: 301690401 Change-Id: Ifcf08e053b26e4bf2b704896655efae5740669a5 --- tensorflow/lite/schema/schema_generated.h | 3730 +++++++++++---------- 1 file changed, 1868 insertions(+), 1862 deletions(-) diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index d884d7f865b..8caf2409b96 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -396,7 +396,7 @@ inline const TensorType (&EnumValuesTensorType())[10] { } inline const char * const *EnumNamesTensorType() { - static const char * const names[] = { + static const char * const names[11] = { "FLOAT32", "FLOAT16", "INT32", @@ -413,7 +413,7 @@ inline const char * const *EnumNamesTensorType() { } inline const char *EnumNameTensorType(TensorType e) { - if (e < TensorType_FLOAT32 || e > TensorType_INT8) return ""; + if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT8)) return ""; const size_t index = static_cast(e); return EnumNamesTensorType()[index]; } @@ -434,7 +434,7 @@ inline const QuantizationDetails (&EnumValuesQuantizationDetails())[2] { } inline const char * const *EnumNamesQuantizationDetails() { - static const char * const names[] = { + static const char * const names[3] = { "NONE", "CustomQuantization", nullptr @@ -443,7 +443,7 @@ inline const char * const *EnumNamesQuantizationDetails() { } inline const char *EnumNameQuantizationDetails(QuantizationDetails e) { - if (e < QuantizationDetails_NONE || e > QuantizationDetails_CustomQuantization) return ""; + if (flatbuffers::IsOutRange(e, QuantizationDetails_NONE, QuantizationDetails_CustomQuantization)) return ""; const size_t index = static_cast(e); return EnumNamesQuantizationDetails()[index]; } @@ -452,7 +452,7 @@ template struct QuantizationDetailsTraits { static const QuantizationDetails enum_value = QuantizationDetails_NONE; }; -template<> struct QuantizationDetailsTraits { +template<> struct QuantizationDetailsTraits { static const QuantizationDetails enum_value = QuantizationDetails_CustomQuantization; }; @@ -488,13 +488,13 @@ struct QuantizationDetailsUnion { static void *UnPack(const void *obj, QuantizationDetails type, const flatbuffers::resolver_function_t *resolver); flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; - CustomQuantizationT *AsCustomQuantization() { + tflite::CustomQuantizationT *AsCustomQuantization() { return type == QuantizationDetails_CustomQuantization ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const CustomQuantizationT *AsCustomQuantization() const { + const tflite::CustomQuantizationT *AsCustomQuantization() const { return type == QuantizationDetails_CustomQuantization ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } }; @@ -517,7 +517,7 @@ inline const DimensionType (&EnumValuesDimensionType())[2] { } inline const char * const *EnumNamesDimensionType() { - static const char * const names[] = { + static const char * const names[3] = { "DENSE", "SPARSE_CSR", nullptr @@ -526,7 +526,7 @@ inline const char * const *EnumNamesDimensionType() { } inline const char *EnumNameDimensionType(DimensionType e) { - if (e < DimensionType_DENSE || e > DimensionType_SPARSE_CSR) return ""; + if (flatbuffers::IsOutRange(e, DimensionType_DENSE, DimensionType_SPARSE_CSR)) return ""; const size_t index = static_cast(e); return EnumNamesDimensionType()[index]; } @@ -551,7 +551,7 @@ inline const SparseIndexVector (&EnumValuesSparseIndexVector())[4] { } inline const char * const *EnumNamesSparseIndexVector() { - static const char * const names[] = { + static const char * const names[5] = { "NONE", "Int32Vector", "Uint16Vector", @@ -562,7 +562,7 @@ inline const char * const *EnumNamesSparseIndexVector() { } inline const char *EnumNameSparseIndexVector(SparseIndexVector e) { - if (e < SparseIndexVector_NONE || e > SparseIndexVector_Uint8Vector) return ""; + if (flatbuffers::IsOutRange(e, SparseIndexVector_NONE, SparseIndexVector_Uint8Vector)) return ""; const size_t index = static_cast(e); return EnumNamesSparseIndexVector()[index]; } @@ -571,15 +571,15 @@ template struct SparseIndexVectorTraits { static const SparseIndexVector enum_value = SparseIndexVector_NONE; }; -template<> struct SparseIndexVectorTraits { +template<> struct SparseIndexVectorTraits { static const SparseIndexVector enum_value = SparseIndexVector_Int32Vector; }; -template<> struct SparseIndexVectorTraits { +template<> struct SparseIndexVectorTraits { static const SparseIndexVector enum_value = SparseIndexVector_Uint16Vector; }; -template<> struct SparseIndexVectorTraits { +template<> struct SparseIndexVectorTraits { static const SparseIndexVector enum_value = SparseIndexVector_Uint8Vector; }; @@ -615,29 +615,29 @@ struct SparseIndexVectorUnion { static void *UnPack(const void *obj, SparseIndexVector type, const flatbuffers::resolver_function_t *resolver); flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; - Int32VectorT *AsInt32Vector() { + tflite::Int32VectorT *AsInt32Vector() { return type == SparseIndexVector_Int32Vector ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const Int32VectorT *AsInt32Vector() const { + const tflite::Int32VectorT *AsInt32Vector() const { return type == SparseIndexVector_Int32Vector ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - Uint16VectorT *AsUint16Vector() { + tflite::Uint16VectorT *AsUint16Vector() { return type == SparseIndexVector_Uint16Vector ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const Uint16VectorT *AsUint16Vector() const { + const tflite::Uint16VectorT *AsUint16Vector() const { return type == SparseIndexVector_Uint16Vector ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - Uint8VectorT *AsUint8Vector() { + tflite::Uint8VectorT *AsUint8Vector() { return type == SparseIndexVector_Uint8Vector ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const Uint8VectorT *AsUint8Vector() const { + const tflite::Uint8VectorT *AsUint8Vector() const { return type == SparseIndexVector_Uint8Vector ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } }; @@ -908,7 +908,7 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] { } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[] = { + static const char * const names[127] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1041,7 +1041,7 @@ inline const char * const *EnumNamesBuiltinOperator() { } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (e < BuiltinOperator_ADD || e > BuiltinOperator_SEGMENT_SUM) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_SEGMENT_SUM)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -1260,7 +1260,7 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] { } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[] = { + static const char * const names[102] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1368,7 +1368,7 @@ inline const char * const *EnumNamesBuiltinOptions() { } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (e < BuiltinOptions_NONE || e > BuiltinOptions_SegmentSumOptions) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_SegmentSumOptions)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -1377,403 +1377,403 @@ template struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NONE; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DepthwiseConv2DOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ConcatEmbeddingsOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_AddOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LocalResponseNormalizationOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_CallOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_EmbeddingLookupSparseOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MulOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_PadOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SubOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DivOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ExpOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TopKV2Options; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SplitOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LogSoftmaxOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_CastOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DequantizeOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MaximumMinimumOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LessOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NegOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_PadV2Options; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SliceOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TileOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_PackOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ZerosLikeOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_FillOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceLSTMOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_FloorModOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_RangeOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ResizeNearestNeighborOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LeakyReluOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SquaredDifferenceOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MirrorPadOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_AbsOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SplitVOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_CosOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_RankOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_IfOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_WhileOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DepthToSpaceOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV4Options; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ScatterNdOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SelectV2Options; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DensifyOptions; }; -template<> struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions; }; @@ -1809,805 +1809,805 @@ struct BuiltinOptionsUnion { static void *UnPack(const void *obj, BuiltinOptions type, const flatbuffers::resolver_function_t *resolver); flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; - Conv2DOptionsT *AsConv2DOptions() { + tflite::Conv2DOptionsT *AsConv2DOptions() { return type == BuiltinOptions_Conv2DOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const Conv2DOptionsT *AsConv2DOptions() const { + const tflite::Conv2DOptionsT *AsConv2DOptions() const { return type == BuiltinOptions_Conv2DOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() { + tflite::DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() { return type == BuiltinOptions_DepthwiseConv2DOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() const { + const tflite::DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() const { return type == BuiltinOptions_DepthwiseConv2DOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() { + tflite::ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() { return type == BuiltinOptions_ConcatEmbeddingsOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() const { + const tflite::ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() const { return type == BuiltinOptions_ConcatEmbeddingsOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LSHProjectionOptionsT *AsLSHProjectionOptions() { + tflite::LSHProjectionOptionsT *AsLSHProjectionOptions() { return type == BuiltinOptions_LSHProjectionOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LSHProjectionOptionsT *AsLSHProjectionOptions() const { + const tflite::LSHProjectionOptionsT *AsLSHProjectionOptions() const { return type == BuiltinOptions_LSHProjectionOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - Pool2DOptionsT *AsPool2DOptions() { + tflite::Pool2DOptionsT *AsPool2DOptions() { return type == BuiltinOptions_Pool2DOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const Pool2DOptionsT *AsPool2DOptions() const { + const tflite::Pool2DOptionsT *AsPool2DOptions() const { return type == BuiltinOptions_Pool2DOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SVDFOptionsT *AsSVDFOptions() { + tflite::SVDFOptionsT *AsSVDFOptions() { return type == BuiltinOptions_SVDFOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SVDFOptionsT *AsSVDFOptions() const { + const tflite::SVDFOptionsT *AsSVDFOptions() const { return type == BuiltinOptions_SVDFOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - RNNOptionsT *AsRNNOptions() { + tflite::RNNOptionsT *AsRNNOptions() { return type == BuiltinOptions_RNNOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const RNNOptionsT *AsRNNOptions() const { + const tflite::RNNOptionsT *AsRNNOptions() const { return type == BuiltinOptions_RNNOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - FullyConnectedOptionsT *AsFullyConnectedOptions() { + tflite::FullyConnectedOptionsT *AsFullyConnectedOptions() { return type == BuiltinOptions_FullyConnectedOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const FullyConnectedOptionsT *AsFullyConnectedOptions() const { + const tflite::FullyConnectedOptionsT *AsFullyConnectedOptions() const { return type == BuiltinOptions_FullyConnectedOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SoftmaxOptionsT *AsSoftmaxOptions() { + tflite::SoftmaxOptionsT *AsSoftmaxOptions() { return type == BuiltinOptions_SoftmaxOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SoftmaxOptionsT *AsSoftmaxOptions() const { + const tflite::SoftmaxOptionsT *AsSoftmaxOptions() const { return type == BuiltinOptions_SoftmaxOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ConcatenationOptionsT *AsConcatenationOptions() { + tflite::ConcatenationOptionsT *AsConcatenationOptions() { return type == BuiltinOptions_ConcatenationOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ConcatenationOptionsT *AsConcatenationOptions() const { + const tflite::ConcatenationOptionsT *AsConcatenationOptions() const { return type == BuiltinOptions_ConcatenationOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - AddOptionsT *AsAddOptions() { + tflite::AddOptionsT *AsAddOptions() { return type == BuiltinOptions_AddOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const AddOptionsT *AsAddOptions() const { + const tflite::AddOptionsT *AsAddOptions() const { return type == BuiltinOptions_AddOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - L2NormOptionsT *AsL2NormOptions() { + tflite::L2NormOptionsT *AsL2NormOptions() { return type == BuiltinOptions_L2NormOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const L2NormOptionsT *AsL2NormOptions() const { + const tflite::L2NormOptionsT *AsL2NormOptions() const { return type == BuiltinOptions_L2NormOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() { + tflite::LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() { return type == BuiltinOptions_LocalResponseNormalizationOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() const { + const tflite::LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() const { return type == BuiltinOptions_LocalResponseNormalizationOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LSTMOptionsT *AsLSTMOptions() { + tflite::LSTMOptionsT *AsLSTMOptions() { return type == BuiltinOptions_LSTMOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LSTMOptionsT *AsLSTMOptions() const { + const tflite::LSTMOptionsT *AsLSTMOptions() const { return type == BuiltinOptions_LSTMOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ResizeBilinearOptionsT *AsResizeBilinearOptions() { + tflite::ResizeBilinearOptionsT *AsResizeBilinearOptions() { return type == BuiltinOptions_ResizeBilinearOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ResizeBilinearOptionsT *AsResizeBilinearOptions() const { + const tflite::ResizeBilinearOptionsT *AsResizeBilinearOptions() const { return type == BuiltinOptions_ResizeBilinearOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - CallOptionsT *AsCallOptions() { + tflite::CallOptionsT *AsCallOptions() { return type == BuiltinOptions_CallOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const CallOptionsT *AsCallOptions() const { + const tflite::CallOptionsT *AsCallOptions() const { return type == BuiltinOptions_CallOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ReshapeOptionsT *AsReshapeOptions() { + tflite::ReshapeOptionsT *AsReshapeOptions() { return type == BuiltinOptions_ReshapeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ReshapeOptionsT *AsReshapeOptions() const { + const tflite::ReshapeOptionsT *AsReshapeOptions() const { return type == BuiltinOptions_ReshapeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SkipGramOptionsT *AsSkipGramOptions() { + tflite::SkipGramOptionsT *AsSkipGramOptions() { return type == BuiltinOptions_SkipGramOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SkipGramOptionsT *AsSkipGramOptions() const { + const tflite::SkipGramOptionsT *AsSkipGramOptions() const { return type == BuiltinOptions_SkipGramOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SpaceToDepthOptionsT *AsSpaceToDepthOptions() { + tflite::SpaceToDepthOptionsT *AsSpaceToDepthOptions() { return type == BuiltinOptions_SpaceToDepthOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SpaceToDepthOptionsT *AsSpaceToDepthOptions() const { + const tflite::SpaceToDepthOptionsT *AsSpaceToDepthOptions() const { return type == BuiltinOptions_SpaceToDepthOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() { + tflite::EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() { return type == BuiltinOptions_EmbeddingLookupSparseOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() const { + const tflite::EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() const { return type == BuiltinOptions_EmbeddingLookupSparseOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - MulOptionsT *AsMulOptions() { + tflite::MulOptionsT *AsMulOptions() { return type == BuiltinOptions_MulOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const MulOptionsT *AsMulOptions() const { + const tflite::MulOptionsT *AsMulOptions() const { return type == BuiltinOptions_MulOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - PadOptionsT *AsPadOptions() { + tflite::PadOptionsT *AsPadOptions() { return type == BuiltinOptions_PadOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const PadOptionsT *AsPadOptions() const { + const tflite::PadOptionsT *AsPadOptions() const { return type == BuiltinOptions_PadOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - GatherOptionsT *AsGatherOptions() { + tflite::GatherOptionsT *AsGatherOptions() { return type == BuiltinOptions_GatherOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const GatherOptionsT *AsGatherOptions() const { + const tflite::GatherOptionsT *AsGatherOptions() const { return type == BuiltinOptions_GatherOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() { + tflite::BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() { return type == BuiltinOptions_BatchToSpaceNDOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() const { + const tflite::BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() const { return type == BuiltinOptions_BatchToSpaceNDOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() { + tflite::SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() { return type == BuiltinOptions_SpaceToBatchNDOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() const { + const tflite::SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() const { return type == BuiltinOptions_SpaceToBatchNDOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - TransposeOptionsT *AsTransposeOptions() { + tflite::TransposeOptionsT *AsTransposeOptions() { return type == BuiltinOptions_TransposeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const TransposeOptionsT *AsTransposeOptions() const { + const tflite::TransposeOptionsT *AsTransposeOptions() const { return type == BuiltinOptions_TransposeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ReducerOptionsT *AsReducerOptions() { + tflite::ReducerOptionsT *AsReducerOptions() { return type == BuiltinOptions_ReducerOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ReducerOptionsT *AsReducerOptions() const { + const tflite::ReducerOptionsT *AsReducerOptions() const { return type == BuiltinOptions_ReducerOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SubOptionsT *AsSubOptions() { + tflite::SubOptionsT *AsSubOptions() { return type == BuiltinOptions_SubOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SubOptionsT *AsSubOptions() const { + const tflite::SubOptionsT *AsSubOptions() const { return type == BuiltinOptions_SubOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - DivOptionsT *AsDivOptions() { + tflite::DivOptionsT *AsDivOptions() { return type == BuiltinOptions_DivOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const DivOptionsT *AsDivOptions() const { + const tflite::DivOptionsT *AsDivOptions() const { return type == BuiltinOptions_DivOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SqueezeOptionsT *AsSqueezeOptions() { + tflite::SqueezeOptionsT *AsSqueezeOptions() { return type == BuiltinOptions_SqueezeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SqueezeOptionsT *AsSqueezeOptions() const { + const tflite::SqueezeOptionsT *AsSqueezeOptions() const { return type == BuiltinOptions_SqueezeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SequenceRNNOptionsT *AsSequenceRNNOptions() { + tflite::SequenceRNNOptionsT *AsSequenceRNNOptions() { return type == BuiltinOptions_SequenceRNNOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SequenceRNNOptionsT *AsSequenceRNNOptions() const { + const tflite::SequenceRNNOptionsT *AsSequenceRNNOptions() const { return type == BuiltinOptions_SequenceRNNOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - StridedSliceOptionsT *AsStridedSliceOptions() { + tflite::StridedSliceOptionsT *AsStridedSliceOptions() { return type == BuiltinOptions_StridedSliceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const StridedSliceOptionsT *AsStridedSliceOptions() const { + const tflite::StridedSliceOptionsT *AsStridedSliceOptions() const { return type == BuiltinOptions_StridedSliceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ExpOptionsT *AsExpOptions() { + tflite::ExpOptionsT *AsExpOptions() { return type == BuiltinOptions_ExpOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ExpOptionsT *AsExpOptions() const { + const tflite::ExpOptionsT *AsExpOptions() const { return type == BuiltinOptions_ExpOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - TopKV2OptionsT *AsTopKV2Options() { + tflite::TopKV2OptionsT *AsTopKV2Options() { return type == BuiltinOptions_TopKV2Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const TopKV2OptionsT *AsTopKV2Options() const { + const tflite::TopKV2OptionsT *AsTopKV2Options() const { return type == BuiltinOptions_TopKV2Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SplitOptionsT *AsSplitOptions() { + tflite::SplitOptionsT *AsSplitOptions() { return type == BuiltinOptions_SplitOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SplitOptionsT *AsSplitOptions() const { + const tflite::SplitOptionsT *AsSplitOptions() const { return type == BuiltinOptions_SplitOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LogSoftmaxOptionsT *AsLogSoftmaxOptions() { + tflite::LogSoftmaxOptionsT *AsLogSoftmaxOptions() { return type == BuiltinOptions_LogSoftmaxOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LogSoftmaxOptionsT *AsLogSoftmaxOptions() const { + const tflite::LogSoftmaxOptionsT *AsLogSoftmaxOptions() const { return type == BuiltinOptions_LogSoftmaxOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - CastOptionsT *AsCastOptions() { + tflite::CastOptionsT *AsCastOptions() { return type == BuiltinOptions_CastOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const CastOptionsT *AsCastOptions() const { + const tflite::CastOptionsT *AsCastOptions() const { return type == BuiltinOptions_CastOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - DequantizeOptionsT *AsDequantizeOptions() { + tflite::DequantizeOptionsT *AsDequantizeOptions() { return type == BuiltinOptions_DequantizeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const DequantizeOptionsT *AsDequantizeOptions() const { + const tflite::DequantizeOptionsT *AsDequantizeOptions() const { return type == BuiltinOptions_DequantizeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - MaximumMinimumOptionsT *AsMaximumMinimumOptions() { + tflite::MaximumMinimumOptionsT *AsMaximumMinimumOptions() { return type == BuiltinOptions_MaximumMinimumOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const MaximumMinimumOptionsT *AsMaximumMinimumOptions() const { + const tflite::MaximumMinimumOptionsT *AsMaximumMinimumOptions() const { return type == BuiltinOptions_MaximumMinimumOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ArgMaxOptionsT *AsArgMaxOptions() { + tflite::ArgMaxOptionsT *AsArgMaxOptions() { return type == BuiltinOptions_ArgMaxOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ArgMaxOptionsT *AsArgMaxOptions() const { + const tflite::ArgMaxOptionsT *AsArgMaxOptions() const { return type == BuiltinOptions_ArgMaxOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LessOptionsT *AsLessOptions() { + tflite::LessOptionsT *AsLessOptions() { return type == BuiltinOptions_LessOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LessOptionsT *AsLessOptions() const { + const tflite::LessOptionsT *AsLessOptions() const { return type == BuiltinOptions_LessOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - NegOptionsT *AsNegOptions() { + tflite::NegOptionsT *AsNegOptions() { return type == BuiltinOptions_NegOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const NegOptionsT *AsNegOptions() const { + const tflite::NegOptionsT *AsNegOptions() const { return type == BuiltinOptions_NegOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - PadV2OptionsT *AsPadV2Options() { + tflite::PadV2OptionsT *AsPadV2Options() { return type == BuiltinOptions_PadV2Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const PadV2OptionsT *AsPadV2Options() const { + const tflite::PadV2OptionsT *AsPadV2Options() const { return type == BuiltinOptions_PadV2Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - GreaterOptionsT *AsGreaterOptions() { + tflite::GreaterOptionsT *AsGreaterOptions() { return type == BuiltinOptions_GreaterOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const GreaterOptionsT *AsGreaterOptions() const { + const tflite::GreaterOptionsT *AsGreaterOptions() const { return type == BuiltinOptions_GreaterOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - GreaterEqualOptionsT *AsGreaterEqualOptions() { + tflite::GreaterEqualOptionsT *AsGreaterEqualOptions() { return type == BuiltinOptions_GreaterEqualOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const GreaterEqualOptionsT *AsGreaterEqualOptions() const { + const tflite::GreaterEqualOptionsT *AsGreaterEqualOptions() const { return type == BuiltinOptions_GreaterEqualOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LessEqualOptionsT *AsLessEqualOptions() { + tflite::LessEqualOptionsT *AsLessEqualOptions() { return type == BuiltinOptions_LessEqualOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LessEqualOptionsT *AsLessEqualOptions() const { + const tflite::LessEqualOptionsT *AsLessEqualOptions() const { return type == BuiltinOptions_LessEqualOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SelectOptionsT *AsSelectOptions() { + tflite::SelectOptionsT *AsSelectOptions() { return type == BuiltinOptions_SelectOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SelectOptionsT *AsSelectOptions() const { + const tflite::SelectOptionsT *AsSelectOptions() const { return type == BuiltinOptions_SelectOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SliceOptionsT *AsSliceOptions() { + tflite::SliceOptionsT *AsSliceOptions() { return type == BuiltinOptions_SliceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SliceOptionsT *AsSliceOptions() const { + const tflite::SliceOptionsT *AsSliceOptions() const { return type == BuiltinOptions_SliceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - TransposeConvOptionsT *AsTransposeConvOptions() { + tflite::TransposeConvOptionsT *AsTransposeConvOptions() { return type == BuiltinOptions_TransposeConvOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const TransposeConvOptionsT *AsTransposeConvOptions() const { + const tflite::TransposeConvOptionsT *AsTransposeConvOptions() const { return type == BuiltinOptions_TransposeConvOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SparseToDenseOptionsT *AsSparseToDenseOptions() { + tflite::SparseToDenseOptionsT *AsSparseToDenseOptions() { return type == BuiltinOptions_SparseToDenseOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SparseToDenseOptionsT *AsSparseToDenseOptions() const { + const tflite::SparseToDenseOptionsT *AsSparseToDenseOptions() const { return type == BuiltinOptions_SparseToDenseOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - TileOptionsT *AsTileOptions() { + tflite::TileOptionsT *AsTileOptions() { return type == BuiltinOptions_TileOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const TileOptionsT *AsTileOptions() const { + const tflite::TileOptionsT *AsTileOptions() const { return type == BuiltinOptions_TileOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ExpandDimsOptionsT *AsExpandDimsOptions() { + tflite::ExpandDimsOptionsT *AsExpandDimsOptions() { return type == BuiltinOptions_ExpandDimsOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ExpandDimsOptionsT *AsExpandDimsOptions() const { + const tflite::ExpandDimsOptionsT *AsExpandDimsOptions() const { return type == BuiltinOptions_ExpandDimsOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - EqualOptionsT *AsEqualOptions() { + tflite::EqualOptionsT *AsEqualOptions() { return type == BuiltinOptions_EqualOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const EqualOptionsT *AsEqualOptions() const { + const tflite::EqualOptionsT *AsEqualOptions() const { return type == BuiltinOptions_EqualOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - NotEqualOptionsT *AsNotEqualOptions() { + tflite::NotEqualOptionsT *AsNotEqualOptions() { return type == BuiltinOptions_NotEqualOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const NotEqualOptionsT *AsNotEqualOptions() const { + const tflite::NotEqualOptionsT *AsNotEqualOptions() const { return type == BuiltinOptions_NotEqualOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ShapeOptionsT *AsShapeOptions() { + tflite::ShapeOptionsT *AsShapeOptions() { return type == BuiltinOptions_ShapeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ShapeOptionsT *AsShapeOptions() const { + const tflite::ShapeOptionsT *AsShapeOptions() const { return type == BuiltinOptions_ShapeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - PowOptionsT *AsPowOptions() { + tflite::PowOptionsT *AsPowOptions() { return type == BuiltinOptions_PowOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const PowOptionsT *AsPowOptions() const { + const tflite::PowOptionsT *AsPowOptions() const { return type == BuiltinOptions_PowOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ArgMinOptionsT *AsArgMinOptions() { + tflite::ArgMinOptionsT *AsArgMinOptions() { return type == BuiltinOptions_ArgMinOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ArgMinOptionsT *AsArgMinOptions() const { + const tflite::ArgMinOptionsT *AsArgMinOptions() const { return type == BuiltinOptions_ArgMinOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - FakeQuantOptionsT *AsFakeQuantOptions() { + tflite::FakeQuantOptionsT *AsFakeQuantOptions() { return type == BuiltinOptions_FakeQuantOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const FakeQuantOptionsT *AsFakeQuantOptions() const { + const tflite::FakeQuantOptionsT *AsFakeQuantOptions() const { return type == BuiltinOptions_FakeQuantOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - PackOptionsT *AsPackOptions() { + tflite::PackOptionsT *AsPackOptions() { return type == BuiltinOptions_PackOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const PackOptionsT *AsPackOptions() const { + const tflite::PackOptionsT *AsPackOptions() const { return type == BuiltinOptions_PackOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LogicalOrOptionsT *AsLogicalOrOptions() { + tflite::LogicalOrOptionsT *AsLogicalOrOptions() { return type == BuiltinOptions_LogicalOrOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LogicalOrOptionsT *AsLogicalOrOptions() const { + const tflite::LogicalOrOptionsT *AsLogicalOrOptions() const { return type == BuiltinOptions_LogicalOrOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - OneHotOptionsT *AsOneHotOptions() { + tflite::OneHotOptionsT *AsOneHotOptions() { return type == BuiltinOptions_OneHotOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const OneHotOptionsT *AsOneHotOptions() const { + const tflite::OneHotOptionsT *AsOneHotOptions() const { return type == BuiltinOptions_OneHotOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LogicalAndOptionsT *AsLogicalAndOptions() { + tflite::LogicalAndOptionsT *AsLogicalAndOptions() { return type == BuiltinOptions_LogicalAndOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LogicalAndOptionsT *AsLogicalAndOptions() const { + const tflite::LogicalAndOptionsT *AsLogicalAndOptions() const { return type == BuiltinOptions_LogicalAndOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LogicalNotOptionsT *AsLogicalNotOptions() { + tflite::LogicalNotOptionsT *AsLogicalNotOptions() { return type == BuiltinOptions_LogicalNotOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LogicalNotOptionsT *AsLogicalNotOptions() const { + const tflite::LogicalNotOptionsT *AsLogicalNotOptions() const { return type == BuiltinOptions_LogicalNotOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - UnpackOptionsT *AsUnpackOptions() { + tflite::UnpackOptionsT *AsUnpackOptions() { return type == BuiltinOptions_UnpackOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const UnpackOptionsT *AsUnpackOptions() const { + const tflite::UnpackOptionsT *AsUnpackOptions() const { return type == BuiltinOptions_UnpackOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - FloorDivOptionsT *AsFloorDivOptions() { + tflite::FloorDivOptionsT *AsFloorDivOptions() { return type == BuiltinOptions_FloorDivOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const FloorDivOptionsT *AsFloorDivOptions() const { + const tflite::FloorDivOptionsT *AsFloorDivOptions() const { return type == BuiltinOptions_FloorDivOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SquareOptionsT *AsSquareOptions() { + tflite::SquareOptionsT *AsSquareOptions() { return type == BuiltinOptions_SquareOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SquareOptionsT *AsSquareOptions() const { + const tflite::SquareOptionsT *AsSquareOptions() const { return type == BuiltinOptions_SquareOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ZerosLikeOptionsT *AsZerosLikeOptions() { + tflite::ZerosLikeOptionsT *AsZerosLikeOptions() { return type == BuiltinOptions_ZerosLikeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ZerosLikeOptionsT *AsZerosLikeOptions() const { + const tflite::ZerosLikeOptionsT *AsZerosLikeOptions() const { return type == BuiltinOptions_ZerosLikeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - FillOptionsT *AsFillOptions() { + tflite::FillOptionsT *AsFillOptions() { return type == BuiltinOptions_FillOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const FillOptionsT *AsFillOptions() const { + const tflite::FillOptionsT *AsFillOptions() const { return type == BuiltinOptions_FillOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() { + tflite::BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() { return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() const { + const tflite::BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() const { return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() { + tflite::BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() { return type == BuiltinOptions_BidirectionalSequenceRNNOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() const { + const tflite::BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() const { return type == BuiltinOptions_BidirectionalSequenceRNNOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() { + tflite::UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() { return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() const { + const tflite::UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() const { return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - FloorModOptionsT *AsFloorModOptions() { + tflite::FloorModOptionsT *AsFloorModOptions() { return type == BuiltinOptions_FloorModOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const FloorModOptionsT *AsFloorModOptions() const { + const tflite::FloorModOptionsT *AsFloorModOptions() const { return type == BuiltinOptions_FloorModOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - RangeOptionsT *AsRangeOptions() { + tflite::RangeOptionsT *AsRangeOptions() { return type == BuiltinOptions_RangeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const RangeOptionsT *AsRangeOptions() const { + const tflite::RangeOptionsT *AsRangeOptions() const { return type == BuiltinOptions_RangeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ResizeNearestNeighborOptionsT *AsResizeNearestNeighborOptions() { + tflite::ResizeNearestNeighborOptionsT *AsResizeNearestNeighborOptions() { return type == BuiltinOptions_ResizeNearestNeighborOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ResizeNearestNeighborOptionsT *AsResizeNearestNeighborOptions() const { + const tflite::ResizeNearestNeighborOptionsT *AsResizeNearestNeighborOptions() const { return type == BuiltinOptions_ResizeNearestNeighborOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - LeakyReluOptionsT *AsLeakyReluOptions() { + tflite::LeakyReluOptionsT *AsLeakyReluOptions() { return type == BuiltinOptions_LeakyReluOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const LeakyReluOptionsT *AsLeakyReluOptions() const { + const tflite::LeakyReluOptionsT *AsLeakyReluOptions() const { return type == BuiltinOptions_LeakyReluOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SquaredDifferenceOptionsT *AsSquaredDifferenceOptions() { + tflite::SquaredDifferenceOptionsT *AsSquaredDifferenceOptions() { return type == BuiltinOptions_SquaredDifferenceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SquaredDifferenceOptionsT *AsSquaredDifferenceOptions() const { + const tflite::SquaredDifferenceOptionsT *AsSquaredDifferenceOptions() const { return type == BuiltinOptions_SquaredDifferenceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - MirrorPadOptionsT *AsMirrorPadOptions() { + tflite::MirrorPadOptionsT *AsMirrorPadOptions() { return type == BuiltinOptions_MirrorPadOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const MirrorPadOptionsT *AsMirrorPadOptions() const { + const tflite::MirrorPadOptionsT *AsMirrorPadOptions() const { return type == BuiltinOptions_MirrorPadOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - AbsOptionsT *AsAbsOptions() { + tflite::AbsOptionsT *AsAbsOptions() { return type == BuiltinOptions_AbsOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const AbsOptionsT *AsAbsOptions() const { + const tflite::AbsOptionsT *AsAbsOptions() const { return type == BuiltinOptions_AbsOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SplitVOptionsT *AsSplitVOptions() { + tflite::SplitVOptionsT *AsSplitVOptions() { return type == BuiltinOptions_SplitVOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SplitVOptionsT *AsSplitVOptions() const { + const tflite::SplitVOptionsT *AsSplitVOptions() const { return type == BuiltinOptions_SplitVOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - UniqueOptionsT *AsUniqueOptions() { + tflite::UniqueOptionsT *AsUniqueOptions() { return type == BuiltinOptions_UniqueOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const UniqueOptionsT *AsUniqueOptions() const { + const tflite::UniqueOptionsT *AsUniqueOptions() const { return type == BuiltinOptions_UniqueOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ReverseV2OptionsT *AsReverseV2Options() { + tflite::ReverseV2OptionsT *AsReverseV2Options() { return type == BuiltinOptions_ReverseV2Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ReverseV2OptionsT *AsReverseV2Options() const { + const tflite::ReverseV2OptionsT *AsReverseV2Options() const { return type == BuiltinOptions_ReverseV2Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - AddNOptionsT *AsAddNOptions() { + tflite::AddNOptionsT *AsAddNOptions() { return type == BuiltinOptions_AddNOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const AddNOptionsT *AsAddNOptions() const { + const tflite::AddNOptionsT *AsAddNOptions() const { return type == BuiltinOptions_AddNOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - GatherNdOptionsT *AsGatherNdOptions() { + tflite::GatherNdOptionsT *AsGatherNdOptions() { return type == BuiltinOptions_GatherNdOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const GatherNdOptionsT *AsGatherNdOptions() const { + const tflite::GatherNdOptionsT *AsGatherNdOptions() const { return type == BuiltinOptions_GatherNdOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - CosOptionsT *AsCosOptions() { + tflite::CosOptionsT *AsCosOptions() { return type == BuiltinOptions_CosOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const CosOptionsT *AsCosOptions() const { + const tflite::CosOptionsT *AsCosOptions() const { return type == BuiltinOptions_CosOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - WhereOptionsT *AsWhereOptions() { + tflite::WhereOptionsT *AsWhereOptions() { return type == BuiltinOptions_WhereOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const WhereOptionsT *AsWhereOptions() const { + const tflite::WhereOptionsT *AsWhereOptions() const { return type == BuiltinOptions_WhereOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - RankOptionsT *AsRankOptions() { + tflite::RankOptionsT *AsRankOptions() { return type == BuiltinOptions_RankOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const RankOptionsT *AsRankOptions() const { + const tflite::RankOptionsT *AsRankOptions() const { return type == BuiltinOptions_RankOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ReverseSequenceOptionsT *AsReverseSequenceOptions() { + tflite::ReverseSequenceOptionsT *AsReverseSequenceOptions() { return type == BuiltinOptions_ReverseSequenceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ReverseSequenceOptionsT *AsReverseSequenceOptions() const { + const tflite::ReverseSequenceOptionsT *AsReverseSequenceOptions() const { return type == BuiltinOptions_ReverseSequenceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - MatrixDiagOptionsT *AsMatrixDiagOptions() { + tflite::MatrixDiagOptionsT *AsMatrixDiagOptions() { return type == BuiltinOptions_MatrixDiagOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const MatrixDiagOptionsT *AsMatrixDiagOptions() const { + const tflite::MatrixDiagOptionsT *AsMatrixDiagOptions() const { return type == BuiltinOptions_MatrixDiagOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - QuantizeOptionsT *AsQuantizeOptions() { + tflite::QuantizeOptionsT *AsQuantizeOptions() { return type == BuiltinOptions_QuantizeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const QuantizeOptionsT *AsQuantizeOptions() const { + const tflite::QuantizeOptionsT *AsQuantizeOptions() const { return type == BuiltinOptions_QuantizeOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() { + tflite::MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() { return type == BuiltinOptions_MatrixSetDiagOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() const { + const tflite::MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() const { return type == BuiltinOptions_MatrixSetDiagOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - HardSwishOptionsT *AsHardSwishOptions() { + tflite::HardSwishOptionsT *AsHardSwishOptions() { return type == BuiltinOptions_HardSwishOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const HardSwishOptionsT *AsHardSwishOptions() const { + const tflite::HardSwishOptionsT *AsHardSwishOptions() const { return type == BuiltinOptions_HardSwishOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - IfOptionsT *AsIfOptions() { + tflite::IfOptionsT *AsIfOptions() { return type == BuiltinOptions_IfOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const IfOptionsT *AsIfOptions() const { + const tflite::IfOptionsT *AsIfOptions() const { return type == BuiltinOptions_IfOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - WhileOptionsT *AsWhileOptions() { + tflite::WhileOptionsT *AsWhileOptions() { return type == BuiltinOptions_WhileOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const WhileOptionsT *AsWhileOptions() const { + const tflite::WhileOptionsT *AsWhileOptions() const { return type == BuiltinOptions_WhileOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - DepthToSpaceOptionsT *AsDepthToSpaceOptions() { + tflite::DepthToSpaceOptionsT *AsDepthToSpaceOptions() { return type == BuiltinOptions_DepthToSpaceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const DepthToSpaceOptionsT *AsDepthToSpaceOptions() const { + const tflite::DepthToSpaceOptionsT *AsDepthToSpaceOptions() const { return type == BuiltinOptions_DepthToSpaceOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - NonMaxSuppressionV4OptionsT *AsNonMaxSuppressionV4Options() { + tflite::NonMaxSuppressionV4OptionsT *AsNonMaxSuppressionV4Options() { return type == BuiltinOptions_NonMaxSuppressionV4Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const NonMaxSuppressionV4OptionsT *AsNonMaxSuppressionV4Options() const { + const tflite::NonMaxSuppressionV4OptionsT *AsNonMaxSuppressionV4Options() const { return type == BuiltinOptions_NonMaxSuppressionV4Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - NonMaxSuppressionV5OptionsT *AsNonMaxSuppressionV5Options() { + tflite::NonMaxSuppressionV5OptionsT *AsNonMaxSuppressionV5Options() { return type == BuiltinOptions_NonMaxSuppressionV5Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const NonMaxSuppressionV5OptionsT *AsNonMaxSuppressionV5Options() const { + const tflite::NonMaxSuppressionV5OptionsT *AsNonMaxSuppressionV5Options() const { return type == BuiltinOptions_NonMaxSuppressionV5Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - ScatterNdOptionsT *AsScatterNdOptions() { + tflite::ScatterNdOptionsT *AsScatterNdOptions() { return type == BuiltinOptions_ScatterNdOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const ScatterNdOptionsT *AsScatterNdOptions() const { + const tflite::ScatterNdOptionsT *AsScatterNdOptions() const { return type == BuiltinOptions_ScatterNdOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SelectV2OptionsT *AsSelectV2Options() { + tflite::SelectV2OptionsT *AsSelectV2Options() { return type == BuiltinOptions_SelectV2Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SelectV2OptionsT *AsSelectV2Options() const { + const tflite::SelectV2OptionsT *AsSelectV2Options() const { return type == BuiltinOptions_SelectV2Options ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - DensifyOptionsT *AsDensifyOptions() { + tflite::DensifyOptionsT *AsDensifyOptions() { return type == BuiltinOptions_DensifyOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const DensifyOptionsT *AsDensifyOptions() const { + const tflite::DensifyOptionsT *AsDensifyOptions() const { return type == BuiltinOptions_DensifyOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - SegmentSumOptionsT *AsSegmentSumOptions() { + tflite::SegmentSumOptionsT *AsSegmentSumOptions() { return type == BuiltinOptions_SegmentSumOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } - const SegmentSumOptionsT *AsSegmentSumOptions() const { + const tflite::SegmentSumOptionsT *AsSegmentSumOptions() const { return type == BuiltinOptions_SegmentSumOptions ? - reinterpret_cast(value) : nullptr; + reinterpret_cast(value) : nullptr; } }; @@ -2630,7 +2630,7 @@ inline const Padding (&EnumValuesPadding())[2] { } inline const char * const *EnumNamesPadding() { - static const char * const names[] = { + static const char * const names[3] = { "SAME", "VALID", nullptr @@ -2639,7 +2639,7 @@ inline const char * const *EnumNamesPadding() { } inline const char *EnumNamePadding(Padding e) { - if (e < Padding_SAME || e > Padding_VALID) return ""; + if (flatbuffers::IsOutRange(e, Padding_SAME, Padding_VALID)) return ""; const size_t index = static_cast(e); return EnumNamesPadding()[index]; } @@ -2668,7 +2668,7 @@ inline const ActivationFunctionType (&EnumValuesActivationFunctionType())[6] { } inline const char * const *EnumNamesActivationFunctionType() { - static const char * const names[] = { + static const char * const names[7] = { "NONE", "RELU", "RELU_N1_TO_1", @@ -2681,7 +2681,7 @@ inline const char * const *EnumNamesActivationFunctionType() { } inline const char *EnumNameActivationFunctionType(ActivationFunctionType e) { - if (e < ActivationFunctionType_NONE || e > ActivationFunctionType_SIGN_BIT) return ""; + if (flatbuffers::IsOutRange(e, ActivationFunctionType_NONE, ActivationFunctionType_SIGN_BIT)) return ""; const size_t index = static_cast(e); return EnumNamesActivationFunctionType()[index]; } @@ -2704,7 +2704,7 @@ inline const LSHProjectionType (&EnumValuesLSHProjectionType())[3] { } inline const char * const *EnumNamesLSHProjectionType() { - static const char * const names[] = { + static const char * const names[4] = { "UNKNOWN", "SPARSE", "DENSE", @@ -2714,7 +2714,7 @@ inline const char * const *EnumNamesLSHProjectionType() { } inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { - if (e < LSHProjectionType_UNKNOWN || e > LSHProjectionType_DENSE) return ""; + if (flatbuffers::IsOutRange(e, LSHProjectionType_UNKNOWN, LSHProjectionType_DENSE)) return ""; const size_t index = static_cast(e); return EnumNamesLSHProjectionType()[index]; } @@ -2735,7 +2735,7 @@ inline const FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOption } inline const char * const *EnumNamesFullyConnectedOptionsWeightsFormat() { - static const char * const names[] = { + static const char * const names[3] = { "DEFAULT", "SHUFFLED4x16INT8", nullptr @@ -2744,7 +2744,7 @@ inline const char * const *EnumNamesFullyConnectedOptionsWeightsFormat() { } inline const char *EnumNameFullyConnectedOptionsWeightsFormat(FullyConnectedOptionsWeightsFormat e) { - if (e < FullyConnectedOptionsWeightsFormat_DEFAULT || e > FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8) return ""; + if (flatbuffers::IsOutRange(e, FullyConnectedOptionsWeightsFormat_DEFAULT, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8)) return ""; const size_t index = static_cast(e); return EnumNamesFullyConnectedOptionsWeightsFormat()[index]; } @@ -2765,7 +2765,7 @@ inline const LSTMKernelType (&EnumValuesLSTMKernelType())[2] { } inline const char * const *EnumNamesLSTMKernelType() { - static const char * const names[] = { + static const char * const names[3] = { "FULL", "BASIC", nullptr @@ -2774,7 +2774,7 @@ inline const char * const *EnumNamesLSTMKernelType() { } inline const char *EnumNameLSTMKernelType(LSTMKernelType e) { - if (e < LSTMKernelType_FULL || e > LSTMKernelType_BASIC) return ""; + if (flatbuffers::IsOutRange(e, LSTMKernelType_FULL, LSTMKernelType_BASIC)) return ""; const size_t index = static_cast(e); return EnumNamesLSTMKernelType()[index]; } @@ -2797,7 +2797,7 @@ inline const CombinerType (&EnumValuesCombinerType())[3] { } inline const char * const *EnumNamesCombinerType() { - static const char * const names[] = { + static const char * const names[4] = { "SUM", "MEAN", "SQRTN", @@ -2807,7 +2807,7 @@ inline const char * const *EnumNamesCombinerType() { } inline const char *EnumNameCombinerType(CombinerType e) { - if (e < CombinerType_SUM || e > CombinerType_SQRTN) return ""; + if (flatbuffers::IsOutRange(e, CombinerType_SUM, CombinerType_SQRTN)) return ""; const size_t index = static_cast(e); return EnumNamesCombinerType()[index]; } @@ -2828,7 +2828,7 @@ inline const MirrorPadMode (&EnumValuesMirrorPadMode())[2] { } inline const char * const *EnumNamesMirrorPadMode() { - static const char * const names[] = { + static const char * const names[3] = { "REFLECT", "SYMMETRIC", nullptr @@ -2837,7 +2837,7 @@ inline const char * const *EnumNamesMirrorPadMode() { } inline const char *EnumNameMirrorPadMode(MirrorPadMode e) { - if (e < MirrorPadMode_REFLECT || e > MirrorPadMode_SYMMETRIC) return ""; + if (flatbuffers::IsOutRange(e, MirrorPadMode_REFLECT, MirrorPadMode_SYMMETRIC)) return ""; const size_t index = static_cast(e); return EnumNamesMirrorPadMode()[index]; } @@ -2856,7 +2856,7 @@ inline const CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] { } inline const char * const *EnumNamesCustomOptionsFormat() { - static const char * const names[] = { + static const char * const names[2] = { "FLEXBUFFERS", nullptr }; @@ -2864,7 +2864,7 @@ inline const char * const *EnumNamesCustomOptionsFormat() { } inline const char *EnumNameCustomOptionsFormat(CustomOptionsFormat e) { - if (e < CustomOptionsFormat_FLEXBUFFERS || e > CustomOptionsFormat_FLEXBUFFERS) return ""; + if (flatbuffers::IsOutRange(e, CustomOptionsFormat_FLEXBUFFERS, CustomOptionsFormat_FLEXBUFFERS)) return ""; const size_t index = static_cast(e); return EnumNamesCustomOptionsFormat()[index]; } @@ -2939,7 +2939,7 @@ struct QuantizationParametersT : public flatbuffers::NativeTable { std::vector max; std::vector scale; std::vector zero_point; - QuantizationDetailsUnion details; + tflite::QuantizationDetailsUnion details; int32_t quantized_dimension; QuantizationParametersT() : quantized_dimension(0) { @@ -2969,15 +2969,15 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab const flatbuffers::Vector *zero_point() const { return GetPointer *>(VT_ZERO_POINT); } - QuantizationDetails details_type() const { - return static_cast(GetField(VT_DETAILS_TYPE, 0)); + tflite::QuantizationDetails details_type() const { + return static_cast(GetField(VT_DETAILS_TYPE, 0)); } const void *details() const { return GetPointer(VT_DETAILS); } template const T *details_as() const; - const CustomQuantization *details_as_CustomQuantization() const { - return details_type() == QuantizationDetails_CustomQuantization ? static_cast(details()) : nullptr; + const tflite::CustomQuantization *details_as_CustomQuantization() const { + return details_type() == tflite::QuantizationDetails_CustomQuantization ? static_cast(details()) : nullptr; } int32_t quantized_dimension() const { return GetField(VT_QUANTIZED_DIMENSION, 0); @@ -3003,7 +3003,7 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -template<> inline const CustomQuantization *QuantizationParameters::details_as() const { +template<> inline const tflite::CustomQuantization *QuantizationParameters::details_as() const { return details_as_CustomQuantization(); } @@ -3022,7 +3022,7 @@ struct QuantizationParametersBuilder { void add_zero_point(flatbuffers::Offset> zero_point) { fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point); } - void add_details_type(QuantizationDetails details_type) { + void add_details_type(tflite::QuantizationDetails details_type) { fbb_.AddElement(QuantizationParameters::VT_DETAILS_TYPE, static_cast(details_type), 0); } void add_details(flatbuffers::Offset details) { @@ -3049,7 +3049,7 @@ inline flatbuffers::Offset CreateQuantizationParameters( flatbuffers::Offset> max = 0, flatbuffers::Offset> scale = 0, flatbuffers::Offset> zero_point = 0, - QuantizationDetails details_type = QuantizationDetails_NONE, + tflite::QuantizationDetails details_type = tflite::QuantizationDetails_NONE, flatbuffers::Offset details = 0, int32_t quantized_dimension = 0) { QuantizationParametersBuilder builder_(_fbb); @@ -3069,7 +3069,7 @@ inline flatbuffers::Offset CreateQuantizationParametersD const std::vector *max = nullptr, const std::vector *scale = nullptr, const std::vector *zero_point = nullptr, - QuantizationDetails details_type = QuantizationDetails_NONE, + tflite::QuantizationDetails details_type = tflite::QuantizationDetails_NONE, flatbuffers::Offset details = 0, int32_t quantized_dimension = 0) { auto min__ = min ? _fbb.CreateVector(*min) : 0; @@ -3282,12 +3282,12 @@ flatbuffers::Offset CreateUint8Vector(flatbuffers::FlatBufferBuilde struct DimensionMetadataT : public flatbuffers::NativeTable { typedef DimensionMetadata TableType; - DimensionType format; + tflite::DimensionType format; int32_t dense_size; - SparseIndexVectorUnion array_segments; - SparseIndexVectorUnion array_indices; + tflite::SparseIndexVectorUnion array_segments; + tflite::SparseIndexVectorUnion array_indices; DimensionMetadataT() - : format(DimensionType_DENSE), + : format(tflite::DimensionType_DENSE), dense_size(0) { } }; @@ -3302,43 +3302,43 @@ struct DimensionMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_ARRAY_INDICES_TYPE = 12, VT_ARRAY_INDICES = 14 }; - DimensionType format() const { - return static_cast(GetField(VT_FORMAT, 0)); + tflite::DimensionType format() const { + return static_cast(GetField(VT_FORMAT, 0)); } int32_t dense_size() const { return GetField(VT_DENSE_SIZE, 0); } - SparseIndexVector array_segments_type() const { - return static_cast(GetField(VT_ARRAY_SEGMENTS_TYPE, 0)); + tflite::SparseIndexVector array_segments_type() const { + return static_cast(GetField(VT_ARRAY_SEGMENTS_TYPE, 0)); } const void *array_segments() const { return GetPointer(VT_ARRAY_SEGMENTS); } template const T *array_segments_as() const; - const Int32Vector *array_segments_as_Int32Vector() const { - return array_segments_type() == SparseIndexVector_Int32Vector ? static_cast(array_segments()) : nullptr; + const tflite::Int32Vector *array_segments_as_Int32Vector() const { + return array_segments_type() == tflite::SparseIndexVector_Int32Vector ? static_cast(array_segments()) : nullptr; } - const Uint16Vector *array_segments_as_Uint16Vector() const { - return array_segments_type() == SparseIndexVector_Uint16Vector ? static_cast(array_segments()) : nullptr; + const tflite::Uint16Vector *array_segments_as_Uint16Vector() const { + return array_segments_type() == tflite::SparseIndexVector_Uint16Vector ? static_cast(array_segments()) : nullptr; } - const Uint8Vector *array_segments_as_Uint8Vector() const { - return array_segments_type() == SparseIndexVector_Uint8Vector ? static_cast(array_segments()) : nullptr; + const tflite::Uint8Vector *array_segments_as_Uint8Vector() const { + return array_segments_type() == tflite::SparseIndexVector_Uint8Vector ? static_cast(array_segments()) : nullptr; } - SparseIndexVector array_indices_type() const { - return static_cast(GetField(VT_ARRAY_INDICES_TYPE, 0)); + tflite::SparseIndexVector array_indices_type() const { + return static_cast(GetField(VT_ARRAY_INDICES_TYPE, 0)); } const void *array_indices() const { return GetPointer(VT_ARRAY_INDICES); } template const T *array_indices_as() const; - const Int32Vector *array_indices_as_Int32Vector() const { - return array_indices_type() == SparseIndexVector_Int32Vector ? static_cast(array_indices()) : nullptr; + const tflite::Int32Vector *array_indices_as_Int32Vector() const { + return array_indices_type() == tflite::SparseIndexVector_Int32Vector ? static_cast(array_indices()) : nullptr; } - const Uint16Vector *array_indices_as_Uint16Vector() const { - return array_indices_type() == SparseIndexVector_Uint16Vector ? static_cast(array_indices()) : nullptr; + const tflite::Uint16Vector *array_indices_as_Uint16Vector() const { + return array_indices_type() == tflite::SparseIndexVector_Uint16Vector ? static_cast(array_indices()) : nullptr; } - const Uint8Vector *array_indices_as_Uint8Vector() const { - return array_indices_type() == SparseIndexVector_Uint8Vector ? static_cast(array_indices()) : nullptr; + const tflite::Uint8Vector *array_indices_as_Uint8Vector() const { + return array_indices_type() == tflite::SparseIndexVector_Uint8Vector ? static_cast(array_indices()) : nullptr; } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -3357,46 +3357,46 @@ struct DimensionMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -template<> inline const Int32Vector *DimensionMetadata::array_segments_as() const { +template<> inline const tflite::Int32Vector *DimensionMetadata::array_segments_as() const { return array_segments_as_Int32Vector(); } -template<> inline const Uint16Vector *DimensionMetadata::array_segments_as() const { +template<> inline const tflite::Uint16Vector *DimensionMetadata::array_segments_as() const { return array_segments_as_Uint16Vector(); } -template<> inline const Uint8Vector *DimensionMetadata::array_segments_as() const { +template<> inline const tflite::Uint8Vector *DimensionMetadata::array_segments_as() const { return array_segments_as_Uint8Vector(); } -template<> inline const Int32Vector *DimensionMetadata::array_indices_as() const { +template<> inline const tflite::Int32Vector *DimensionMetadata::array_indices_as() const { return array_indices_as_Int32Vector(); } -template<> inline const Uint16Vector *DimensionMetadata::array_indices_as() const { +template<> inline const tflite::Uint16Vector *DimensionMetadata::array_indices_as() const { return array_indices_as_Uint16Vector(); } -template<> inline const Uint8Vector *DimensionMetadata::array_indices_as() const { +template<> inline const tflite::Uint8Vector *DimensionMetadata::array_indices_as() const { return array_indices_as_Uint8Vector(); } struct DimensionMetadataBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_format(DimensionType format) { + void add_format(tflite::DimensionType format) { fbb_.AddElement(DimensionMetadata::VT_FORMAT, static_cast(format), 0); } void add_dense_size(int32_t dense_size) { fbb_.AddElement(DimensionMetadata::VT_DENSE_SIZE, dense_size, 0); } - void add_array_segments_type(SparseIndexVector array_segments_type) { + void add_array_segments_type(tflite::SparseIndexVector array_segments_type) { fbb_.AddElement(DimensionMetadata::VT_ARRAY_SEGMENTS_TYPE, static_cast(array_segments_type), 0); } void add_array_segments(flatbuffers::Offset array_segments) { fbb_.AddOffset(DimensionMetadata::VT_ARRAY_SEGMENTS, array_segments); } - void add_array_indices_type(SparseIndexVector array_indices_type) { + void add_array_indices_type(tflite::SparseIndexVector array_indices_type) { fbb_.AddElement(DimensionMetadata::VT_ARRAY_INDICES_TYPE, static_cast(array_indices_type), 0); } void add_array_indices(flatbuffers::Offset array_indices) { @@ -3416,11 +3416,11 @@ struct DimensionMetadataBuilder { inline flatbuffers::Offset CreateDimensionMetadata( flatbuffers::FlatBufferBuilder &_fbb, - DimensionType format = DimensionType_DENSE, + tflite::DimensionType format = tflite::DimensionType_DENSE, int32_t dense_size = 0, - SparseIndexVector array_segments_type = SparseIndexVector_NONE, + tflite::SparseIndexVector array_segments_type = tflite::SparseIndexVector_NONE, flatbuffers::Offset array_segments = 0, - SparseIndexVector array_indices_type = SparseIndexVector_NONE, + tflite::SparseIndexVector array_indices_type = tflite::SparseIndexVector_NONE, flatbuffers::Offset array_indices = 0) { DimensionMetadataBuilder builder_(_fbb); builder_.add_array_indices(array_indices); @@ -3438,7 +3438,7 @@ struct SparsityParametersT : public flatbuffers::NativeTable { typedef SparsityParameters TableType; std::vector traversal_order; std::vector block_map; - std::vector> dim_metadata; + std::vector> dim_metadata; SparsityParametersT() { } }; @@ -3456,8 +3456,8 @@ struct SparsityParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *block_map() const { return GetPointer *>(VT_BLOCK_MAP); } - const flatbuffers::Vector> *dim_metadata() const { - return GetPointer> *>(VT_DIM_METADATA); + const flatbuffers::Vector> *dim_metadata() const { + return GetPointer> *>(VT_DIM_METADATA); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -3484,7 +3484,7 @@ struct SparsityParametersBuilder { void add_block_map(flatbuffers::Offset> block_map) { fbb_.AddOffset(SparsityParameters::VT_BLOCK_MAP, block_map); } - void add_dim_metadata(flatbuffers::Offset>> dim_metadata) { + void add_dim_metadata(flatbuffers::Offset>> dim_metadata) { fbb_.AddOffset(SparsityParameters::VT_DIM_METADATA, dim_metadata); } explicit SparsityParametersBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -3503,7 +3503,7 @@ inline flatbuffers::Offset CreateSparsityParameters( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> traversal_order = 0, flatbuffers::Offset> block_map = 0, - flatbuffers::Offset>> dim_metadata = 0) { + flatbuffers::Offset>> dim_metadata = 0) { SparsityParametersBuilder builder_(_fbb); builder_.add_dim_metadata(dim_metadata); builder_.add_block_map(block_map); @@ -3515,10 +3515,10 @@ inline flatbuffers::Offset CreateSparsityParametersDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *traversal_order = nullptr, const std::vector *block_map = nullptr, - const std::vector> *dim_metadata = nullptr) { + const std::vector> *dim_metadata = nullptr) { auto traversal_order__ = traversal_order ? _fbb.CreateVector(*traversal_order) : 0; auto block_map__ = block_map ? _fbb.CreateVector(*block_map) : 0; - auto dim_metadata__ = dim_metadata ? _fbb.CreateVector>(*dim_metadata) : 0; + auto dim_metadata__ = dim_metadata ? _fbb.CreateVector>(*dim_metadata) : 0; return tflite::CreateSparsityParameters( _fbb, traversal_order__, @@ -3531,15 +3531,15 @@ flatbuffers::Offset CreateSparsityParameters(flatbuffers::Fl struct TensorT : public flatbuffers::NativeTable { typedef Tensor TableType; std::vector shape; - TensorType type; + tflite::TensorType type; uint32_t buffer; std::string name; - std::unique_ptr quantization; + std::unique_ptr quantization; bool is_variable; - std::unique_ptr sparsity; + std::unique_ptr sparsity; std::vector shape_signature; TensorT() - : type(TensorType_FLOAT32), + : type(tflite::TensorType_FLOAT32), buffer(0), is_variable(false) { } @@ -3560,8 +3560,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); } - TensorType type() const { - return static_cast(GetField(VT_TYPE, 0)); + tflite::TensorType type() const { + return static_cast(GetField(VT_TYPE, 0)); } uint32_t buffer() const { return GetField(VT_BUFFER, 0); @@ -3569,14 +3569,14 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::String *name() const { return GetPointer(VT_NAME); } - const QuantizationParameters *quantization() const { - return GetPointer(VT_QUANTIZATION); + const tflite::QuantizationParameters *quantization() const { + return GetPointer(VT_QUANTIZATION); } bool is_variable() const { return GetField(VT_IS_VARIABLE, 0) != 0; } - const SparsityParameters *sparsity() const { - return GetPointer(VT_SPARSITY); + const tflite::SparsityParameters *sparsity() const { + return GetPointer(VT_SPARSITY); } const flatbuffers::Vector *shape_signature() const { return GetPointer *>(VT_SHAPE_SIGNATURE); @@ -3609,7 +3609,7 @@ struct TensorBuilder { void add_shape(flatbuffers::Offset> shape) { fbb_.AddOffset(Tensor::VT_SHAPE, shape); } - void add_type(TensorType type) { + void add_type(tflite::TensorType type) { fbb_.AddElement(Tensor::VT_TYPE, static_cast(type), 0); } void add_buffer(uint32_t buffer) { @@ -3618,13 +3618,13 @@ struct TensorBuilder { void add_name(flatbuffers::Offset name) { fbb_.AddOffset(Tensor::VT_NAME, name); } - void add_quantization(flatbuffers::Offset quantization) { + void add_quantization(flatbuffers::Offset quantization) { fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); } void add_is_variable(bool is_variable) { fbb_.AddElement(Tensor::VT_IS_VARIABLE, static_cast(is_variable), 0); } - void add_sparsity(flatbuffers::Offset sparsity) { + void add_sparsity(flatbuffers::Offset sparsity) { fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity); } void add_shape_signature(flatbuffers::Offset> shape_signature) { @@ -3645,12 +3645,12 @@ struct TensorBuilder { inline flatbuffers::Offset CreateTensor( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> shape = 0, - TensorType type = TensorType_FLOAT32, + tflite::TensorType type = tflite::TensorType_FLOAT32, uint32_t buffer = 0, flatbuffers::Offset name = 0, - flatbuffers::Offset quantization = 0, + flatbuffers::Offset quantization = 0, bool is_variable = false, - flatbuffers::Offset sparsity = 0, + flatbuffers::Offset sparsity = 0, flatbuffers::Offset> shape_signature = 0) { TensorBuilder builder_(_fbb); builder_.add_shape_signature(shape_signature); @@ -3667,12 +3667,12 @@ inline flatbuffers::Offset CreateTensor( inline flatbuffers::Offset CreateTensorDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *shape = nullptr, - TensorType type = TensorType_FLOAT32, + tflite::TensorType type = tflite::TensorType_FLOAT32, uint32_t buffer = 0, const char *name = nullptr, - flatbuffers::Offset quantization = 0, + flatbuffers::Offset quantization = 0, bool is_variable = false, - flatbuffers::Offset sparsity = 0, + flatbuffers::Offset sparsity = 0, const std::vector *shape_signature = nullptr) { auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; auto name__ = name ? _fbb.CreateString(name) : 0; @@ -3693,17 +3693,17 @@ flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, c struct Conv2DOptionsT : public flatbuffers::NativeTable { typedef Conv2DOptions TableType; - Padding padding; + tflite::Padding padding; int32_t stride_w; int32_t stride_h; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; int32_t dilation_w_factor; int32_t dilation_h_factor; Conv2DOptionsT() - : padding(Padding_SAME), + : padding(tflite::Padding_SAME), stride_w(0), stride_h(0), - fused_activation_function(ActivationFunctionType_NONE), + fused_activation_function(tflite::ActivationFunctionType_NONE), dilation_w_factor(1), dilation_h_factor(1) { } @@ -3719,8 +3719,8 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_DILATION_W_FACTOR = 12, VT_DILATION_H_FACTOR = 14 }; - Padding padding() const { - return static_cast(GetField(VT_PADDING, 0)); + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); } int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); @@ -3728,8 +3728,8 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } int32_t dilation_w_factor() const { return GetField(VT_DILATION_W_FACTOR, 1); @@ -3755,7 +3755,7 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct Conv2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_padding(Padding padding) { + void add_padding(tflite::Padding padding) { fbb_.AddElement(Conv2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { @@ -3764,7 +3764,7 @@ struct Conv2DOptionsBuilder { void add_stride_h(int32_t stride_h) { fbb_.AddElement(Conv2DOptions::VT_STRIDE_H, stride_h, 0); } - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } void add_dilation_w_factor(int32_t dilation_w_factor) { @@ -3787,10 +3787,10 @@ struct Conv2DOptionsBuilder { inline flatbuffers::Offset CreateConv2DOptions( flatbuffers::FlatBufferBuilder &_fbb, - Padding padding = Padding_SAME, + tflite::Padding padding = tflite::Padding_SAME, int32_t stride_w = 0, int32_t stride_h = 0, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, int32_t dilation_w_factor = 1, int32_t dilation_h_factor = 1) { Conv2DOptionsBuilder builder_(_fbb); @@ -3807,19 +3807,19 @@ flatbuffers::Offset CreateConv2DOptions(flatbuffers::FlatBufferBu struct Pool2DOptionsT : public flatbuffers::NativeTable { typedef Pool2DOptions TableType; - Padding padding; + tflite::Padding padding; int32_t stride_w; int32_t stride_h; int32_t filter_width; int32_t filter_height; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; Pool2DOptionsT() - : padding(Padding_SAME), + : padding(tflite::Padding_SAME), stride_w(0), stride_h(0), filter_width(0), filter_height(0), - fused_activation_function(ActivationFunctionType_NONE) { + fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -3833,8 +3833,8 @@ struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_FILTER_HEIGHT = 12, VT_FUSED_ACTIVATION_FUNCTION = 14 }; - Padding padding() const { - return static_cast(GetField(VT_PADDING, 0)); + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); } int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); @@ -3848,8 +3848,8 @@ struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { int32_t filter_height() const { return GetField(VT_FILTER_HEIGHT, 0); } - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -3869,7 +3869,7 @@ struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct Pool2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_padding(Padding padding) { + void add_padding(tflite::Padding padding) { fbb_.AddElement(Pool2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { @@ -3884,7 +3884,7 @@ struct Pool2DOptionsBuilder { void add_filter_height(int32_t filter_height) { fbb_.AddElement(Pool2DOptions::VT_FILTER_HEIGHT, filter_height, 0); } - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit Pool2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -3901,12 +3901,12 @@ struct Pool2DOptionsBuilder { inline flatbuffers::Offset CreatePool2DOptions( flatbuffers::FlatBufferBuilder &_fbb, - Padding padding = Padding_SAME, + tflite::Padding padding = tflite::Padding_SAME, int32_t stride_w = 0, int32_t stride_h = 0, int32_t filter_width = 0, int32_t filter_height = 0, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { Pool2DOptionsBuilder builder_(_fbb); builder_.add_filter_height(filter_height); builder_.add_filter_width(filter_width); @@ -3921,19 +3921,19 @@ flatbuffers::Offset CreatePool2DOptions(flatbuffers::FlatBufferBu struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable { typedef DepthwiseConv2DOptions TableType; - Padding padding; + tflite::Padding padding; int32_t stride_w; int32_t stride_h; int32_t depth_multiplier; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; int32_t dilation_w_factor; int32_t dilation_h_factor; DepthwiseConv2DOptionsT() - : padding(Padding_SAME), + : padding(tflite::Padding_SAME), stride_w(0), stride_h(0), depth_multiplier(0), - fused_activation_function(ActivationFunctionType_NONE), + fused_activation_function(tflite::ActivationFunctionType_NONE), dilation_w_factor(1), dilation_h_factor(1) { } @@ -3950,8 +3950,8 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab VT_DILATION_W_FACTOR = 14, VT_DILATION_H_FACTOR = 16 }; - Padding padding() const { - return static_cast(GetField(VT_PADDING, 0)); + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); } int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); @@ -3962,8 +3962,8 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab int32_t depth_multiplier() const { return GetField(VT_DEPTH_MULTIPLIER, 0); } - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } int32_t dilation_w_factor() const { return GetField(VT_DILATION_W_FACTOR, 1); @@ -3990,7 +3990,7 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab struct DepthwiseConv2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_padding(Padding padding) { + void add_padding(tflite::Padding padding) { fbb_.AddElement(DepthwiseConv2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { @@ -4002,7 +4002,7 @@ struct DepthwiseConv2DOptionsBuilder { void add_depth_multiplier(int32_t depth_multiplier) { fbb_.AddElement(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, depth_multiplier, 0); } - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } void add_dilation_w_factor(int32_t dilation_w_factor) { @@ -4025,11 +4025,11 @@ struct DepthwiseConv2DOptionsBuilder { inline flatbuffers::Offset CreateDepthwiseConv2DOptions( flatbuffers::FlatBufferBuilder &_fbb, - Padding padding = Padding_SAME, + tflite::Padding padding = tflite::Padding_SAME, int32_t stride_w = 0, int32_t stride_h = 0, int32_t depth_multiplier = 0, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, int32_t dilation_w_factor = 1, int32_t dilation_h_factor = 1) { DepthwiseConv2DOptionsBuilder builder_(_fbb); @@ -4139,9 +4139,9 @@ flatbuffers::Offset CreateConcatEmbeddingsOptions(flatb struct LSHProjectionOptionsT : public flatbuffers::NativeTable { typedef LSHProjectionOptions TableType; - LSHProjectionType type; + tflite::LSHProjectionType type; LSHProjectionOptionsT() - : type(LSHProjectionType_UNKNOWN) { + : type(tflite::LSHProjectionType_UNKNOWN) { } }; @@ -4150,8 +4150,8 @@ struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_TYPE = 4 }; - LSHProjectionType type() const { - return static_cast(GetField(VT_TYPE, 0)); + tflite::LSHProjectionType type() const { + return static_cast(GetField(VT_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4166,7 +4166,7 @@ struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table struct LSHProjectionOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_type(LSHProjectionType type) { + void add_type(tflite::LSHProjectionType type) { fbb_.AddElement(LSHProjectionOptions::VT_TYPE, static_cast(type), 0); } explicit LSHProjectionOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4183,7 +4183,7 @@ struct LSHProjectionOptionsBuilder { inline flatbuffers::Offset CreateLSHProjectionOptions( flatbuffers::FlatBufferBuilder &_fbb, - LSHProjectionType type = LSHProjectionType_UNKNOWN) { + tflite::LSHProjectionType type = tflite::LSHProjectionType_UNKNOWN) { LSHProjectionOptionsBuilder builder_(_fbb); builder_.add_type(type); return builder_.Finish(); @@ -4194,10 +4194,10 @@ flatbuffers::Offset CreateLSHProjectionOptions(flatbuffers struct SVDFOptionsT : public flatbuffers::NativeTable { typedef SVDFOptions TableType; int32_t rank; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; SVDFOptionsT() : rank(0), - fused_activation_function(ActivationFunctionType_NONE) { + fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -4210,8 +4210,8 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { int32_t rank() const { return GetField(VT_RANK, 0); } - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4230,7 +4230,7 @@ struct SVDFOptionsBuilder { void add_rank(int32_t rank) { fbb_.AddElement(SVDFOptions::VT_RANK, rank, 0); } - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4248,7 +4248,7 @@ struct SVDFOptionsBuilder { inline flatbuffers::Offset CreateSVDFOptions( flatbuffers::FlatBufferBuilder &_fbb, int32_t rank = 0, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { SVDFOptionsBuilder builder_(_fbb); builder_.add_rank(rank); builder_.add_fused_activation_function(fused_activation_function); @@ -4259,9 +4259,9 @@ flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilde struct RNNOptionsT : public flatbuffers::NativeTable { typedef RNNOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; RNNOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -4270,8 +4270,8 @@ struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FUSED_ACTIVATION_FUNCTION = 4 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4286,7 +4286,7 @@ struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct RNNOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4303,7 +4303,7 @@ struct RNNOptionsBuilder { inline flatbuffers::Offset CreateRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { RNNOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -4314,10 +4314,10 @@ flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferBuilder struct SequenceRNNOptionsT : public flatbuffers::NativeTable { typedef SequenceRNNOptions TableType; bool time_major; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; SequenceRNNOptionsT() : time_major(false), - fused_activation_function(ActivationFunctionType_NONE) { + fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -4330,8 +4330,8 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; } - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4350,7 +4350,7 @@ struct SequenceRNNOptionsBuilder { void add_time_major(bool time_major) { fbb_.AddElement(SequenceRNNOptions::VT_TIME_MAJOR, static_cast(time_major), 0); } - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4368,7 +4368,7 @@ struct SequenceRNNOptionsBuilder { inline flatbuffers::Offset CreateSequenceRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { SequenceRNNOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); @@ -4380,11 +4380,11 @@ flatbuffers::Offset CreateSequenceRNNOptions(flatbuffers::Fl struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable { typedef BidirectionalSequenceRNNOptions TableType; bool time_major; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; bool merge_outputs; BidirectionalSequenceRNNOptionsT() : time_major(false), - fused_activation_function(ActivationFunctionType_NONE), + fused_activation_function(tflite::ActivationFunctionType_NONE), merge_outputs(false) { } }; @@ -4399,8 +4399,8 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; } - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool merge_outputs() const { return GetField(VT_MERGE_OUTPUTS, 0) != 0; @@ -4423,7 +4423,7 @@ struct BidirectionalSequenceRNNOptionsBuilder { void add_time_major(bool time_major) { fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_TIME_MAJOR, static_cast(time_major), 0); } - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } void add_merge_outputs(bool merge_outputs) { @@ -4444,7 +4444,7 @@ struct BidirectionalSequenceRNNOptionsBuilder { inline flatbuffers::Offset CreateBidirectionalSequenceRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, bool merge_outputs = false) { BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); builder_.add_merge_outputs(merge_outputs); @@ -4457,12 +4457,12 @@ flatbuffers::Offset CreateBidirectionalSequence struct FullyConnectedOptionsT : public flatbuffers::NativeTable { typedef FullyConnectedOptions TableType; - ActivationFunctionType fused_activation_function; - FullyConnectedOptionsWeightsFormat weights_format; + tflite::ActivationFunctionType fused_activation_function; + tflite::FullyConnectedOptionsWeightsFormat weights_format; bool keep_num_dims; FullyConnectedOptionsT() - : fused_activation_function(ActivationFunctionType_NONE), - weights_format(FullyConnectedOptionsWeightsFormat_DEFAULT), + : fused_activation_function(tflite::ActivationFunctionType_NONE), + weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT), keep_num_dims(false) { } }; @@ -4474,11 +4474,11 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl VT_WEIGHTS_FORMAT = 6, VT_KEEP_NUM_DIMS = 8 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } - FullyConnectedOptionsWeightsFormat weights_format() const { - return static_cast(GetField(VT_WEIGHTS_FORMAT, 0)); + tflite::FullyConnectedOptionsWeightsFormat weights_format() const { + return static_cast(GetField(VT_WEIGHTS_FORMAT, 0)); } bool keep_num_dims() const { return GetField(VT_KEEP_NUM_DIMS, 0) != 0; @@ -4498,10 +4498,10 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl struct FullyConnectedOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } - void add_weights_format(FullyConnectedOptionsWeightsFormat weights_format) { + void add_weights_format(tflite::FullyConnectedOptionsWeightsFormat weights_format) { fbb_.AddElement(FullyConnectedOptions::VT_WEIGHTS_FORMAT, static_cast(weights_format), 0); } void add_keep_num_dims(bool keep_num_dims) { @@ -4521,8 +4521,8 @@ struct FullyConnectedOptionsBuilder { inline flatbuffers::Offset CreateFullyConnectedOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, - FullyConnectedOptionsWeightsFormat weights_format = FullyConnectedOptionsWeightsFormat_DEFAULT, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, bool keep_num_dims = false) { FullyConnectedOptionsBuilder builder_(_fbb); builder_.add_keep_num_dims(keep_num_dims); @@ -4590,10 +4590,10 @@ flatbuffers::Offset CreateSoftmaxOptions(flatbuffers::FlatBuffer struct ConcatenationOptionsT : public flatbuffers::NativeTable { typedef ConcatenationOptions TableType; int32_t axis; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; ConcatenationOptionsT() : axis(0), - fused_activation_function(ActivationFunctionType_NONE) { + fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -4606,8 +4606,8 @@ struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table int32_t axis() const { return GetField(VT_AXIS, 0); } - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4626,7 +4626,7 @@ struct ConcatenationOptionsBuilder { void add_axis(int32_t axis) { fbb_.AddElement(ConcatenationOptions::VT_AXIS, axis, 0); } - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit ConcatenationOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4644,7 +4644,7 @@ struct ConcatenationOptionsBuilder { inline flatbuffers::Offset CreateConcatenationOptions( flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { ConcatenationOptionsBuilder builder_(_fbb); builder_.add_axis(axis); builder_.add_fused_activation_function(fused_activation_function); @@ -4655,9 +4655,9 @@ flatbuffers::Offset CreateConcatenationOptions(flatbuffers struct AddOptionsT : public flatbuffers::NativeTable { typedef AddOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; AddOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -4666,8 +4666,8 @@ struct AddOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FUSED_ACTIVATION_FUNCTION = 4 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4682,7 +4682,7 @@ struct AddOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct AddOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit AddOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4699,7 +4699,7 @@ struct AddOptionsBuilder { inline flatbuffers::Offset CreateAddOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { AddOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -4709,9 +4709,9 @@ flatbuffers::Offset CreateAddOptions(flatbuffers::FlatBufferBuilder struct MulOptionsT : public flatbuffers::NativeTable { typedef MulOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; MulOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -4720,8 +4720,8 @@ struct MulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FUSED_ACTIVATION_FUNCTION = 4 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4736,7 +4736,7 @@ struct MulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct MulOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit MulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4753,7 +4753,7 @@ struct MulOptionsBuilder { inline flatbuffers::Offset CreateMulOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { MulOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -4763,9 +4763,9 @@ flatbuffers::Offset CreateMulOptions(flatbuffers::FlatBufferBuilder struct L2NormOptionsT : public flatbuffers::NativeTable { typedef L2NormOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; L2NormOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -4774,8 +4774,8 @@ struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FUSED_ACTIVATION_FUNCTION = 4 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4790,7 +4790,7 @@ struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct L2NormOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit L2NormOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4807,7 +4807,7 @@ struct L2NormOptionsBuilder { inline flatbuffers::Offset CreateL2NormOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { L2NormOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -4907,15 +4907,15 @@ flatbuffers::Offset CreateLocalResponseNormal struct LSTMOptionsT : public flatbuffers::NativeTable { typedef LSTMOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; - LSTMKernelType kernel_type; + tflite::LSTMKernelType kernel_type; LSTMOptionsT() - : fused_activation_function(ActivationFunctionType_NONE), + : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), - kernel_type(LSTMKernelType_FULL) { + kernel_type(tflite::LSTMKernelType_FULL) { } }; @@ -4927,8 +4927,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_PROJ_CLIP = 8, VT_KERNEL_TYPE = 10 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } float cell_clip() const { return GetField(VT_CELL_CLIP, 0.0f); @@ -4936,8 +4936,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { float proj_clip() const { return GetField(VT_PROJ_CLIP, 0.0f); } - LSTMKernelType kernel_type() const { - return static_cast(GetField(VT_KERNEL_TYPE, 0)); + tflite::LSTMKernelType kernel_type() const { + return static_cast(GetField(VT_KERNEL_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4955,7 +4955,7 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct LSTMOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } void add_cell_clip(float cell_clip) { @@ -4964,7 +4964,7 @@ struct LSTMOptionsBuilder { void add_proj_clip(float proj_clip) { fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } - void add_kernel_type(LSTMKernelType kernel_type) { + void add_kernel_type(tflite::LSTMKernelType kernel_type) { fbb_.AddElement(LSTMOptions::VT_KERNEL_TYPE, static_cast(kernel_type), 0); } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -4981,10 +4981,10 @@ struct LSTMOptionsBuilder { inline flatbuffers::Offset CreateLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, float cell_clip = 0.0f, float proj_clip = 0.0f, - LSTMKernelType kernel_type = LSTMKernelType_FULL) { + tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); @@ -4997,12 +4997,12 @@ flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBufferBuilde struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { typedef UnidirectionalSequenceLSTMOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; bool time_major; UnidirectionalSequenceLSTMOptionsT() - : fused_activation_function(ActivationFunctionType_NONE), + : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), time_major(false) { @@ -5017,8 +5017,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb VT_PROJ_CLIP = 8, VT_TIME_MAJOR = 10 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } float cell_clip() const { return GetField(VT_CELL_CLIP, 0.0f); @@ -5045,7 +5045,7 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb struct UnidirectionalSequenceLSTMOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } void add_cell_clip(float cell_clip) { @@ -5071,7 +5071,7 @@ struct UnidirectionalSequenceLSTMOptionsBuilder { inline flatbuffers::Offset CreateUnidirectionalSequenceLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, float cell_clip = 0.0f, float proj_clip = 0.0f, bool time_major = false) { @@ -5087,13 +5087,13 @@ flatbuffers::Offset CreateUnidirectionalSeque struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { typedef BidirectionalSequenceLSTMOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; bool merge_outputs; bool time_major; BidirectionalSequenceLSTMOptionsT() - : fused_activation_function(ActivationFunctionType_NONE), + : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), merge_outputs(false), @@ -5110,8 +5110,8 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu VT_MERGE_OUTPUTS = 10, VT_TIME_MAJOR = 12 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } float cell_clip() const { return GetField(VT_CELL_CLIP, 0.0f); @@ -5142,7 +5142,7 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu struct BidirectionalSequenceLSTMOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } void add_cell_clip(float cell_clip) { @@ -5171,7 +5171,7 @@ struct BidirectionalSequenceLSTMOptionsBuilder { inline flatbuffers::Offset CreateBidirectionalSequenceLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, float cell_clip = 0.0f, float proj_clip = 0.0f, bool merge_outputs = false, @@ -5772,9 +5772,9 @@ flatbuffers::Offset CreateDepthToSpaceOptions(flatbuffers:: struct SubOptionsT : public flatbuffers::NativeTable { typedef SubOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; SubOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -5783,8 +5783,8 @@ struct SubOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FUSED_ACTIVATION_FUNCTION = 4 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -5799,7 +5799,7 @@ struct SubOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct SubOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(SubOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SubOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -5816,7 +5816,7 @@ struct SubOptionsBuilder { inline flatbuffers::Offset CreateSubOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { SubOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -5826,9 +5826,9 @@ flatbuffers::Offset CreateSubOptions(flatbuffers::FlatBufferBuilder struct DivOptionsT : public flatbuffers::NativeTable { typedef DivOptions TableType; - ActivationFunctionType fused_activation_function; + tflite::ActivationFunctionType fused_activation_function; DivOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(tflite::ActivationFunctionType_NONE) { } }; @@ -5837,8 +5837,8 @@ struct DivOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FUSED_ACTIVATION_FUNCTION = 4 }; - ActivationFunctionType fused_activation_function() const { - return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -5853,7 +5853,7 @@ struct DivOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct DivOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(DivOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit DivOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -5870,7 +5870,7 @@ struct DivOptionsBuilder { inline flatbuffers::Offset CreateDivOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { DivOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -5920,9 +5920,9 @@ flatbuffers::Offset CreateTopKV2Options(flatbuffers::FlatBufferBu struct EmbeddingLookupSparseOptionsT : public flatbuffers::NativeTable { typedef EmbeddingLookupSparseOptions TableType; - CombinerType combiner; + tflite::CombinerType combiner; EmbeddingLookupSparseOptionsT() - : combiner(CombinerType_SUM) { + : combiner(tflite::CombinerType_SUM) { } }; @@ -5931,8 +5931,8 @@ struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffer enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_COMBINER = 4 }; - CombinerType combiner() const { - return static_cast(GetField(VT_COMBINER, 0)); + tflite::CombinerType combiner() const { + return static_cast(GetField(VT_COMBINER, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -5947,7 +5947,7 @@ struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffer struct EmbeddingLookupSparseOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_combiner(CombinerType combiner) { + void add_combiner(tflite::CombinerType combiner) { fbb_.AddElement(EmbeddingLookupSparseOptions::VT_COMBINER, static_cast(combiner), 0); } explicit EmbeddingLookupSparseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -5964,7 +5964,7 @@ struct EmbeddingLookupSparseOptionsBuilder { inline flatbuffers::Offset CreateEmbeddingLookupSparseOptions( flatbuffers::FlatBufferBuilder &_fbb, - CombinerType combiner = CombinerType_SUM) { + tflite::CombinerType combiner = tflite::CombinerType_SUM) { EmbeddingLookupSparseOptionsBuilder builder_(_fbb); builder_.add_combiner(combiner); return builder_.Finish(); @@ -6515,11 +6515,11 @@ flatbuffers::Offset CreateLogSoftmaxOptions(flatbuffers::Flat struct CastOptionsT : public flatbuffers::NativeTable { typedef CastOptions TableType; - TensorType in_data_type; - TensorType out_data_type; + tflite::TensorType in_data_type; + tflite::TensorType out_data_type; CastOptionsT() - : in_data_type(TensorType_FLOAT32), - out_data_type(TensorType_FLOAT32) { + : in_data_type(tflite::TensorType_FLOAT32), + out_data_type(tflite::TensorType_FLOAT32) { } }; @@ -6529,11 +6529,11 @@ struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_IN_DATA_TYPE = 4, VT_OUT_DATA_TYPE = 6 }; - TensorType in_data_type() const { - return static_cast(GetField(VT_IN_DATA_TYPE, 0)); + tflite::TensorType in_data_type() const { + return static_cast(GetField(VT_IN_DATA_TYPE, 0)); } - TensorType out_data_type() const { - return static_cast(GetField(VT_OUT_DATA_TYPE, 0)); + tflite::TensorType out_data_type() const { + return static_cast(GetField(VT_OUT_DATA_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -6549,10 +6549,10 @@ struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct CastOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_in_data_type(TensorType in_data_type) { + void add_in_data_type(tflite::TensorType in_data_type) { fbb_.AddElement(CastOptions::VT_IN_DATA_TYPE, static_cast(in_data_type), 0); } - void add_out_data_type(TensorType out_data_type) { + void add_out_data_type(tflite::TensorType out_data_type) { fbb_.AddElement(CastOptions::VT_OUT_DATA_TYPE, static_cast(out_data_type), 0); } explicit CastOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -6569,8 +6569,8 @@ struct CastOptionsBuilder { inline flatbuffers::Offset CreateCastOptions( flatbuffers::FlatBufferBuilder &_fbb, - TensorType in_data_type = TensorType_FLOAT32, - TensorType out_data_type = TensorType_FLOAT32) { + tflite::TensorType in_data_type = tflite::TensorType_FLOAT32, + tflite::TensorType out_data_type = tflite::TensorType_FLOAT32) { CastOptionsBuilder builder_(_fbb); builder_.add_out_data_type(out_data_type); builder_.add_in_data_type(in_data_type); @@ -6701,9 +6701,9 @@ flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilde struct ArgMaxOptionsT : public flatbuffers::NativeTable { typedef ArgMaxOptions TableType; - TensorType output_type; + tflite::TensorType output_type; ArgMaxOptionsT() - : output_type(TensorType_FLOAT32) { + : output_type(tflite::TensorType_FLOAT32) { } }; @@ -6712,8 +6712,8 @@ struct ArgMaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_OUTPUT_TYPE = 4 }; - TensorType output_type() const { - return static_cast(GetField(VT_OUTPUT_TYPE, 0)); + tflite::TensorType output_type() const { + return static_cast(GetField(VT_OUTPUT_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -6728,7 +6728,7 @@ struct ArgMaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct ArgMaxOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_output_type(TensorType output_type) { + void add_output_type(tflite::TensorType output_type) { fbb_.AddElement(ArgMaxOptions::VT_OUTPUT_TYPE, static_cast(output_type), 0); } explicit ArgMaxOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -6745,7 +6745,7 @@ struct ArgMaxOptionsBuilder { inline flatbuffers::Offset CreateArgMaxOptions( flatbuffers::FlatBufferBuilder &_fbb, - TensorType output_type = TensorType_FLOAT32) { + tflite::TensorType output_type = tflite::TensorType_FLOAT32) { ArgMaxOptionsBuilder builder_(_fbb); builder_.add_output_type(output_type); return builder_.Finish(); @@ -6755,9 +6755,9 @@ flatbuffers::Offset CreateArgMaxOptions(flatbuffers::FlatBufferBu struct ArgMinOptionsT : public flatbuffers::NativeTable { typedef ArgMinOptions TableType; - TensorType output_type; + tflite::TensorType output_type; ArgMinOptionsT() - : output_type(TensorType_FLOAT32) { + : output_type(tflite::TensorType_FLOAT32) { } }; @@ -6766,8 +6766,8 @@ struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_OUTPUT_TYPE = 4 }; - TensorType output_type() const { - return static_cast(GetField(VT_OUTPUT_TYPE, 0)); + tflite::TensorType output_type() const { + return static_cast(GetField(VT_OUTPUT_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -6782,7 +6782,7 @@ struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct ArgMinOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_output_type(TensorType output_type) { + void add_output_type(tflite::TensorType output_type) { fbb_.AddElement(ArgMinOptions::VT_OUTPUT_TYPE, static_cast(output_type), 0); } explicit ArgMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -6799,7 +6799,7 @@ struct ArgMinOptionsBuilder { inline flatbuffers::Offset CreateArgMinOptions( flatbuffers::FlatBufferBuilder &_fbb, - TensorType output_type = TensorType_FLOAT32) { + tflite::TensorType output_type = tflite::TensorType_FLOAT32) { ArgMinOptionsBuilder builder_(_fbb); builder_.add_output_type(output_type); return builder_.Finish(); @@ -7089,11 +7089,11 @@ flatbuffers::Offset CreateSliceOptions(flatbuffers::FlatBufferBuil struct TransposeConvOptionsT : public flatbuffers::NativeTable { typedef TransposeConvOptions TableType; - Padding padding; + tflite::Padding padding; int32_t stride_w; int32_t stride_h; TransposeConvOptionsT() - : padding(Padding_SAME), + : padding(tflite::Padding_SAME), stride_w(0), stride_h(0) { } @@ -7106,8 +7106,8 @@ struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table VT_STRIDE_W = 6, VT_STRIDE_H = 8 }; - Padding padding() const { - return static_cast(GetField(VT_PADDING, 0)); + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); } int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); @@ -7130,7 +7130,7 @@ struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table struct TransposeConvOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_padding(Padding padding) { + void add_padding(tflite::Padding padding) { fbb_.AddElement(TransposeConvOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { @@ -7153,7 +7153,7 @@ struct TransposeConvOptionsBuilder { inline flatbuffers::Offset CreateTransposeConvOptions( flatbuffers::FlatBufferBuilder &_fbb, - Padding padding = Padding_SAME, + tflite::Padding padding = tflite::Padding_SAME, int32_t stride_w = 0, int32_t stride_h = 0) { TransposeConvOptionsBuilder builder_(_fbb); @@ -7341,9 +7341,9 @@ flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBuff struct ShapeOptionsT : public flatbuffers::NativeTable { typedef ShapeOptions TableType; - TensorType out_type; + tflite::TensorType out_type; ShapeOptionsT() - : out_type(TensorType_FLOAT32) { + : out_type(tflite::TensorType_FLOAT32) { } }; @@ -7352,8 +7352,8 @@ struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_OUT_TYPE = 4 }; - TensorType out_type() const { - return static_cast(GetField(VT_OUT_TYPE, 0)); + tflite::TensorType out_type() const { + return static_cast(GetField(VT_OUT_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -7368,7 +7368,7 @@ struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct ShapeOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_out_type(TensorType out_type) { + void add_out_type(tflite::TensorType out_type) { fbb_.AddElement(ShapeOptions::VT_OUT_TYPE, static_cast(out_type), 0); } explicit ShapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -7385,7 +7385,7 @@ struct ShapeOptionsBuilder { inline flatbuffers::Offset CreateShapeOptions( flatbuffers::FlatBufferBuilder &_fbb, - TensorType out_type = TensorType_FLOAT32) { + tflite::TensorType out_type = tflite::TensorType_FLOAT32) { ShapeOptionsBuilder builder_(_fbb); builder_.add_out_type(out_type); return builder_.Finish(); @@ -8285,9 +8285,9 @@ flatbuffers::Offset CreateSquaredDifferenceOptions(fla struct MirrorPadOptionsT : public flatbuffers::NativeTable { typedef MirrorPadOptions TableType; - MirrorPadMode mode; + tflite::MirrorPadMode mode; MirrorPadOptionsT() - : mode(MirrorPadMode_REFLECT) { + : mode(tflite::MirrorPadMode_REFLECT) { } }; @@ -8296,8 +8296,8 @@ struct MirrorPadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_MODE = 4 }; - MirrorPadMode mode() const { - return static_cast(GetField(VT_MODE, 0)); + tflite::MirrorPadMode mode() const { + return static_cast(GetField(VT_MODE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -8312,7 +8312,7 @@ struct MirrorPadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct MirrorPadOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_mode(MirrorPadMode mode) { + void add_mode(tflite::MirrorPadMode mode) { fbb_.AddElement(MirrorPadOptions::VT_MODE, static_cast(mode), 0); } explicit MirrorPadOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -8329,7 +8329,7 @@ struct MirrorPadOptionsBuilder { inline flatbuffers::Offset CreateMirrorPadOptions( flatbuffers::FlatBufferBuilder &_fbb, - MirrorPadMode mode = MirrorPadMode_REFLECT) { + tflite::MirrorPadMode mode = tflite::MirrorPadMode_REFLECT) { MirrorPadOptionsBuilder builder_(_fbb); builder_.add_mode(mode); return builder_.Finish(); @@ -8339,9 +8339,9 @@ flatbuffers::Offset CreateMirrorPadOptions(flatbuffers::FlatBu struct UniqueOptionsT : public flatbuffers::NativeTable { typedef UniqueOptions TableType; - TensorType idx_out_type; + tflite::TensorType idx_out_type; UniqueOptionsT() - : idx_out_type(TensorType_INT32) { + : idx_out_type(tflite::TensorType_INT32) { } }; @@ -8350,8 +8350,8 @@ struct UniqueOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_IDX_OUT_TYPE = 4 }; - TensorType idx_out_type() const { - return static_cast(GetField(VT_IDX_OUT_TYPE, 2)); + tflite::TensorType idx_out_type() const { + return static_cast(GetField(VT_IDX_OUT_TYPE, 2)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -8366,7 +8366,7 @@ struct UniqueOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct UniqueOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_idx_out_type(TensorType idx_out_type) { + void add_idx_out_type(tflite::TensorType idx_out_type) { fbb_.AddElement(UniqueOptions::VT_IDX_OUT_TYPE, static_cast(idx_out_type), 2); } explicit UniqueOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -8383,7 +8383,7 @@ struct UniqueOptionsBuilder { inline flatbuffers::Offset CreateUniqueOptions( flatbuffers::FlatBufferBuilder &_fbb, - TensorType idx_out_type = TensorType_INT32) { + tflite::TensorType idx_out_type = tflite::TensorType_INT32) { UniqueOptionsBuilder builder_(_fbb); builder_.add_idx_out_type(idx_out_type); return builder_.Finish(); @@ -9111,11 +9111,11 @@ flatbuffers::Offset CreateSegmentSumOptions(flatbuffers::Flat struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; - BuiltinOperator builtin_code; + tflite::BuiltinOperator builtin_code; std::string custom_code; int32_t version; OperatorCodeT() - : builtin_code(BuiltinOperator_ADD), + : builtin_code(tflite::BuiltinOperator_ADD), version(1) { } }; @@ -9127,8 +9127,8 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_CUSTOM_CODE = 6, VT_VERSION = 8 }; - BuiltinOperator builtin_code() const { - return static_cast(GetField(VT_BUILTIN_CODE, 0)); + tflite::BuiltinOperator builtin_code() const { + return static_cast(GetField(VT_BUILTIN_CODE, 0)); } const flatbuffers::String *custom_code() const { return GetPointer(VT_CUSTOM_CODE); @@ -9152,7 +9152,7 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct OperatorCodeBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_builtin_code(BuiltinOperator builtin_code) { + void add_builtin_code(tflite::BuiltinOperator builtin_code) { fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); } void add_custom_code(flatbuffers::Offset custom_code) { @@ -9175,7 +9175,7 @@ struct OperatorCodeBuilder { inline flatbuffers::Offset CreateOperatorCode( flatbuffers::FlatBufferBuilder &_fbb, - BuiltinOperator builtin_code = BuiltinOperator_ADD, + tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD, flatbuffers::Offset custom_code = 0, int32_t version = 1) { OperatorCodeBuilder builder_(_fbb); @@ -9187,7 +9187,7 @@ inline flatbuffers::Offset CreateOperatorCode( inline flatbuffers::Offset CreateOperatorCodeDirect( flatbuffers::FlatBufferBuilder &_fbb, - BuiltinOperator builtin_code = BuiltinOperator_ADD, + tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD, const char *custom_code = nullptr, int32_t version = 1) { auto custom_code__ = custom_code ? _fbb.CreateString(custom_code) : 0; @@ -9205,14 +9205,14 @@ struct OperatorT : public flatbuffers::NativeTable { uint32_t opcode_index; std::vector inputs; std::vector outputs; - BuiltinOptionsUnion builtin_options; + tflite::BuiltinOptionsUnion builtin_options; std::vector custom_options; - CustomOptionsFormat custom_options_format; + tflite::CustomOptionsFormat custom_options_format; std::vector mutating_variable_inputs; std::vector intermediates; OperatorT() : opcode_index(0), - custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { + custom_options_format(tflite::CustomOptionsFormat_FLEXBUFFERS) { } }; @@ -9238,318 +9238,318 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *outputs() const { return GetPointer *>(VT_OUTPUTS); } - BuiltinOptions builtin_options_type() const { - return static_cast(GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); + tflite::BuiltinOptions builtin_options_type() const { + return static_cast(GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); } const void *builtin_options() const { return GetPointer(VT_BUILTIN_OPTIONS); } template const T *builtin_options_as() const; - const Conv2DOptions *builtin_options_as_Conv2DOptions() const { - return builtin_options_type() == BuiltinOptions_Conv2DOptions ? static_cast(builtin_options()) : nullptr; + const tflite::Conv2DOptions *builtin_options_as_Conv2DOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_Conv2DOptions ? static_cast(builtin_options()) : nullptr; } - const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const { - return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions ? static_cast(builtin_options()) : nullptr; + const tflite::DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DepthwiseConv2DOptions ? static_cast(builtin_options()) : nullptr; } - const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const { - return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ConcatEmbeddingsOptions ? static_cast(builtin_options()) : nullptr; } - const LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { - return builtin_options_type() == BuiltinOptions_LSHProjectionOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LSHProjectionOptions ? static_cast(builtin_options()) : nullptr; } - const Pool2DOptions *builtin_options_as_Pool2DOptions() const { - return builtin_options_type() == BuiltinOptions_Pool2DOptions ? static_cast(builtin_options()) : nullptr; + const tflite::Pool2DOptions *builtin_options_as_Pool2DOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_Pool2DOptions ? static_cast(builtin_options()) : nullptr; } - const SVDFOptions *builtin_options_as_SVDFOptions() const { - return builtin_options_type() == BuiltinOptions_SVDFOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SVDFOptions *builtin_options_as_SVDFOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SVDFOptions ? static_cast(builtin_options()) : nullptr; } - const RNNOptions *builtin_options_as_RNNOptions() const { - return builtin_options_type() == BuiltinOptions_RNNOptions ? static_cast(builtin_options()) : nullptr; + const tflite::RNNOptions *builtin_options_as_RNNOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_RNNOptions ? static_cast(builtin_options()) : nullptr; } - const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const { - return builtin_options_type() == BuiltinOptions_FullyConnectedOptions ? static_cast(builtin_options()) : nullptr; + const tflite::FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FullyConnectedOptions ? static_cast(builtin_options()) : nullptr; } - const SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { - return builtin_options_type() == BuiltinOptions_SoftmaxOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SoftmaxOptions ? static_cast(builtin_options()) : nullptr; } - const ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { - return builtin_options_type() == BuiltinOptions_ConcatenationOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ConcatenationOptions ? static_cast(builtin_options()) : nullptr; } - const AddOptions *builtin_options_as_AddOptions() const { - return builtin_options_type() == BuiltinOptions_AddOptions ? static_cast(builtin_options()) : nullptr; + const tflite::AddOptions *builtin_options_as_AddOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_AddOptions ? static_cast(builtin_options()) : nullptr; } - const L2NormOptions *builtin_options_as_L2NormOptions() const { - return builtin_options_type() == BuiltinOptions_L2NormOptions ? static_cast(builtin_options()) : nullptr; + const tflite::L2NormOptions *builtin_options_as_L2NormOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_L2NormOptions ? static_cast(builtin_options()) : nullptr; } - const LocalResponseNormalizationOptions *builtin_options_as_LocalResponseNormalizationOptions() const { - return builtin_options_type() == BuiltinOptions_LocalResponseNormalizationOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LocalResponseNormalizationOptions *builtin_options_as_LocalResponseNormalizationOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LocalResponseNormalizationOptions ? static_cast(builtin_options()) : nullptr; } - const LSTMOptions *builtin_options_as_LSTMOptions() const { - return builtin_options_type() == BuiltinOptions_LSTMOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LSTMOptions *builtin_options_as_LSTMOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LSTMOptions ? static_cast(builtin_options()) : nullptr; } - const ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() const { - return builtin_options_type() == BuiltinOptions_ResizeBilinearOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ResizeBilinearOptions ? static_cast(builtin_options()) : nullptr; } - const CallOptions *builtin_options_as_CallOptions() const { - return builtin_options_type() == BuiltinOptions_CallOptions ? static_cast(builtin_options()) : nullptr; + const tflite::CallOptions *builtin_options_as_CallOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CallOptions ? static_cast(builtin_options()) : nullptr; } - const ReshapeOptions *builtin_options_as_ReshapeOptions() const { - return builtin_options_type() == BuiltinOptions_ReshapeOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ReshapeOptions *builtin_options_as_ReshapeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ReshapeOptions ? static_cast(builtin_options()) : nullptr; } - const SkipGramOptions *builtin_options_as_SkipGramOptions() const { - return builtin_options_type() == BuiltinOptions_SkipGramOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SkipGramOptions *builtin_options_as_SkipGramOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SkipGramOptions ? static_cast(builtin_options()) : nullptr; } - const SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const { - return builtin_options_type() == BuiltinOptions_SpaceToDepthOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SpaceToDepthOptions ? static_cast(builtin_options()) : nullptr; } - const EmbeddingLookupSparseOptions *builtin_options_as_EmbeddingLookupSparseOptions() const { - return builtin_options_type() == BuiltinOptions_EmbeddingLookupSparseOptions ? static_cast(builtin_options()) : nullptr; + const tflite::EmbeddingLookupSparseOptions *builtin_options_as_EmbeddingLookupSparseOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_EmbeddingLookupSparseOptions ? static_cast(builtin_options()) : nullptr; } - const MulOptions *builtin_options_as_MulOptions() const { - return builtin_options_type() == BuiltinOptions_MulOptions ? static_cast(builtin_options()) : nullptr; + const tflite::MulOptions *builtin_options_as_MulOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MulOptions ? static_cast(builtin_options()) : nullptr; } - const PadOptions *builtin_options_as_PadOptions() const { - return builtin_options_type() == BuiltinOptions_PadOptions ? static_cast(builtin_options()) : nullptr; + const tflite::PadOptions *builtin_options_as_PadOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_PadOptions ? static_cast(builtin_options()) : nullptr; } - const GatherOptions *builtin_options_as_GatherOptions() const { - return builtin_options_type() == BuiltinOptions_GatherOptions ? static_cast(builtin_options()) : nullptr; + const tflite::GatherOptions *builtin_options_as_GatherOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GatherOptions ? static_cast(builtin_options()) : nullptr; } - const BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() const { - return builtin_options_type() == BuiltinOptions_BatchToSpaceNDOptions ? static_cast(builtin_options()) : nullptr; + const tflite::BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BatchToSpaceNDOptions ? static_cast(builtin_options()) : nullptr; } - const SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() const { - return builtin_options_type() == BuiltinOptions_SpaceToBatchNDOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SpaceToBatchNDOptions ? static_cast(builtin_options()) : nullptr; } - const TransposeOptions *builtin_options_as_TransposeOptions() const { - return builtin_options_type() == BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; + const tflite::TransposeOptions *builtin_options_as_TransposeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; } - const ReducerOptions *builtin_options_as_ReducerOptions() const { - return builtin_options_type() == BuiltinOptions_ReducerOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ReducerOptions *builtin_options_as_ReducerOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ReducerOptions ? static_cast(builtin_options()) : nullptr; } - const SubOptions *builtin_options_as_SubOptions() const { - return builtin_options_type() == BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SubOptions *builtin_options_as_SubOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; } - const DivOptions *builtin_options_as_DivOptions() const { - return builtin_options_type() == BuiltinOptions_DivOptions ? static_cast(builtin_options()) : nullptr; + const tflite::DivOptions *builtin_options_as_DivOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DivOptions ? static_cast(builtin_options()) : nullptr; } - const SqueezeOptions *builtin_options_as_SqueezeOptions() const { - return builtin_options_type() == BuiltinOptions_SqueezeOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SqueezeOptions *builtin_options_as_SqueezeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SqueezeOptions ? static_cast(builtin_options()) : nullptr; } - const SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const { - return builtin_options_type() == BuiltinOptions_SequenceRNNOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SequenceRNNOptions ? static_cast(builtin_options()) : nullptr; } - const StridedSliceOptions *builtin_options_as_StridedSliceOptions() const { - return builtin_options_type() == BuiltinOptions_StridedSliceOptions ? static_cast(builtin_options()) : nullptr; + const tflite::StridedSliceOptions *builtin_options_as_StridedSliceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_StridedSliceOptions ? static_cast(builtin_options()) : nullptr; } - const ExpOptions *builtin_options_as_ExpOptions() const { - return builtin_options_type() == BuiltinOptions_ExpOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ExpOptions *builtin_options_as_ExpOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ExpOptions ? static_cast(builtin_options()) : nullptr; } - const TopKV2Options *builtin_options_as_TopKV2Options() const { - return builtin_options_type() == BuiltinOptions_TopKV2Options ? static_cast(builtin_options()) : nullptr; + const tflite::TopKV2Options *builtin_options_as_TopKV2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_TopKV2Options ? static_cast(builtin_options()) : nullptr; } - const SplitOptions *builtin_options_as_SplitOptions() const { - return builtin_options_type() == BuiltinOptions_SplitOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SplitOptions *builtin_options_as_SplitOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SplitOptions ? static_cast(builtin_options()) : nullptr; } - const LogSoftmaxOptions *builtin_options_as_LogSoftmaxOptions() const { - return builtin_options_type() == BuiltinOptions_LogSoftmaxOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LogSoftmaxOptions *builtin_options_as_LogSoftmaxOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LogSoftmaxOptions ? static_cast(builtin_options()) : nullptr; } - const CastOptions *builtin_options_as_CastOptions() const { - return builtin_options_type() == BuiltinOptions_CastOptions ? static_cast(builtin_options()) : nullptr; + const tflite::CastOptions *builtin_options_as_CastOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CastOptions ? static_cast(builtin_options()) : nullptr; } - const DequantizeOptions *builtin_options_as_DequantizeOptions() const { - return builtin_options_type() == BuiltinOptions_DequantizeOptions ? static_cast(builtin_options()) : nullptr; + const tflite::DequantizeOptions *builtin_options_as_DequantizeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DequantizeOptions ? static_cast(builtin_options()) : nullptr; } - const MaximumMinimumOptions *builtin_options_as_MaximumMinimumOptions() const { - return builtin_options_type() == BuiltinOptions_MaximumMinimumOptions ? static_cast(builtin_options()) : nullptr; + const tflite::MaximumMinimumOptions *builtin_options_as_MaximumMinimumOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MaximumMinimumOptions ? static_cast(builtin_options()) : nullptr; } - const ArgMaxOptions *builtin_options_as_ArgMaxOptions() const { - return builtin_options_type() == BuiltinOptions_ArgMaxOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ArgMaxOptions *builtin_options_as_ArgMaxOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ArgMaxOptions ? static_cast(builtin_options()) : nullptr; } - const LessOptions *builtin_options_as_LessOptions() const { - return builtin_options_type() == BuiltinOptions_LessOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LessOptions *builtin_options_as_LessOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LessOptions ? static_cast(builtin_options()) : nullptr; } - const NegOptions *builtin_options_as_NegOptions() const { - return builtin_options_type() == BuiltinOptions_NegOptions ? static_cast(builtin_options()) : nullptr; + const tflite::NegOptions *builtin_options_as_NegOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_NegOptions ? static_cast(builtin_options()) : nullptr; } - const PadV2Options *builtin_options_as_PadV2Options() const { - return builtin_options_type() == BuiltinOptions_PadV2Options ? static_cast(builtin_options()) : nullptr; + const tflite::PadV2Options *builtin_options_as_PadV2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_PadV2Options ? static_cast(builtin_options()) : nullptr; } - const GreaterOptions *builtin_options_as_GreaterOptions() const { - return builtin_options_type() == BuiltinOptions_GreaterOptions ? static_cast(builtin_options()) : nullptr; + const tflite::GreaterOptions *builtin_options_as_GreaterOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GreaterOptions ? static_cast(builtin_options()) : nullptr; } - const GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const { - return builtin_options_type() == BuiltinOptions_GreaterEqualOptions ? static_cast(builtin_options()) : nullptr; + const tflite::GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GreaterEqualOptions ? static_cast(builtin_options()) : nullptr; } - const LessEqualOptions *builtin_options_as_LessEqualOptions() const { - return builtin_options_type() == BuiltinOptions_LessEqualOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LessEqualOptions *builtin_options_as_LessEqualOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LessEqualOptions ? static_cast(builtin_options()) : nullptr; } - const SelectOptions *builtin_options_as_SelectOptions() const { - return builtin_options_type() == BuiltinOptions_SelectOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SelectOptions *builtin_options_as_SelectOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SelectOptions ? static_cast(builtin_options()) : nullptr; } - const SliceOptions *builtin_options_as_SliceOptions() const { - return builtin_options_type() == BuiltinOptions_SliceOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SliceOptions *builtin_options_as_SliceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SliceOptions ? static_cast(builtin_options()) : nullptr; } - const TransposeConvOptions *builtin_options_as_TransposeConvOptions() const { - return builtin_options_type() == BuiltinOptions_TransposeConvOptions ? static_cast(builtin_options()) : nullptr; + const tflite::TransposeConvOptions *builtin_options_as_TransposeConvOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_TransposeConvOptions ? static_cast(builtin_options()) : nullptr; } - const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { - return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; } - const TileOptions *builtin_options_as_TileOptions() const { - return builtin_options_type() == BuiltinOptions_TileOptions ? static_cast(builtin_options()) : nullptr; + const tflite::TileOptions *builtin_options_as_TileOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_TileOptions ? static_cast(builtin_options()) : nullptr; } - const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { - return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ExpandDimsOptions ? static_cast(builtin_options()) : nullptr; } - const EqualOptions *builtin_options_as_EqualOptions() const { - return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast(builtin_options()) : nullptr; + const tflite::EqualOptions *builtin_options_as_EqualOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_EqualOptions ? static_cast(builtin_options()) : nullptr; } - const NotEqualOptions *builtin_options_as_NotEqualOptions() const { - return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; + const tflite::NotEqualOptions *builtin_options_as_NotEqualOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; } - const ShapeOptions *builtin_options_as_ShapeOptions() const { - return builtin_options_type() == BuiltinOptions_ShapeOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ShapeOptions *builtin_options_as_ShapeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ShapeOptions ? static_cast(builtin_options()) : nullptr; } - const PowOptions *builtin_options_as_PowOptions() const { - return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast(builtin_options()) : nullptr; + const tflite::PowOptions *builtin_options_as_PowOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_PowOptions ? static_cast(builtin_options()) : nullptr; } - const ArgMinOptions *builtin_options_as_ArgMinOptions() const { - return builtin_options_type() == BuiltinOptions_ArgMinOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ArgMinOptions *builtin_options_as_ArgMinOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ArgMinOptions ? static_cast(builtin_options()) : nullptr; } - const FakeQuantOptions *builtin_options_as_FakeQuantOptions() const { - return builtin_options_type() == BuiltinOptions_FakeQuantOptions ? static_cast(builtin_options()) : nullptr; + const tflite::FakeQuantOptions *builtin_options_as_FakeQuantOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FakeQuantOptions ? static_cast(builtin_options()) : nullptr; } - const PackOptions *builtin_options_as_PackOptions() const { - return builtin_options_type() == BuiltinOptions_PackOptions ? static_cast(builtin_options()) : nullptr; + const tflite::PackOptions *builtin_options_as_PackOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_PackOptions ? static_cast(builtin_options()) : nullptr; } - const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const { - return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LogicalOrOptions *builtin_options_as_LogicalOrOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LogicalOrOptions ? static_cast(builtin_options()) : nullptr; } - const OneHotOptions *builtin_options_as_OneHotOptions() const { - return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast(builtin_options()) : nullptr; + const tflite::OneHotOptions *builtin_options_as_OneHotOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_OneHotOptions ? static_cast(builtin_options()) : nullptr; } - const LogicalAndOptions *builtin_options_as_LogicalAndOptions() const { - return builtin_options_type() == BuiltinOptions_LogicalAndOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LogicalAndOptions *builtin_options_as_LogicalAndOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LogicalAndOptions ? static_cast(builtin_options()) : nullptr; } - const LogicalNotOptions *builtin_options_as_LogicalNotOptions() const { - return builtin_options_type() == BuiltinOptions_LogicalNotOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LogicalNotOptions *builtin_options_as_LogicalNotOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LogicalNotOptions ? static_cast(builtin_options()) : nullptr; } - const UnpackOptions *builtin_options_as_UnpackOptions() const { - return builtin_options_type() == BuiltinOptions_UnpackOptions ? static_cast(builtin_options()) : nullptr; + const tflite::UnpackOptions *builtin_options_as_UnpackOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnpackOptions ? static_cast(builtin_options()) : nullptr; } - const FloorDivOptions *builtin_options_as_FloorDivOptions() const { - return builtin_options_type() == BuiltinOptions_FloorDivOptions ? static_cast(builtin_options()) : nullptr; + const tflite::FloorDivOptions *builtin_options_as_FloorDivOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FloorDivOptions ? static_cast(builtin_options()) : nullptr; } - const SquareOptions *builtin_options_as_SquareOptions() const { - return builtin_options_type() == BuiltinOptions_SquareOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SquareOptions *builtin_options_as_SquareOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SquareOptions ? static_cast(builtin_options()) : nullptr; } - const ZerosLikeOptions *builtin_options_as_ZerosLikeOptions() const { - return builtin_options_type() == BuiltinOptions_ZerosLikeOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ZerosLikeOptions *builtin_options_as_ZerosLikeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ZerosLikeOptions ? static_cast(builtin_options()) : nullptr; } - const FillOptions *builtin_options_as_FillOptions() const { - return builtin_options_type() == BuiltinOptions_FillOptions ? static_cast(builtin_options()) : nullptr; + const tflite::FillOptions *builtin_options_as_FillOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FillOptions ? static_cast(builtin_options()) : nullptr; } - const BidirectionalSequenceLSTMOptions *builtin_options_as_BidirectionalSequenceLSTMOptions() const { - return builtin_options_type() == BuiltinOptions_BidirectionalSequenceLSTMOptions ? static_cast(builtin_options()) : nullptr; + const tflite::BidirectionalSequenceLSTMOptions *builtin_options_as_BidirectionalSequenceLSTMOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions ? static_cast(builtin_options()) : nullptr; } - const BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const { - return builtin_options_type() == BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast(builtin_options()) : nullptr; + const tflite::BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast(builtin_options()) : nullptr; } - const UnidirectionalSequenceLSTMOptions *builtin_options_as_UnidirectionalSequenceLSTMOptions() const { - return builtin_options_type() == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? static_cast(builtin_options()) : nullptr; + const tflite::UnidirectionalSequenceLSTMOptions *builtin_options_as_UnidirectionalSequenceLSTMOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions ? static_cast(builtin_options()) : nullptr; } - const FloorModOptions *builtin_options_as_FloorModOptions() const { - return builtin_options_type() == BuiltinOptions_FloorModOptions ? static_cast(builtin_options()) : nullptr; + const tflite::FloorModOptions *builtin_options_as_FloorModOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FloorModOptions ? static_cast(builtin_options()) : nullptr; } - const RangeOptions *builtin_options_as_RangeOptions() const { - return builtin_options_type() == BuiltinOptions_RangeOptions ? static_cast(builtin_options()) : nullptr; + const tflite::RangeOptions *builtin_options_as_RangeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_RangeOptions ? static_cast(builtin_options()) : nullptr; } - const ResizeNearestNeighborOptions *builtin_options_as_ResizeNearestNeighborOptions() const { - return builtin_options_type() == BuiltinOptions_ResizeNearestNeighborOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ResizeNearestNeighborOptions *builtin_options_as_ResizeNearestNeighborOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ResizeNearestNeighborOptions ? static_cast(builtin_options()) : nullptr; } - const LeakyReluOptions *builtin_options_as_LeakyReluOptions() const { - return builtin_options_type() == BuiltinOptions_LeakyReluOptions ? static_cast(builtin_options()) : nullptr; + const tflite::LeakyReluOptions *builtin_options_as_LeakyReluOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LeakyReluOptions ? static_cast(builtin_options()) : nullptr; } - const SquaredDifferenceOptions *builtin_options_as_SquaredDifferenceOptions() const { - return builtin_options_type() == BuiltinOptions_SquaredDifferenceOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SquaredDifferenceOptions *builtin_options_as_SquaredDifferenceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SquaredDifferenceOptions ? static_cast(builtin_options()) : nullptr; } - const MirrorPadOptions *builtin_options_as_MirrorPadOptions() const { - return builtin_options_type() == BuiltinOptions_MirrorPadOptions ? static_cast(builtin_options()) : nullptr; + const tflite::MirrorPadOptions *builtin_options_as_MirrorPadOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MirrorPadOptions ? static_cast(builtin_options()) : nullptr; } - const AbsOptions *builtin_options_as_AbsOptions() const { - return builtin_options_type() == BuiltinOptions_AbsOptions ? static_cast(builtin_options()) : nullptr; + const tflite::AbsOptions *builtin_options_as_AbsOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_AbsOptions ? static_cast(builtin_options()) : nullptr; } - const SplitVOptions *builtin_options_as_SplitVOptions() const { - return builtin_options_type() == BuiltinOptions_SplitVOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SplitVOptions *builtin_options_as_SplitVOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SplitVOptions ? static_cast(builtin_options()) : nullptr; } - const UniqueOptions *builtin_options_as_UniqueOptions() const { - return builtin_options_type() == BuiltinOptions_UniqueOptions ? static_cast(builtin_options()) : nullptr; + const tflite::UniqueOptions *builtin_options_as_UniqueOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UniqueOptions ? static_cast(builtin_options()) : nullptr; } - const ReverseV2Options *builtin_options_as_ReverseV2Options() const { - return builtin_options_type() == BuiltinOptions_ReverseV2Options ? static_cast(builtin_options()) : nullptr; + const tflite::ReverseV2Options *builtin_options_as_ReverseV2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_ReverseV2Options ? static_cast(builtin_options()) : nullptr; } - const AddNOptions *builtin_options_as_AddNOptions() const { - return builtin_options_type() == BuiltinOptions_AddNOptions ? static_cast(builtin_options()) : nullptr; + const tflite::AddNOptions *builtin_options_as_AddNOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_AddNOptions ? static_cast(builtin_options()) : nullptr; } - const GatherNdOptions *builtin_options_as_GatherNdOptions() const { - return builtin_options_type() == BuiltinOptions_GatherNdOptions ? static_cast(builtin_options()) : nullptr; + const tflite::GatherNdOptions *builtin_options_as_GatherNdOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GatherNdOptions ? static_cast(builtin_options()) : nullptr; } - const CosOptions *builtin_options_as_CosOptions() const { - return builtin_options_type() == BuiltinOptions_CosOptions ? static_cast(builtin_options()) : nullptr; + const tflite::CosOptions *builtin_options_as_CosOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CosOptions ? static_cast(builtin_options()) : nullptr; } - const WhereOptions *builtin_options_as_WhereOptions() const { - return builtin_options_type() == BuiltinOptions_WhereOptions ? static_cast(builtin_options()) : nullptr; + const tflite::WhereOptions *builtin_options_as_WhereOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_WhereOptions ? static_cast(builtin_options()) : nullptr; } - const RankOptions *builtin_options_as_RankOptions() const { - return builtin_options_type() == BuiltinOptions_RankOptions ? static_cast(builtin_options()) : nullptr; + const tflite::RankOptions *builtin_options_as_RankOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_RankOptions ? static_cast(builtin_options()) : nullptr; } - const ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const { - return builtin_options_type() == BuiltinOptions_ReverseSequenceOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ReverseSequenceOptions ? static_cast(builtin_options()) : nullptr; } - const MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const { - return builtin_options_type() == BuiltinOptions_MatrixDiagOptions ? static_cast(builtin_options()) : nullptr; + const tflite::MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MatrixDiagOptions ? static_cast(builtin_options()) : nullptr; } - const QuantizeOptions *builtin_options_as_QuantizeOptions() const { - return builtin_options_type() == BuiltinOptions_QuantizeOptions ? static_cast(builtin_options()) : nullptr; + const tflite::QuantizeOptions *builtin_options_as_QuantizeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_QuantizeOptions ? static_cast(builtin_options()) : nullptr; } - const MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const { - return builtin_options_type() == BuiltinOptions_MatrixSetDiagOptions ? static_cast(builtin_options()) : nullptr; + const tflite::MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MatrixSetDiagOptions ? static_cast(builtin_options()) : nullptr; } - const HardSwishOptions *builtin_options_as_HardSwishOptions() const { - return builtin_options_type() == BuiltinOptions_HardSwishOptions ? static_cast(builtin_options()) : nullptr; + const tflite::HardSwishOptions *builtin_options_as_HardSwishOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_HardSwishOptions ? static_cast(builtin_options()) : nullptr; } - const IfOptions *builtin_options_as_IfOptions() const { - return builtin_options_type() == BuiltinOptions_IfOptions ? static_cast(builtin_options()) : nullptr; + const tflite::IfOptions *builtin_options_as_IfOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_IfOptions ? static_cast(builtin_options()) : nullptr; } - const WhileOptions *builtin_options_as_WhileOptions() const { - return builtin_options_type() == BuiltinOptions_WhileOptions ? static_cast(builtin_options()) : nullptr; + const tflite::WhileOptions *builtin_options_as_WhileOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_WhileOptions ? static_cast(builtin_options()) : nullptr; } - const DepthToSpaceOptions *builtin_options_as_DepthToSpaceOptions() const { - return builtin_options_type() == BuiltinOptions_DepthToSpaceOptions ? static_cast(builtin_options()) : nullptr; + const tflite::DepthToSpaceOptions *builtin_options_as_DepthToSpaceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DepthToSpaceOptions ? static_cast(builtin_options()) : nullptr; } - const NonMaxSuppressionV4Options *builtin_options_as_NonMaxSuppressionV4Options() const { - return builtin_options_type() == BuiltinOptions_NonMaxSuppressionV4Options ? static_cast(builtin_options()) : nullptr; + const tflite::NonMaxSuppressionV4Options *builtin_options_as_NonMaxSuppressionV4Options() const { + return builtin_options_type() == tflite::BuiltinOptions_NonMaxSuppressionV4Options ? static_cast(builtin_options()) : nullptr; } - const NonMaxSuppressionV5Options *builtin_options_as_NonMaxSuppressionV5Options() const { - return builtin_options_type() == BuiltinOptions_NonMaxSuppressionV5Options ? static_cast(builtin_options()) : nullptr; + const tflite::NonMaxSuppressionV5Options *builtin_options_as_NonMaxSuppressionV5Options() const { + return builtin_options_type() == tflite::BuiltinOptions_NonMaxSuppressionV5Options ? static_cast(builtin_options()) : nullptr; } - const ScatterNdOptions *builtin_options_as_ScatterNdOptions() const { - return builtin_options_type() == BuiltinOptions_ScatterNdOptions ? static_cast(builtin_options()) : nullptr; + const tflite::ScatterNdOptions *builtin_options_as_ScatterNdOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ScatterNdOptions ? static_cast(builtin_options()) : nullptr; } - const SelectV2Options *builtin_options_as_SelectV2Options() const { - return builtin_options_type() == BuiltinOptions_SelectV2Options ? static_cast(builtin_options()) : nullptr; + const tflite::SelectV2Options *builtin_options_as_SelectV2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_SelectV2Options ? static_cast(builtin_options()) : nullptr; } - const DensifyOptions *builtin_options_as_DensifyOptions() const { - return builtin_options_type() == BuiltinOptions_DensifyOptions ? static_cast(builtin_options()) : nullptr; + const tflite::DensifyOptions *builtin_options_as_DensifyOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DensifyOptions ? static_cast(builtin_options()) : nullptr; } - const SegmentSumOptions *builtin_options_as_SegmentSumOptions() const { - return builtin_options_type() == BuiltinOptions_SegmentSumOptions ? static_cast(builtin_options()) : nullptr; + const tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SegmentSumOptions ? static_cast(builtin_options()) : nullptr; } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } - CustomOptionsFormat custom_options_format() const { - return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); + tflite::CustomOptionsFormat custom_options_format() const { + return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); } const flatbuffers::Vector *mutating_variable_inputs() const { return GetPointer *>(VT_MUTATING_VARIABLE_INPUTS); @@ -9581,403 +9581,403 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -template<> inline const Conv2DOptions *Operator::builtin_options_as() const { +template<> inline const tflite::Conv2DOptions *Operator::builtin_options_as() const { return builtin_options_as_Conv2DOptions(); } -template<> inline const DepthwiseConv2DOptions *Operator::builtin_options_as() const { +template<> inline const tflite::DepthwiseConv2DOptions *Operator::builtin_options_as() const { return builtin_options_as_DepthwiseConv2DOptions(); } -template<> inline const ConcatEmbeddingsOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ConcatEmbeddingsOptions *Operator::builtin_options_as() const { return builtin_options_as_ConcatEmbeddingsOptions(); } -template<> inline const LSHProjectionOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LSHProjectionOptions *Operator::builtin_options_as() const { return builtin_options_as_LSHProjectionOptions(); } -template<> inline const Pool2DOptions *Operator::builtin_options_as() const { +template<> inline const tflite::Pool2DOptions *Operator::builtin_options_as() const { return builtin_options_as_Pool2DOptions(); } -template<> inline const SVDFOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SVDFOptions *Operator::builtin_options_as() const { return builtin_options_as_SVDFOptions(); } -template<> inline const RNNOptions *Operator::builtin_options_as() const { +template<> inline const tflite::RNNOptions *Operator::builtin_options_as() const { return builtin_options_as_RNNOptions(); } -template<> inline const FullyConnectedOptions *Operator::builtin_options_as() const { +template<> inline const tflite::FullyConnectedOptions *Operator::builtin_options_as() const { return builtin_options_as_FullyConnectedOptions(); } -template<> inline const SoftmaxOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SoftmaxOptions *Operator::builtin_options_as() const { return builtin_options_as_SoftmaxOptions(); } -template<> inline const ConcatenationOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ConcatenationOptions *Operator::builtin_options_as() const { return builtin_options_as_ConcatenationOptions(); } -template<> inline const AddOptions *Operator::builtin_options_as() const { +template<> inline const tflite::AddOptions *Operator::builtin_options_as() const { return builtin_options_as_AddOptions(); } -template<> inline const L2NormOptions *Operator::builtin_options_as() const { +template<> inline const tflite::L2NormOptions *Operator::builtin_options_as() const { return builtin_options_as_L2NormOptions(); } -template<> inline const LocalResponseNormalizationOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LocalResponseNormalizationOptions *Operator::builtin_options_as() const { return builtin_options_as_LocalResponseNormalizationOptions(); } -template<> inline const LSTMOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LSTMOptions *Operator::builtin_options_as() const { return builtin_options_as_LSTMOptions(); } -template<> inline const ResizeBilinearOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ResizeBilinearOptions *Operator::builtin_options_as() const { return builtin_options_as_ResizeBilinearOptions(); } -template<> inline const CallOptions *Operator::builtin_options_as() const { +template<> inline const tflite::CallOptions *Operator::builtin_options_as() const { return builtin_options_as_CallOptions(); } -template<> inline const ReshapeOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ReshapeOptions *Operator::builtin_options_as() const { return builtin_options_as_ReshapeOptions(); } -template<> inline const SkipGramOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SkipGramOptions *Operator::builtin_options_as() const { return builtin_options_as_SkipGramOptions(); } -template<> inline const SpaceToDepthOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SpaceToDepthOptions *Operator::builtin_options_as() const { return builtin_options_as_SpaceToDepthOptions(); } -template<> inline const EmbeddingLookupSparseOptions *Operator::builtin_options_as() const { +template<> inline const tflite::EmbeddingLookupSparseOptions *Operator::builtin_options_as() const { return builtin_options_as_EmbeddingLookupSparseOptions(); } -template<> inline const MulOptions *Operator::builtin_options_as() const { +template<> inline const tflite::MulOptions *Operator::builtin_options_as() const { return builtin_options_as_MulOptions(); } -template<> inline const PadOptions *Operator::builtin_options_as() const { +template<> inline const tflite::PadOptions *Operator::builtin_options_as() const { return builtin_options_as_PadOptions(); } -template<> inline const GatherOptions *Operator::builtin_options_as() const { +template<> inline const tflite::GatherOptions *Operator::builtin_options_as() const { return builtin_options_as_GatherOptions(); } -template<> inline const BatchToSpaceNDOptions *Operator::builtin_options_as() const { +template<> inline const tflite::BatchToSpaceNDOptions *Operator::builtin_options_as() const { return builtin_options_as_BatchToSpaceNDOptions(); } -template<> inline const SpaceToBatchNDOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SpaceToBatchNDOptions *Operator::builtin_options_as() const { return builtin_options_as_SpaceToBatchNDOptions(); } -template<> inline const TransposeOptions *Operator::builtin_options_as() const { +template<> inline const tflite::TransposeOptions *Operator::builtin_options_as() const { return builtin_options_as_TransposeOptions(); } -template<> inline const ReducerOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ReducerOptions *Operator::builtin_options_as() const { return builtin_options_as_ReducerOptions(); } -template<> inline const SubOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SubOptions *Operator::builtin_options_as() const { return builtin_options_as_SubOptions(); } -template<> inline const DivOptions *Operator::builtin_options_as() const { +template<> inline const tflite::DivOptions *Operator::builtin_options_as() const { return builtin_options_as_DivOptions(); } -template<> inline const SqueezeOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SqueezeOptions *Operator::builtin_options_as() const { return builtin_options_as_SqueezeOptions(); } -template<> inline const SequenceRNNOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SequenceRNNOptions *Operator::builtin_options_as() const { return builtin_options_as_SequenceRNNOptions(); } -template<> inline const StridedSliceOptions *Operator::builtin_options_as() const { +template<> inline const tflite::StridedSliceOptions *Operator::builtin_options_as() const { return builtin_options_as_StridedSliceOptions(); } -template<> inline const ExpOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ExpOptions *Operator::builtin_options_as() const { return builtin_options_as_ExpOptions(); } -template<> inline const TopKV2Options *Operator::builtin_options_as() const { +template<> inline const tflite::TopKV2Options *Operator::builtin_options_as() const { return builtin_options_as_TopKV2Options(); } -template<> inline const SplitOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SplitOptions *Operator::builtin_options_as() const { return builtin_options_as_SplitOptions(); } -template<> inline const LogSoftmaxOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LogSoftmaxOptions *Operator::builtin_options_as() const { return builtin_options_as_LogSoftmaxOptions(); } -template<> inline const CastOptions *Operator::builtin_options_as() const { +template<> inline const tflite::CastOptions *Operator::builtin_options_as() const { return builtin_options_as_CastOptions(); } -template<> inline const DequantizeOptions *Operator::builtin_options_as() const { +template<> inline const tflite::DequantizeOptions *Operator::builtin_options_as() const { return builtin_options_as_DequantizeOptions(); } -template<> inline const MaximumMinimumOptions *Operator::builtin_options_as() const { +template<> inline const tflite::MaximumMinimumOptions *Operator::builtin_options_as() const { return builtin_options_as_MaximumMinimumOptions(); } -template<> inline const ArgMaxOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ArgMaxOptions *Operator::builtin_options_as() const { return builtin_options_as_ArgMaxOptions(); } -template<> inline const LessOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LessOptions *Operator::builtin_options_as() const { return builtin_options_as_LessOptions(); } -template<> inline const NegOptions *Operator::builtin_options_as() const { +template<> inline const tflite::NegOptions *Operator::builtin_options_as() const { return builtin_options_as_NegOptions(); } -template<> inline const PadV2Options *Operator::builtin_options_as() const { +template<> inline const tflite::PadV2Options *Operator::builtin_options_as() const { return builtin_options_as_PadV2Options(); } -template<> inline const GreaterOptions *Operator::builtin_options_as() const { +template<> inline const tflite::GreaterOptions *Operator::builtin_options_as() const { return builtin_options_as_GreaterOptions(); } -template<> inline const GreaterEqualOptions *Operator::builtin_options_as() const { +template<> inline const tflite::GreaterEqualOptions *Operator::builtin_options_as() const { return builtin_options_as_GreaterEqualOptions(); } -template<> inline const LessEqualOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LessEqualOptions *Operator::builtin_options_as() const { return builtin_options_as_LessEqualOptions(); } -template<> inline const SelectOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SelectOptions *Operator::builtin_options_as() const { return builtin_options_as_SelectOptions(); } -template<> inline const SliceOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SliceOptions *Operator::builtin_options_as() const { return builtin_options_as_SliceOptions(); } -template<> inline const TransposeConvOptions *Operator::builtin_options_as() const { +template<> inline const tflite::TransposeConvOptions *Operator::builtin_options_as() const { return builtin_options_as_TransposeConvOptions(); } -template<> inline const SparseToDenseOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SparseToDenseOptions *Operator::builtin_options_as() const { return builtin_options_as_SparseToDenseOptions(); } -template<> inline const TileOptions *Operator::builtin_options_as() const { +template<> inline const tflite::TileOptions *Operator::builtin_options_as() const { return builtin_options_as_TileOptions(); } -template<> inline const ExpandDimsOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ExpandDimsOptions *Operator::builtin_options_as() const { return builtin_options_as_ExpandDimsOptions(); } -template<> inline const EqualOptions *Operator::builtin_options_as() const { +template<> inline const tflite::EqualOptions *Operator::builtin_options_as() const { return builtin_options_as_EqualOptions(); } -template<> inline const NotEqualOptions *Operator::builtin_options_as() const { +template<> inline const tflite::NotEqualOptions *Operator::builtin_options_as() const { return builtin_options_as_NotEqualOptions(); } -template<> inline const ShapeOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ShapeOptions *Operator::builtin_options_as() const { return builtin_options_as_ShapeOptions(); } -template<> inline const PowOptions *Operator::builtin_options_as() const { +template<> inline const tflite::PowOptions *Operator::builtin_options_as() const { return builtin_options_as_PowOptions(); } -template<> inline const ArgMinOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ArgMinOptions *Operator::builtin_options_as() const { return builtin_options_as_ArgMinOptions(); } -template<> inline const FakeQuantOptions *Operator::builtin_options_as() const { +template<> inline const tflite::FakeQuantOptions *Operator::builtin_options_as() const { return builtin_options_as_FakeQuantOptions(); } -template<> inline const PackOptions *Operator::builtin_options_as() const { +template<> inline const tflite::PackOptions *Operator::builtin_options_as() const { return builtin_options_as_PackOptions(); } -template<> inline const LogicalOrOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LogicalOrOptions *Operator::builtin_options_as() const { return builtin_options_as_LogicalOrOptions(); } -template<> inline const OneHotOptions *Operator::builtin_options_as() const { +template<> inline const tflite::OneHotOptions *Operator::builtin_options_as() const { return builtin_options_as_OneHotOptions(); } -template<> inline const LogicalAndOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LogicalAndOptions *Operator::builtin_options_as() const { return builtin_options_as_LogicalAndOptions(); } -template<> inline const LogicalNotOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LogicalNotOptions *Operator::builtin_options_as() const { return builtin_options_as_LogicalNotOptions(); } -template<> inline const UnpackOptions *Operator::builtin_options_as() const { +template<> inline const tflite::UnpackOptions *Operator::builtin_options_as() const { return builtin_options_as_UnpackOptions(); } -template<> inline const FloorDivOptions *Operator::builtin_options_as() const { +template<> inline const tflite::FloorDivOptions *Operator::builtin_options_as() const { return builtin_options_as_FloorDivOptions(); } -template<> inline const SquareOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SquareOptions *Operator::builtin_options_as() const { return builtin_options_as_SquareOptions(); } -template<> inline const ZerosLikeOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ZerosLikeOptions *Operator::builtin_options_as() const { return builtin_options_as_ZerosLikeOptions(); } -template<> inline const FillOptions *Operator::builtin_options_as() const { +template<> inline const tflite::FillOptions *Operator::builtin_options_as() const { return builtin_options_as_FillOptions(); } -template<> inline const BidirectionalSequenceLSTMOptions *Operator::builtin_options_as() const { +template<> inline const tflite::BidirectionalSequenceLSTMOptions *Operator::builtin_options_as() const { return builtin_options_as_BidirectionalSequenceLSTMOptions(); } -template<> inline const BidirectionalSequenceRNNOptions *Operator::builtin_options_as() const { +template<> inline const tflite::BidirectionalSequenceRNNOptions *Operator::builtin_options_as() const { return builtin_options_as_BidirectionalSequenceRNNOptions(); } -template<> inline const UnidirectionalSequenceLSTMOptions *Operator::builtin_options_as() const { +template<> inline const tflite::UnidirectionalSequenceLSTMOptions *Operator::builtin_options_as() const { return builtin_options_as_UnidirectionalSequenceLSTMOptions(); } -template<> inline const FloorModOptions *Operator::builtin_options_as() const { +template<> inline const tflite::FloorModOptions *Operator::builtin_options_as() const { return builtin_options_as_FloorModOptions(); } -template<> inline const RangeOptions *Operator::builtin_options_as() const { +template<> inline const tflite::RangeOptions *Operator::builtin_options_as() const { return builtin_options_as_RangeOptions(); } -template<> inline const ResizeNearestNeighborOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ResizeNearestNeighborOptions *Operator::builtin_options_as() const { return builtin_options_as_ResizeNearestNeighborOptions(); } -template<> inline const LeakyReluOptions *Operator::builtin_options_as() const { +template<> inline const tflite::LeakyReluOptions *Operator::builtin_options_as() const { return builtin_options_as_LeakyReluOptions(); } -template<> inline const SquaredDifferenceOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SquaredDifferenceOptions *Operator::builtin_options_as() const { return builtin_options_as_SquaredDifferenceOptions(); } -template<> inline const MirrorPadOptions *Operator::builtin_options_as() const { +template<> inline const tflite::MirrorPadOptions *Operator::builtin_options_as() const { return builtin_options_as_MirrorPadOptions(); } -template<> inline const AbsOptions *Operator::builtin_options_as() const { +template<> inline const tflite::AbsOptions *Operator::builtin_options_as() const { return builtin_options_as_AbsOptions(); } -template<> inline const SplitVOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SplitVOptions *Operator::builtin_options_as() const { return builtin_options_as_SplitVOptions(); } -template<> inline const UniqueOptions *Operator::builtin_options_as() const { +template<> inline const tflite::UniqueOptions *Operator::builtin_options_as() const { return builtin_options_as_UniqueOptions(); } -template<> inline const ReverseV2Options *Operator::builtin_options_as() const { +template<> inline const tflite::ReverseV2Options *Operator::builtin_options_as() const { return builtin_options_as_ReverseV2Options(); } -template<> inline const AddNOptions *Operator::builtin_options_as() const { +template<> inline const tflite::AddNOptions *Operator::builtin_options_as() const { return builtin_options_as_AddNOptions(); } -template<> inline const GatherNdOptions *Operator::builtin_options_as() const { +template<> inline const tflite::GatherNdOptions *Operator::builtin_options_as() const { return builtin_options_as_GatherNdOptions(); } -template<> inline const CosOptions *Operator::builtin_options_as() const { +template<> inline const tflite::CosOptions *Operator::builtin_options_as() const { return builtin_options_as_CosOptions(); } -template<> inline const WhereOptions *Operator::builtin_options_as() const { +template<> inline const tflite::WhereOptions *Operator::builtin_options_as() const { return builtin_options_as_WhereOptions(); } -template<> inline const RankOptions *Operator::builtin_options_as() const { +template<> inline const tflite::RankOptions *Operator::builtin_options_as() const { return builtin_options_as_RankOptions(); } -template<> inline const ReverseSequenceOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ReverseSequenceOptions *Operator::builtin_options_as() const { return builtin_options_as_ReverseSequenceOptions(); } -template<> inline const MatrixDiagOptions *Operator::builtin_options_as() const { +template<> inline const tflite::MatrixDiagOptions *Operator::builtin_options_as() const { return builtin_options_as_MatrixDiagOptions(); } -template<> inline const QuantizeOptions *Operator::builtin_options_as() const { +template<> inline const tflite::QuantizeOptions *Operator::builtin_options_as() const { return builtin_options_as_QuantizeOptions(); } -template<> inline const MatrixSetDiagOptions *Operator::builtin_options_as() const { +template<> inline const tflite::MatrixSetDiagOptions *Operator::builtin_options_as() const { return builtin_options_as_MatrixSetDiagOptions(); } -template<> inline const HardSwishOptions *Operator::builtin_options_as() const { +template<> inline const tflite::HardSwishOptions *Operator::builtin_options_as() const { return builtin_options_as_HardSwishOptions(); } -template<> inline const IfOptions *Operator::builtin_options_as() const { +template<> inline const tflite::IfOptions *Operator::builtin_options_as() const { return builtin_options_as_IfOptions(); } -template<> inline const WhileOptions *Operator::builtin_options_as() const { +template<> inline const tflite::WhileOptions *Operator::builtin_options_as() const { return builtin_options_as_WhileOptions(); } -template<> inline const DepthToSpaceOptions *Operator::builtin_options_as() const { +template<> inline const tflite::DepthToSpaceOptions *Operator::builtin_options_as() const { return builtin_options_as_DepthToSpaceOptions(); } -template<> inline const NonMaxSuppressionV4Options *Operator::builtin_options_as() const { +template<> inline const tflite::NonMaxSuppressionV4Options *Operator::builtin_options_as() const { return builtin_options_as_NonMaxSuppressionV4Options(); } -template<> inline const NonMaxSuppressionV5Options *Operator::builtin_options_as() const { +template<> inline const tflite::NonMaxSuppressionV5Options *Operator::builtin_options_as() const { return builtin_options_as_NonMaxSuppressionV5Options(); } -template<> inline const ScatterNdOptions *Operator::builtin_options_as() const { +template<> inline const tflite::ScatterNdOptions *Operator::builtin_options_as() const { return builtin_options_as_ScatterNdOptions(); } -template<> inline const SelectV2Options *Operator::builtin_options_as() const { +template<> inline const tflite::SelectV2Options *Operator::builtin_options_as() const { return builtin_options_as_SelectV2Options(); } -template<> inline const DensifyOptions *Operator::builtin_options_as() const { +template<> inline const tflite::DensifyOptions *Operator::builtin_options_as() const { return builtin_options_as_DensifyOptions(); } -template<> inline const SegmentSumOptions *Operator::builtin_options_as() const { +template<> inline const tflite::SegmentSumOptions *Operator::builtin_options_as() const { return builtin_options_as_SegmentSumOptions(); } @@ -9993,7 +9993,7 @@ struct OperatorBuilder { void add_outputs(flatbuffers::Offset> outputs) { fbb_.AddOffset(Operator::VT_OUTPUTS, outputs); } - void add_builtin_options_type(BuiltinOptions builtin_options_type) { + void add_builtin_options_type(tflite::BuiltinOptions builtin_options_type) { fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_TYPE, static_cast(builtin_options_type), 0); } void add_builtin_options(flatbuffers::Offset builtin_options) { @@ -10002,7 +10002,7 @@ struct OperatorBuilder { void add_custom_options(flatbuffers::Offset> custom_options) { fbb_.AddOffset(Operator::VT_CUSTOM_OPTIONS, custom_options); } - void add_custom_options_format(CustomOptionsFormat custom_options_format) { + void add_custom_options_format(tflite::CustomOptionsFormat custom_options_format) { fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); } void add_mutating_variable_inputs(flatbuffers::Offset> mutating_variable_inputs) { @@ -10028,10 +10028,10 @@ inline flatbuffers::Offset CreateOperator( uint32_t opcode_index = 0, flatbuffers::Offset> inputs = 0, flatbuffers::Offset> outputs = 0, - BuiltinOptions builtin_options_type = BuiltinOptions_NONE, + tflite::BuiltinOptions builtin_options_type = tflite::BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, flatbuffers::Offset> custom_options = 0, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + tflite::CustomOptionsFormat custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS, flatbuffers::Offset> mutating_variable_inputs = 0, flatbuffers::Offset> intermediates = 0) { OperatorBuilder builder_(_fbb); @@ -10052,10 +10052,10 @@ inline flatbuffers::Offset CreateOperatorDirect( uint32_t opcode_index = 0, const std::vector *inputs = nullptr, const std::vector *outputs = nullptr, - BuiltinOptions builtin_options_type = BuiltinOptions_NONE, + tflite::BuiltinOptions builtin_options_type = tflite::BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, const std::vector *custom_options = nullptr, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + tflite::CustomOptionsFormat custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS, const std::vector *mutating_variable_inputs = nullptr, const std::vector *intermediates = nullptr) { auto inputs__ = inputs ? _fbb.CreateVector(*inputs) : 0; @@ -10080,10 +10080,10 @@ flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fb struct SubGraphT : public flatbuffers::NativeTable { typedef SubGraph TableType; - std::vector> tensors; + std::vector> tensors; std::vector inputs; std::vector outputs; - std::vector> operators; + std::vector> operators; std::string name; SubGraphT() { } @@ -10098,8 +10098,8 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_OPERATORS = 10, VT_NAME = 12 }; - const flatbuffers::Vector> *tensors() const { - return GetPointer> *>(VT_TENSORS); + const flatbuffers::Vector> *tensors() const { + return GetPointer> *>(VT_TENSORS); } const flatbuffers::Vector *inputs() const { return GetPointer *>(VT_INPUTS); @@ -10107,8 +10107,8 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *outputs() const { return GetPointer *>(VT_OUTPUTS); } - const flatbuffers::Vector> *operators() const { - return GetPointer> *>(VT_OPERATORS); + const flatbuffers::Vector> *operators() const { + return GetPointer> *>(VT_OPERATORS); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); @@ -10137,7 +10137,7 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct SubGraphBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_tensors(flatbuffers::Offset>> tensors) { + void add_tensors(flatbuffers::Offset>> tensors) { fbb_.AddOffset(SubGraph::VT_TENSORS, tensors); } void add_inputs(flatbuffers::Offset> inputs) { @@ -10146,7 +10146,7 @@ struct SubGraphBuilder { void add_outputs(flatbuffers::Offset> outputs) { fbb_.AddOffset(SubGraph::VT_OUTPUTS, outputs); } - void add_operators(flatbuffers::Offset>> operators) { + void add_operators(flatbuffers::Offset>> operators) { fbb_.AddOffset(SubGraph::VT_OPERATORS, operators); } void add_name(flatbuffers::Offset name) { @@ -10166,10 +10166,10 @@ struct SubGraphBuilder { inline flatbuffers::Offset CreateSubGraph( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset>> tensors = 0, + flatbuffers::Offset>> tensors = 0, flatbuffers::Offset> inputs = 0, flatbuffers::Offset> outputs = 0, - flatbuffers::Offset>> operators = 0, + flatbuffers::Offset>> operators = 0, flatbuffers::Offset name = 0) { SubGraphBuilder builder_(_fbb); builder_.add_name(name); @@ -10182,15 +10182,15 @@ inline flatbuffers::Offset CreateSubGraph( inline flatbuffers::Offset CreateSubGraphDirect( flatbuffers::FlatBufferBuilder &_fbb, - const std::vector> *tensors = nullptr, + const std::vector> *tensors = nullptr, const std::vector *inputs = nullptr, const std::vector *outputs = nullptr, - const std::vector> *operators = nullptr, + const std::vector> *operators = nullptr, const char *name = nullptr) { - auto tensors__ = tensors ? _fbb.CreateVector>(*tensors) : 0; + auto tensors__ = tensors ? _fbb.CreateVector>(*tensors) : 0; auto inputs__ = inputs ? _fbb.CreateVector(*inputs) : 0; auto outputs__ = outputs ? _fbb.CreateVector(*outputs) : 0; - auto operators__ = operators ? _fbb.CreateVector>(*operators) : 0; + auto operators__ = operators ? _fbb.CreateVector>(*operators) : 0; auto name__ = name ? _fbb.CreateString(name) : 0; return tflite::CreateSubGraph( _fbb, @@ -10347,12 +10347,12 @@ flatbuffers::Offset CreateMetadata(flatbuffers::FlatBufferBuilder &_fb struct ModelT : public flatbuffers::NativeTable { typedef Model TableType; uint32_t version; - std::vector> operator_codes; - std::vector> subgraphs; + std::vector> operator_codes; + std::vector> subgraphs; std::string description; - std::vector> buffers; + std::vector> buffers; std::vector metadata_buffer; - std::vector> metadata; + std::vector> metadata; ModelT() : version(0) { } @@ -10372,23 +10372,23 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { uint32_t version() const { return GetField(VT_VERSION, 0); } - const flatbuffers::Vector> *operator_codes() const { - return GetPointer> *>(VT_OPERATOR_CODES); + const flatbuffers::Vector> *operator_codes() const { + return GetPointer> *>(VT_OPERATOR_CODES); } - const flatbuffers::Vector> *subgraphs() const { - return GetPointer> *>(VT_SUBGRAPHS); + const flatbuffers::Vector> *subgraphs() const { + return GetPointer> *>(VT_SUBGRAPHS); } const flatbuffers::String *description() const { return GetPointer(VT_DESCRIPTION); } - const flatbuffers::Vector> *buffers() const { - return GetPointer> *>(VT_BUFFERS); + const flatbuffers::Vector> *buffers() const { + return GetPointer> *>(VT_BUFFERS); } const flatbuffers::Vector *metadata_buffer() const { return GetPointer *>(VT_METADATA_BUFFER); } - const flatbuffers::Vector> *metadata() const { - return GetPointer> *>(VT_METADATA); + const flatbuffers::Vector> *metadata() const { + return GetPointer> *>(VT_METADATA); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -10422,22 +10422,22 @@ struct ModelBuilder { void add_version(uint32_t version) { fbb_.AddElement(Model::VT_VERSION, version, 0); } - void add_operator_codes(flatbuffers::Offset>> operator_codes) { + void add_operator_codes(flatbuffers::Offset>> operator_codes) { fbb_.AddOffset(Model::VT_OPERATOR_CODES, operator_codes); } - void add_subgraphs(flatbuffers::Offset>> subgraphs) { + void add_subgraphs(flatbuffers::Offset>> subgraphs) { fbb_.AddOffset(Model::VT_SUBGRAPHS, subgraphs); } void add_description(flatbuffers::Offset description) { fbb_.AddOffset(Model::VT_DESCRIPTION, description); } - void add_buffers(flatbuffers::Offset>> buffers) { + void add_buffers(flatbuffers::Offset>> buffers) { fbb_.AddOffset(Model::VT_BUFFERS, buffers); } void add_metadata_buffer(flatbuffers::Offset> metadata_buffer) { fbb_.AddOffset(Model::VT_METADATA_BUFFER, metadata_buffer); } - void add_metadata(flatbuffers::Offset>> metadata) { + void add_metadata(flatbuffers::Offset>> metadata) { fbb_.AddOffset(Model::VT_METADATA, metadata); } explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -10455,12 +10455,12 @@ struct ModelBuilder { inline flatbuffers::Offset CreateModel( flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, - flatbuffers::Offset>> operator_codes = 0, - flatbuffers::Offset>> subgraphs = 0, + flatbuffers::Offset>> operator_codes = 0, + flatbuffers::Offset>> subgraphs = 0, flatbuffers::Offset description = 0, - flatbuffers::Offset>> buffers = 0, + flatbuffers::Offset>> buffers = 0, flatbuffers::Offset> metadata_buffer = 0, - flatbuffers::Offset>> metadata = 0) { + flatbuffers::Offset>> metadata = 0) { ModelBuilder builder_(_fbb); builder_.add_metadata(metadata); builder_.add_metadata_buffer(metadata_buffer); @@ -10475,18 +10475,18 @@ inline flatbuffers::Offset CreateModel( inline flatbuffers::Offset CreateModelDirect( flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, - const std::vector> *operator_codes = nullptr, - const std::vector> *subgraphs = nullptr, + const std::vector> *operator_codes = nullptr, + const std::vector> *subgraphs = nullptr, const char *description = nullptr, - const std::vector> *buffers = nullptr, + const std::vector> *buffers = nullptr, const std::vector *metadata_buffer = nullptr, - const std::vector> *metadata = nullptr) { - auto operator_codes__ = operator_codes ? _fbb.CreateVector>(*operator_codes) : 0; - auto subgraphs__ = subgraphs ? _fbb.CreateVector>(*subgraphs) : 0; + const std::vector> *metadata = nullptr) { + auto operator_codes__ = operator_codes ? _fbb.CreateVector>(*operator_codes) : 0; + auto subgraphs__ = subgraphs ? _fbb.CreateVector>(*subgraphs) : 0; auto description__ = description ? _fbb.CreateString(description) : 0; - auto buffers__ = buffers ? _fbb.CreateVector>(*buffers) : 0; + auto buffers__ = buffers ? _fbb.CreateVector>(*buffers) : 0; auto metadata_buffer__ = metadata_buffer ? _fbb.CreateVector(*metadata_buffer) : 0; - auto metadata__ = metadata ? _fbb.CreateVector>(*metadata) : 0; + auto metadata__ = metadata ? _fbb.CreateVector>(*metadata) : 0; return tflite::CreateModel( _fbb, version, @@ -10509,7 +10509,7 @@ inline CustomQuantizationT *CustomQuantization::UnPack(const flatbuffers::resolv inline void CustomQuantization::UnPackTo(CustomQuantizationT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = custom(); if (_e) { _o->custom.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom[_i] = _e->Get(_i); } } }; + { auto _e = custom(); if (_e) { _o->custom.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset CustomQuantization::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10536,13 +10536,13 @@ inline QuantizationParametersT *QuantizationParameters::UnPack(const flatbuffers inline void QuantizationParameters::UnPackTo(QuantizationParametersT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = min(); if (_e) { _o->min.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->min[_i] = _e->Get(_i); } } }; - { auto _e = max(); if (_e) { _o->max.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->max[_i] = _e->Get(_i); } } }; - { auto _e = scale(); if (_e) { _o->scale.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scale[_i] = _e->Get(_i); } } }; - { auto _e = zero_point(); if (_e) { _o->zero_point.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zero_point[_i] = _e->Get(_i); } } }; - { auto _e = details_type(); _o->details.type = _e; }; - { auto _e = details(); if (_e) _o->details.value = QuantizationDetailsUnion::UnPack(_e, details_type(), _resolver); }; - { auto _e = quantized_dimension(); _o->quantized_dimension = _e; }; + { auto _e = min(); if (_e) { _o->min.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->min[_i] = _e->Get(_i); } } } + { auto _e = max(); if (_e) { _o->max.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->max[_i] = _e->Get(_i); } } } + { auto _e = scale(); if (_e) { _o->scale.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scale[_i] = _e->Get(_i); } } } + { auto _e = zero_point(); if (_e) { _o->zero_point.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zero_point[_i] = _e->Get(_i); } } } + { auto _e = details_type(); _o->details.type = _e; } + { auto _e = details(); if (_e) _o->details.value = tflite::QuantizationDetailsUnion::UnPack(_e, details_type(), _resolver); } + { auto _e = quantized_dimension(); _o->quantized_dimension = _e; } } inline flatbuffers::Offset QuantizationParameters::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10580,7 +10580,7 @@ inline Int32VectorT *Int32Vector::UnPack(const flatbuffers::resolver_function_t inline void Int32Vector::UnPackTo(Int32VectorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } }; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset Int32Vector::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10606,7 +10606,7 @@ inline Uint16VectorT *Uint16Vector::UnPack(const flatbuffers::resolver_function_ inline void Uint16Vector::UnPackTo(Uint16VectorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } }; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset Uint16Vector::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10633,7 +10633,7 @@ inline Uint8VectorT *Uint8Vector::UnPack(const flatbuffers::resolver_function_t inline void Uint8Vector::UnPackTo(Uint8VectorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } }; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset Uint8Vector::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10660,12 +10660,12 @@ inline DimensionMetadataT *DimensionMetadata::UnPack(const flatbuffers::resolver inline void DimensionMetadata::UnPackTo(DimensionMetadataT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = format(); _o->format = _e; }; - { auto _e = dense_size(); _o->dense_size = _e; }; - { auto _e = array_segments_type(); _o->array_segments.type = _e; }; - { auto _e = array_segments(); if (_e) _o->array_segments.value = SparseIndexVectorUnion::UnPack(_e, array_segments_type(), _resolver); }; - { auto _e = array_indices_type(); _o->array_indices.type = _e; }; - { auto _e = array_indices(); if (_e) _o->array_indices.value = SparseIndexVectorUnion::UnPack(_e, array_indices_type(), _resolver); }; + { auto _e = format(); _o->format = _e; } + { auto _e = dense_size(); _o->dense_size = _e; } + { auto _e = array_segments_type(); _o->array_segments.type = _e; } + { auto _e = array_segments(); if (_e) _o->array_segments.value = tflite::SparseIndexVectorUnion::UnPack(_e, array_segments_type(), _resolver); } + { auto _e = array_indices_type(); _o->array_indices.type = _e; } + { auto _e = array_indices(); if (_e) _o->array_indices.value = tflite::SparseIndexVectorUnion::UnPack(_e, array_indices_type(), _resolver); } } inline flatbuffers::Offset DimensionMetadata::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10701,9 +10701,9 @@ inline SparsityParametersT *SparsityParameters::UnPack(const flatbuffers::resolv inline void SparsityParameters::UnPackTo(SparsityParametersT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = traversal_order(); if (_e) { _o->traversal_order.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->traversal_order[_i] = _e->Get(_i); } } }; - { auto _e = block_map(); if (_e) { _o->block_map.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->block_map[_i] = _e->Get(_i); } } }; - { auto _e = dim_metadata(); if (_e) { _o->dim_metadata.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->dim_metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = traversal_order(); if (_e) { _o->traversal_order.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->traversal_order[_i] = _e->Get(_i); } } } + { auto _e = block_map(); if (_e) { _o->block_map.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->block_map[_i] = _e->Get(_i); } } } + { auto _e = dim_metadata(); if (_e) { _o->dim_metadata.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->dim_metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } } inline flatbuffers::Offset SparsityParameters::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparsityParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10716,7 +10716,7 @@ inline flatbuffers::Offset CreateSparsityParameters(flatbuff struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SparsityParametersT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _traversal_order = _o->traversal_order.size() ? _fbb.CreateVector(_o->traversal_order) : 0; auto _block_map = _o->block_map.size() ? _fbb.CreateVector(_o->block_map) : 0; - auto _dim_metadata = _o->dim_metadata.size() ? _fbb.CreateVector> (_o->dim_metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateDimensionMetadata(*__va->__fbb, __va->__o->dim_metadata[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _dim_metadata = _o->dim_metadata.size() ? _fbb.CreateVector> (_o->dim_metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateDimensionMetadata(*__va->__fbb, __va->__o->dim_metadata[i].get(), __va->__rehasher); }, &_va ) : 0; return tflite::CreateSparsityParameters( _fbb, _traversal_order, @@ -10733,14 +10733,14 @@ inline TensorT *Tensor::UnPack(const flatbuffers::resolver_function_t *_resolver inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = shape(); if (_e) { _o->shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape[_i] = _e->Get(_i); } } }; - { auto _e = type(); _o->type = _e; }; - { auto _e = buffer(); _o->buffer = _e; }; - { auto _e = name(); if (_e) _o->name = _e->str(); }; - { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); }; - { auto _e = is_variable(); _o->is_variable = _e; }; - { auto _e = sparsity(); if (_e) _o->sparsity = std::unique_ptr(_e->UnPack(_resolver)); }; - { auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } }; + { auto _e = shape(); if (_e) { _o->shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape[_i] = _e->Get(_i); } } } + { auto _e = type(); _o->type = _e; } + { auto _e = buffer(); _o->buffer = _e; } + { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); } + { auto _e = is_variable(); _o->is_variable = _e; } + { auto _e = sparsity(); if (_e) _o->sparsity = std::unique_ptr(_e->UnPack(_resolver)); } + { auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10780,12 +10780,12 @@ inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_functio inline void Conv2DOptions::UnPackTo(Conv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = padding(); _o->padding = _e; }; - { auto _e = stride_w(); _o->stride_w = _e; }; - { auto _e = stride_h(); _o->stride_h = _e; }; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; - { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; }; - { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; }; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; } + { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; } } inline flatbuffers::Offset Conv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10821,12 +10821,12 @@ inline Pool2DOptionsT *Pool2DOptions::UnPack(const flatbuffers::resolver_functio inline void Pool2DOptions::UnPackTo(Pool2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = padding(); _o->padding = _e; }; - { auto _e = stride_w(); _o->stride_w = _e; }; - { auto _e = stride_h(); _o->stride_h = _e; }; - { auto _e = filter_width(); _o->filter_width = _e; }; - { auto _e = filter_height(); _o->filter_height = _e; }; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } + { auto _e = filter_width(); _o->filter_width = _e; } + { auto _e = filter_height(); _o->filter_height = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset Pool2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10862,13 +10862,13 @@ inline DepthwiseConv2DOptionsT *DepthwiseConv2DOptions::UnPack(const flatbuffers inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = padding(); _o->padding = _e; }; - { auto _e = stride_w(); _o->stride_w = _e; }; - { auto _e = stride_h(); _o->stride_h = _e; }; - { auto _e = depth_multiplier(); _o->depth_multiplier = _e; }; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; - { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; }; - { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; }; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } + { auto _e = depth_multiplier(); _o->depth_multiplier = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; } + { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; } } inline flatbuffers::Offset DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10906,9 +10906,9 @@ inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffe inline void ConcatEmbeddingsOptions::UnPackTo(ConcatEmbeddingsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = num_channels(); _o->num_channels = _e; }; - { auto _e = num_columns_per_channel(); if (_e) { _o->num_columns_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->num_columns_per_channel[_i] = _e->Get(_i); } } }; - { auto _e = embedding_dim_per_channel(); if (_e) { _o->embedding_dim_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_dim_per_channel[_i] = _e->Get(_i); } } }; + { auto _e = num_channels(); _o->num_channels = _e; } + { auto _e = num_columns_per_channel(); if (_e) { _o->num_columns_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->num_columns_per_channel[_i] = _e->Get(_i); } } } + { auto _e = embedding_dim_per_channel(); if (_e) { _o->embedding_dim_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_dim_per_channel[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset ConcatEmbeddingsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10938,7 +10938,7 @@ inline LSHProjectionOptionsT *LSHProjectionOptions::UnPack(const flatbuffers::re inline void LSHProjectionOptions::UnPackTo(LSHProjectionOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = type(); _o->type = _e; }; + { auto _e = type(); _o->type = _e; } } inline flatbuffers::Offset LSHProjectionOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10964,8 +10964,8 @@ inline SVDFOptionsT *SVDFOptions::UnPack(const flatbuffers::resolver_function_t inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = rank(); _o->rank = _e; }; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = rank(); _o->rank = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10993,7 +10993,7 @@ inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_ inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11019,8 +11019,8 @@ inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolv inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = time_major(); _o->time_major = _e; }; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = time_major(); _o->time_major = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11048,9 +11048,9 @@ inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = time_major(); _o->time_major = _e; }; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; - { auto _e = merge_outputs(); _o->merge_outputs = _e; }; + { auto _e = time_major(); _o->time_major = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = merge_outputs(); _o->merge_outputs = _e; } } inline flatbuffers::Offset BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11080,9 +11080,9 @@ inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers:: inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; - { auto _e = weights_format(); _o->weights_format = _e; }; - { auto _e = keep_num_dims(); _o->keep_num_dims = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = weights_format(); _o->weights_format = _e; } + { auto _e = keep_num_dims(); _o->keep_num_dims = _e; } } inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11112,7 +11112,7 @@ inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_funct inline void SoftmaxOptions::UnPackTo(SoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = beta(); _o->beta = _e; }; + { auto _e = beta(); _o->beta = _e; } } inline flatbuffers::Offset SoftmaxOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11138,8 +11138,8 @@ inline ConcatenationOptionsT *ConcatenationOptions::UnPack(const flatbuffers::re inline void ConcatenationOptions::UnPackTo(ConcatenationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = axis(); _o->axis = _e; }; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = axis(); _o->axis = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset ConcatenationOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11167,7 +11167,7 @@ inline AddOptionsT *AddOptions::UnPack(const flatbuffers::resolver_function_t *_ inline void AddOptions::UnPackTo(AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset AddOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11193,7 +11193,7 @@ inline MulOptionsT *MulOptions::UnPack(const flatbuffers::resolver_function_t *_ inline void MulOptions::UnPackTo(MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset MulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11219,7 +11219,7 @@ inline L2NormOptionsT *L2NormOptions::UnPack(const flatbuffers::resolver_functio inline void L2NormOptions::UnPackTo(L2NormOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset L2NormOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11245,10 +11245,10 @@ inline LocalResponseNormalizationOptionsT *LocalResponseNormalizationOptions::Un inline void LocalResponseNormalizationOptions::UnPackTo(LocalResponseNormalizationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = radius(); _o->radius = _e; }; - { auto _e = bias(); _o->bias = _e; }; - { auto _e = alpha(); _o->alpha = _e; }; - { auto _e = beta(); _o->beta = _e; }; + { auto _e = radius(); _o->radius = _e; } + { auto _e = bias(); _o->bias = _e; } + { auto _e = alpha(); _o->alpha = _e; } + { auto _e = beta(); _o->beta = _e; } } inline flatbuffers::Offset LocalResponseNormalizationOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11280,10 +11280,10 @@ inline LSTMOptionsT *LSTMOptions::UnPack(const flatbuffers::resolver_function_t inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; - { auto _e = cell_clip(); _o->cell_clip = _e; }; - { auto _e = proj_clip(); _o->proj_clip = _e; }; - { auto _e = kernel_type(); _o->kernel_type = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = cell_clip(); _o->cell_clip = _e; } + { auto _e = proj_clip(); _o->proj_clip = _e; } + { auto _e = kernel_type(); _o->kernel_type = _e; } } inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11315,10 +11315,10 @@ inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::Un inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; - { auto _e = cell_clip(); _o->cell_clip = _e; }; - { auto _e = proj_clip(); _o->proj_clip = _e; }; - { auto _e = time_major(); _o->time_major = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = cell_clip(); _o->cell_clip = _e; } + { auto _e = proj_clip(); _o->proj_clip = _e; } + { auto _e = time_major(); _o->time_major = _e; } } inline flatbuffers::Offset UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11350,11 +11350,11 @@ inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPa inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; - { auto _e = cell_clip(); _o->cell_clip = _e; }; - { auto _e = proj_clip(); _o->proj_clip = _e; }; - { auto _e = merge_outputs(); _o->merge_outputs = _e; }; - { auto _e = time_major(); _o->time_major = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = cell_clip(); _o->cell_clip = _e; } + { auto _e = proj_clip(); _o->proj_clip = _e; } + { auto _e = merge_outputs(); _o->merge_outputs = _e; } + { auto _e = time_major(); _o->time_major = _e; } } inline flatbuffers::Offset BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11388,8 +11388,8 @@ inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers:: inline void ResizeBilinearOptions::UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = align_corners(); _o->align_corners = _e; }; - { auto _e = half_pixel_centers(); _o->half_pixel_centers = _e; }; + { auto _e = align_corners(); _o->align_corners = _e; } + { auto _e = half_pixel_centers(); _o->half_pixel_centers = _e; } } inline flatbuffers::Offset ResizeBilinearOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11417,7 +11417,7 @@ inline ResizeNearestNeighborOptionsT *ResizeNearestNeighborOptions::UnPack(const inline void ResizeNearestNeighborOptions::UnPackTo(ResizeNearestNeighborOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = align_corners(); _o->align_corners = _e; }; + { auto _e = align_corners(); _o->align_corners = _e; } } inline flatbuffers::Offset ResizeNearestNeighborOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeNearestNeighborOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11443,7 +11443,7 @@ inline CallOptionsT *CallOptions::UnPack(const flatbuffers::resolver_function_t inline void CallOptions::UnPackTo(CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = subgraph(); _o->subgraph = _e; }; + { auto _e = subgraph(); _o->subgraph = _e; } } inline flatbuffers::Offset CallOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11515,7 +11515,7 @@ inline ReshapeOptionsT *ReshapeOptions::UnPack(const flatbuffers::resolver_funct inline void ReshapeOptions::UnPackTo(ReshapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = new_shape(); if (_e) { _o->new_shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->new_shape[_i] = _e->Get(_i); } } }; + { auto _e = new_shape(); if (_e) { _o->new_shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->new_shape[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset ReshapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11587,9 +11587,9 @@ inline SkipGramOptionsT *SkipGramOptions::UnPack(const flatbuffers::resolver_fun inline void SkipGramOptions::UnPackTo(SkipGramOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = ngram_size(); _o->ngram_size = _e; }; - { auto _e = max_skip_size(); _o->max_skip_size = _e; }; - { auto _e = include_all_ngrams(); _o->include_all_ngrams = _e; }; + { auto _e = ngram_size(); _o->ngram_size = _e; } + { auto _e = max_skip_size(); _o->max_skip_size = _e; } + { auto _e = include_all_ngrams(); _o->include_all_ngrams = _e; } } inline flatbuffers::Offset SkipGramOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11619,7 +11619,7 @@ inline SpaceToDepthOptionsT *SpaceToDepthOptions::UnPack(const flatbuffers::reso inline void SpaceToDepthOptions::UnPackTo(SpaceToDepthOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = block_size(); _o->block_size = _e; }; + { auto _e = block_size(); _o->block_size = _e; } } inline flatbuffers::Offset SpaceToDepthOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11645,7 +11645,7 @@ inline DepthToSpaceOptionsT *DepthToSpaceOptions::UnPack(const flatbuffers::reso inline void DepthToSpaceOptions::UnPackTo(DepthToSpaceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = block_size(); _o->block_size = _e; }; + { auto _e = block_size(); _o->block_size = _e; } } inline flatbuffers::Offset DepthToSpaceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthToSpaceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11671,7 +11671,7 @@ inline SubOptionsT *SubOptions::UnPack(const flatbuffers::resolver_function_t *_ inline void SubOptions::UnPackTo(SubOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset SubOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11697,7 +11697,7 @@ inline DivOptionsT *DivOptions::UnPack(const flatbuffers::resolver_function_t *_ inline void DivOptions::UnPackTo(DivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } } inline flatbuffers::Offset DivOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11746,7 +11746,7 @@ inline EmbeddingLookupSparseOptionsT *EmbeddingLookupSparseOptions::UnPack(const inline void EmbeddingLookupSparseOptions::UnPackTo(EmbeddingLookupSparseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = combiner(); _o->combiner = _e; }; + { auto _e = combiner(); _o->combiner = _e; } } inline flatbuffers::Offset EmbeddingLookupSparseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11772,7 +11772,7 @@ inline GatherOptionsT *GatherOptions::UnPack(const flatbuffers::resolver_functio inline void GatherOptions::UnPackTo(GatherOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = axis(); _o->axis = _e; }; + { auto _e = axis(); _o->axis = _e; } } inline flatbuffers::Offset GatherOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11867,7 +11867,7 @@ inline ReducerOptionsT *ReducerOptions::UnPack(const flatbuffers::resolver_funct inline void ReducerOptions::UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = keep_dims(); _o->keep_dims = _e; }; + { auto _e = keep_dims(); _o->keep_dims = _e; } } inline flatbuffers::Offset ReducerOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11893,7 +11893,7 @@ inline SqueezeOptionsT *SqueezeOptions::UnPack(const flatbuffers::resolver_funct inline void SqueezeOptions::UnPackTo(SqueezeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = squeeze_dims(); if (_e) { _o->squeeze_dims.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->squeeze_dims[_i] = _e->Get(_i); } } }; + { auto _e = squeeze_dims(); if (_e) { _o->squeeze_dims.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->squeeze_dims[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset SqueezeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11919,7 +11919,7 @@ inline SplitOptionsT *SplitOptions::UnPack(const flatbuffers::resolver_function_ inline void SplitOptions::UnPackTo(SplitOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = num_splits(); _o->num_splits = _e; }; + { auto _e = num_splits(); _o->num_splits = _e; } } inline flatbuffers::Offset SplitOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11945,7 +11945,7 @@ inline SplitVOptionsT *SplitVOptions::UnPack(const flatbuffers::resolver_functio inline void SplitVOptions::UnPackTo(SplitVOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = num_splits(); _o->num_splits = _e; }; + { auto _e = num_splits(); _o->num_splits = _e; } } inline flatbuffers::Offset SplitVOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SplitVOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11971,11 +11971,11 @@ inline StridedSliceOptionsT *StridedSliceOptions::UnPack(const flatbuffers::reso inline void StridedSliceOptions::UnPackTo(StridedSliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = begin_mask(); _o->begin_mask = _e; }; - { auto _e = end_mask(); _o->end_mask = _e; }; - { auto _e = ellipsis_mask(); _o->ellipsis_mask = _e; }; - { auto _e = new_axis_mask(); _o->new_axis_mask = _e; }; - { auto _e = shrink_axis_mask(); _o->shrink_axis_mask = _e; }; + { auto _e = begin_mask(); _o->begin_mask = _e; } + { auto _e = end_mask(); _o->end_mask = _e; } + { auto _e = ellipsis_mask(); _o->ellipsis_mask = _e; } + { auto _e = new_axis_mask(); _o->new_axis_mask = _e; } + { auto _e = shrink_axis_mask(); _o->shrink_axis_mask = _e; } } inline flatbuffers::Offset StridedSliceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12032,8 +12032,8 @@ inline CastOptionsT *CastOptions::UnPack(const flatbuffers::resolver_function_t inline void CastOptions::UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = in_data_type(); _o->in_data_type = _e; }; - { auto _e = out_data_type(); _o->out_data_type = _e; }; + { auto _e = in_data_type(); _o->in_data_type = _e; } + { auto _e = out_data_type(); _o->out_data_type = _e; } } inline flatbuffers::Offset CastOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12130,7 +12130,7 @@ inline ArgMaxOptionsT *ArgMaxOptions::UnPack(const flatbuffers::resolver_functio inline void ArgMaxOptions::UnPackTo(ArgMaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = output_type(); _o->output_type = _e; }; + { auto _e = output_type(); _o->output_type = _e; } } inline flatbuffers::Offset ArgMaxOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12156,7 +12156,7 @@ inline ArgMinOptionsT *ArgMinOptions::UnPack(const flatbuffers::resolver_functio inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = output_type(); _o->output_type = _e; }; + { auto _e = output_type(); _o->output_type = _e; } } inline flatbuffers::Offset ArgMinOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12343,9 +12343,9 @@ inline TransposeConvOptionsT *TransposeConvOptions::UnPack(const flatbuffers::re inline void TransposeConvOptions::UnPackTo(TransposeConvOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = padding(); _o->padding = _e; }; - { auto _e = stride_w(); _o->stride_w = _e; }; - { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } } inline flatbuffers::Offset TransposeConvOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12398,7 +12398,7 @@ inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::re inline void SparseToDenseOptions::UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = validate_indices(); _o->validate_indices = _e; }; + { auto _e = validate_indices(); _o->validate_indices = _e; } } inline flatbuffers::Offset SparseToDenseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12470,7 +12470,7 @@ inline ShapeOptionsT *ShapeOptions::UnPack(const flatbuffers::resolver_function_ inline void ShapeOptions::UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = out_type(); _o->out_type = _e; }; + { auto _e = out_type(); _o->out_type = _e; } } inline flatbuffers::Offset ShapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12542,10 +12542,10 @@ inline FakeQuantOptionsT *FakeQuantOptions::UnPack(const flatbuffers::resolver_f inline void FakeQuantOptions::UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = min(); _o->min = _e; }; - { auto _e = max(); _o->max = _e; }; - { auto _e = num_bits(); _o->num_bits = _e; }; - { auto _e = narrow_range(); _o->narrow_range = _e; }; + { auto _e = min(); _o->min = _e; } + { auto _e = max(); _o->max = _e; } + { auto _e = num_bits(); _o->num_bits = _e; } + { auto _e = narrow_range(); _o->narrow_range = _e; } } inline flatbuffers::Offset FakeQuantOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12577,8 +12577,8 @@ inline PackOptionsT *PackOptions::UnPack(const flatbuffers::resolver_function_t inline void PackOptions::UnPackTo(PackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = values_count(); _o->values_count = _e; }; - { auto _e = axis(); _o->axis = _e; }; + { auto _e = values_count(); _o->values_count = _e; } + { auto _e = axis(); _o->axis = _e; } } inline flatbuffers::Offset PackOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12629,7 +12629,7 @@ inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_functio inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = axis(); _o->axis = _e; }; + { auto _e = axis(); _o->axis = _e; } } inline flatbuffers::Offset OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12747,8 +12747,8 @@ inline UnpackOptionsT *UnpackOptions::UnPack(const flatbuffers::resolver_functio inline void UnpackOptions::UnPackTo(UnpackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = num(); _o->num = _e; }; - { auto _e = axis(); _o->axis = _e; }; + { auto _e = num(); _o->num = _e; } + { auto _e = axis(); _o->axis = _e; } } inline flatbuffers::Offset UnpackOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12914,7 +12914,7 @@ inline LeakyReluOptionsT *LeakyReluOptions::UnPack(const flatbuffers::resolver_f inline void LeakyReluOptions::UnPackTo(LeakyReluOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = alpha(); _o->alpha = _e; }; + { auto _e = alpha(); _o->alpha = _e; } } inline flatbuffers::Offset LeakyReluOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LeakyReluOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12963,7 +12963,7 @@ inline MirrorPadOptionsT *MirrorPadOptions::UnPack(const flatbuffers::resolver_f inline void MirrorPadOptions::UnPackTo(MirrorPadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = mode(); _o->mode = _e; }; + { auto _e = mode(); _o->mode = _e; } } inline flatbuffers::Offset MirrorPadOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -12989,7 +12989,7 @@ inline UniqueOptionsT *UniqueOptions::UnPack(const flatbuffers::resolver_functio inline void UniqueOptions::UnPackTo(UniqueOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = idx_out_type(); _o->idx_out_type = _e; }; + { auto _e = idx_out_type(); _o->idx_out_type = _e; } } inline flatbuffers::Offset UniqueOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13107,8 +13107,8 @@ inline ReverseSequenceOptionsT *ReverseSequenceOptions::UnPack(const flatbuffers inline void ReverseSequenceOptions::UnPackTo(ReverseSequenceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = seq_dim(); _o->seq_dim = _e; }; - { auto _e = batch_dim(); _o->batch_dim = _e; }; + { auto _e = seq_dim(); _o->seq_dim = _e; } + { auto _e = batch_dim(); _o->batch_dim = _e; } } inline flatbuffers::Offset ReverseSequenceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13205,8 +13205,8 @@ inline IfOptionsT *IfOptions::UnPack(const flatbuffers::resolver_function_t *_re inline void IfOptions::UnPackTo(IfOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = then_subgraph_index(); _o->then_subgraph_index = _e; }; - { auto _e = else_subgraph_index(); _o->else_subgraph_index = _e; }; + { auto _e = then_subgraph_index(); _o->then_subgraph_index = _e; } + { auto _e = else_subgraph_index(); _o->else_subgraph_index = _e; } } inline flatbuffers::Offset IfOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13234,8 +13234,8 @@ inline WhileOptionsT *WhileOptions::UnPack(const flatbuffers::resolver_function_ inline void WhileOptions::UnPackTo(WhileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = cond_subgraph_index(); _o->cond_subgraph_index = _e; }; - { auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; }; + { auto _e = cond_subgraph_index(); _o->cond_subgraph_index = _e; } + { auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; } } inline flatbuffers::Offset WhileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13401,9 +13401,9 @@ inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_ inline void OperatorCode::UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = builtin_code(); _o->builtin_code = _e; }; - { auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); }; - { auto _e = version(); _o->version = _e; }; + { auto _e = builtin_code(); _o->builtin_code = _e; } + { auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); } + { auto _e = version(); _o->version = _e; } } inline flatbuffers::Offset OperatorCode::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13433,15 +13433,15 @@ inline OperatorT *Operator::UnPack(const flatbuffers::resolver_function_t *_reso inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = opcode_index(); _o->opcode_index = _e; }; - { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } }; - { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } }; - { auto _e = builtin_options_type(); _o->builtin_options.type = _e; }; - { auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); }; - { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } }; - { auto _e = custom_options_format(); _o->custom_options_format = _e; }; - { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } }; - { auto _e = intermediates(); if (_e) { _o->intermediates.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->intermediates[_i] = _e->Get(_i); } } }; + { auto _e = opcode_index(); _o->opcode_index = _e; } + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } } + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } } + { auto _e = builtin_options_type(); _o->builtin_options.type = _e; } + { auto _e = builtin_options(); if (_e) _o->builtin_options.value = tflite::BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); } + { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } } + { auto _e = custom_options_format(); _o->custom_options_format = _e; } + { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } } + { auto _e = intermediates(); if (_e) { _o->intermediates.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->intermediates[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13483,11 +13483,11 @@ inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_reso inline void SubGraph::UnPackTo(SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } }; - { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } }; - { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = name(); if (_e) _o->name = _e->str(); }; + { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } } + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } } + { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } + { auto _e = name(); if (_e) _o->name = _e->str(); } } inline flatbuffers::Offset SubGraph::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13498,10 +13498,10 @@ inline flatbuffers::Offset CreateSubGraph(flatbuffers::FlatBufferBuild (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SubGraphT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _tensors = _o->tensors.size() ? _fbb.CreateVector> (_o->tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _tensors = _o->tensors.size() ? _fbb.CreateVector> (_o->tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), __va->__rehasher); }, &_va ) : 0; auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; - auto _operators = _o->operators.size() ? _fbb.CreateVector> (_o->operators.size(), [](size_t i, _VectorArgs *__va) { return CreateOperator(*__va->__fbb, __va->__o->operators[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _operators = _o->operators.size() ? _fbb.CreateVector> (_o->operators.size(), [](size_t i, _VectorArgs *__va) { return CreateOperator(*__va->__fbb, __va->__o->operators[i].get(), __va->__rehasher); }, &_va ) : 0; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); return tflite::CreateSubGraph( _fbb, @@ -13521,7 +13521,7 @@ inline BufferT *Buffer::UnPack(const flatbuffers::resolver_function_t *_resolver inline void Buffer::UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = data(); if (_e) { _o->data.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->data[_i] = _e->Get(_i); } } }; + { auto _e = data(); if (_e) { _o->data.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->data[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset Buffer::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13548,8 +13548,8 @@ inline MetadataT *Metadata::UnPack(const flatbuffers::resolver_function_t *_reso inline void Metadata::UnPackTo(MetadataT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = name(); if (_e) _o->name = _e->str(); }; - { auto _e = buffer(); _o->buffer = _e; }; + { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = buffer(); _o->buffer = _e; } } inline flatbuffers::Offset Metadata::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MetadataT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13577,13 +13577,13 @@ inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = version(); _o->version = _e; }; - { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = description(); if (_e) _o->description = _e->str(); }; - { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = metadata_buffer(); if (_e) { _o->metadata_buffer.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata_buffer[_i] = _e->Get(_i); } } }; - { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = version(); _o->version = _e; } + { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } + { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } + { auto _e = description(); if (_e) _o->description = _e->str(); } + { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } + { auto _e = metadata_buffer(); if (_e) { _o->metadata_buffer.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata_buffer[_i] = _e->Get(_i); } } } + { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } } inline flatbuffers::Offset Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13595,12 +13595,12 @@ inline flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_f (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _version = _o->version; - auto _operator_codes = _o->operator_codes.size() ? _fbb.CreateVector> (_o->operator_codes.size(), [](size_t i, _VectorArgs *__va) { return CreateOperatorCode(*__va->__fbb, __va->__o->operator_codes[i].get(), __va->__rehasher); }, &_va ) : 0; - auto _subgraphs = _o->subgraphs.size() ? _fbb.CreateVector> (_o->subgraphs.size(), [](size_t i, _VectorArgs *__va) { return CreateSubGraph(*__va->__fbb, __va->__o->subgraphs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _operator_codes = _o->operator_codes.size() ? _fbb.CreateVector> (_o->operator_codes.size(), [](size_t i, _VectorArgs *__va) { return CreateOperatorCode(*__va->__fbb, __va->__o->operator_codes[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _subgraphs = _o->subgraphs.size() ? _fbb.CreateVector> (_o->subgraphs.size(), [](size_t i, _VectorArgs *__va) { return CreateSubGraph(*__va->__fbb, __va->__o->subgraphs[i].get(), __va->__rehasher); }, &_va ) : 0; auto _description = _o->description.empty() ? 0 : _fbb.CreateString(_o->description); - auto _buffers = _o->buffers.size() ? _fbb.CreateVector> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _buffers = _o->buffers.size() ? _fbb.CreateVector> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0; auto _metadata_buffer = _o->metadata_buffer.size() ? _fbb.CreateVector(_o->metadata_buffer) : 0; - auto _metadata = _o->metadata.size() ? _fbb.CreateVector> (_o->metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateMetadata(*__va->__fbb, __va->__o->metadata[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _metadata = _o->metadata.size() ? _fbb.CreateVector> (_o->metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateMetadata(*__va->__fbb, __va->__o->metadata[i].get(), __va->__rehasher); }, &_va ) : 0; return tflite::CreateModel( _fbb, _version, @@ -13618,7 +13618,7 @@ inline bool VerifyQuantizationDetails(flatbuffers::Verifier &verifier, const voi return true; } case QuantizationDetails_CustomQuantization: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } default: return true; @@ -13640,7 +13640,7 @@ inline bool VerifyQuantizationDetailsVector(flatbuffers::Verifier &verifier, con inline void *QuantizationDetailsUnion::UnPack(const void *obj, QuantizationDetails type, const flatbuffers::resolver_function_t *resolver) { switch (type) { case QuantizationDetails_CustomQuantization: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } default: return nullptr; @@ -13650,7 +13650,7 @@ inline void *QuantizationDetailsUnion::UnPack(const void *obj, QuantizationDetai inline flatbuffers::Offset QuantizationDetailsUnion::Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher) const { switch (type) { case QuantizationDetails_CustomQuantization: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateCustomQuantization(_fbb, ptr, _rehasher).Union(); } default: return 0; @@ -13660,7 +13660,7 @@ inline flatbuffers::Offset QuantizationDetailsUnion::Pack(flatbuffers::Fla inline QuantizationDetailsUnion::QuantizationDetailsUnion(const QuantizationDetailsUnion &u) FLATBUFFERS_NOEXCEPT : type(u.type), value(nullptr) { switch (type) { case QuantizationDetails_CustomQuantization: { - value = new CustomQuantizationT(*reinterpret_cast(u.value)); + value = new tflite::CustomQuantizationT(*reinterpret_cast(u.value)); break; } default: @@ -13671,7 +13671,7 @@ inline QuantizationDetailsUnion::QuantizationDetailsUnion(const QuantizationDeta inline void QuantizationDetailsUnion::Reset() { switch (type) { case QuantizationDetails_CustomQuantization: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } @@ -13687,15 +13687,15 @@ inline bool VerifySparseIndexVector(flatbuffers::Verifier &verifier, const void return true; } case SparseIndexVector_Int32Vector: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case SparseIndexVector_Uint16Vector: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case SparseIndexVector_Uint8Vector: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } default: return true; @@ -13717,15 +13717,15 @@ inline bool VerifySparseIndexVectorVector(flatbuffers::Verifier &verifier, const inline void *SparseIndexVectorUnion::UnPack(const void *obj, SparseIndexVector type, const flatbuffers::resolver_function_t *resolver) { switch (type) { case SparseIndexVector_Int32Vector: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case SparseIndexVector_Uint16Vector: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case SparseIndexVector_Uint8Vector: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } default: return nullptr; @@ -13735,15 +13735,15 @@ inline void *SparseIndexVectorUnion::UnPack(const void *obj, SparseIndexVector t inline flatbuffers::Offset SparseIndexVectorUnion::Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher) const { switch (type) { case SparseIndexVector_Int32Vector: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateInt32Vector(_fbb, ptr, _rehasher).Union(); } case SparseIndexVector_Uint16Vector: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateUint16Vector(_fbb, ptr, _rehasher).Union(); } case SparseIndexVector_Uint8Vector: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateUint8Vector(_fbb, ptr, _rehasher).Union(); } default: return 0; @@ -13753,15 +13753,15 @@ inline flatbuffers::Offset SparseIndexVectorUnion::Pack(flatbuffers::FlatB inline SparseIndexVectorUnion::SparseIndexVectorUnion(const SparseIndexVectorUnion &u) FLATBUFFERS_NOEXCEPT : type(u.type), value(nullptr) { switch (type) { case SparseIndexVector_Int32Vector: { - value = new Int32VectorT(*reinterpret_cast(u.value)); + value = new tflite::Int32VectorT(*reinterpret_cast(u.value)); break; } case SparseIndexVector_Uint16Vector: { - value = new Uint16VectorT(*reinterpret_cast(u.value)); + value = new tflite::Uint16VectorT(*reinterpret_cast(u.value)); break; } case SparseIndexVector_Uint8Vector: { - value = new Uint8VectorT(*reinterpret_cast(u.value)); + value = new tflite::Uint8VectorT(*reinterpret_cast(u.value)); break; } default: @@ -13772,17 +13772,17 @@ inline SparseIndexVectorUnion::SparseIndexVectorUnion(const SparseIndexVectorUni inline void SparseIndexVectorUnion::Reset() { switch (type) { case SparseIndexVector_Int32Vector: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case SparseIndexVector_Uint16Vector: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case SparseIndexVector_Uint8Vector: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } @@ -13798,403 +13798,403 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob return true; } case BuiltinOptions_Conv2DOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_DepthwiseConv2DOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ConcatEmbeddingsOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LSHProjectionOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_Pool2DOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SVDFOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_RNNOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_FullyConnectedOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SoftmaxOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ConcatenationOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_AddOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_L2NormOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LSTMOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ResizeBilinearOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_CallOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ReshapeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SkipGramOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SpaceToDepthOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_EmbeddingLookupSparseOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_MulOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_PadOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_GatherOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_BatchToSpaceNDOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SpaceToBatchNDOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_TransposeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ReducerOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SubOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_DivOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SqueezeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SequenceRNNOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_StridedSliceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ExpOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_TopKV2Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SplitOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LogSoftmaxOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_CastOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_DequantizeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_MaximumMinimumOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ArgMaxOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LessOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_NegOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_PadV2Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_GreaterOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_GreaterEqualOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LessEqualOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SelectOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SliceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_TransposeConvOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SparseToDenseOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_TileOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ExpandDimsOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_EqualOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_NotEqualOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ShapeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_PowOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ArgMinOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_FakeQuantOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_PackOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LogicalOrOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_OneHotOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LogicalAndOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LogicalNotOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_UnpackOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_FloorDivOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SquareOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ZerosLikeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_FillOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_BidirectionalSequenceLSTMOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_BidirectionalSequenceRNNOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_FloorModOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_RangeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ResizeNearestNeighborOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LeakyReluOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SquaredDifferenceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_MirrorPadOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_AbsOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SplitVOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_UniqueOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ReverseV2Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_AddNOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_GatherNdOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_CosOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_WhereOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_RankOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ReverseSequenceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_MatrixDiagOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_QuantizeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_MatrixSetDiagOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_HardSwishOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_IfOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_WhileOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_DepthToSpaceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_NonMaxSuppressionV4Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_NonMaxSuppressionV5Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_ScatterNdOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SelectV2Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_DensifyOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SegmentSumOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } default: return true; @@ -14216,403 +14216,403 @@ inline bool VerifyBuiltinOptionsVector(flatbuffers::Verifier &verifier, const fl inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, const flatbuffers::resolver_function_t *resolver) { switch (type) { case BuiltinOptions_Conv2DOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_DepthwiseConv2DOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ConcatEmbeddingsOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LSHProjectionOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_Pool2DOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SVDFOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_RNNOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_FullyConnectedOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SoftmaxOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ConcatenationOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_AddOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_L2NormOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LSTMOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ResizeBilinearOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_CallOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ReshapeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SkipGramOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SpaceToDepthOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_EmbeddingLookupSparseOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_MulOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_PadOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_GatherOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_BatchToSpaceNDOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SpaceToBatchNDOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_TransposeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ReducerOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SubOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_DivOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SqueezeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SequenceRNNOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_StridedSliceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ExpOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_TopKV2Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SplitOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LogSoftmaxOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_CastOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_DequantizeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_MaximumMinimumOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ArgMaxOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LessOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_NegOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_PadV2Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_GreaterOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_GreaterEqualOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LessEqualOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SelectOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SliceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_TransposeConvOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SparseToDenseOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_TileOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ExpandDimsOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_EqualOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_NotEqualOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ShapeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_PowOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ArgMinOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_FakeQuantOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_PackOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LogicalOrOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_OneHotOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LogicalAndOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LogicalNotOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_UnpackOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_FloorDivOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SquareOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ZerosLikeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_FillOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_BidirectionalSequenceLSTMOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_BidirectionalSequenceRNNOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_FloorModOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_RangeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ResizeNearestNeighborOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LeakyReluOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SquaredDifferenceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_MirrorPadOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_AbsOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SplitVOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_UniqueOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ReverseV2Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_AddNOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_GatherNdOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_CosOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_WhereOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_RankOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ReverseSequenceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_MatrixDiagOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_QuantizeOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_MatrixSetDiagOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_HardSwishOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_IfOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_WhileOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_DepthToSpaceOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_NonMaxSuppressionV4Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_NonMaxSuppressionV5Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_ScatterNdOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SelectV2Options: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_DensifyOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SegmentSumOptions: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } default: return nullptr; @@ -14622,403 +14622,403 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher) const { switch (type) { case BuiltinOptions_Conv2DOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateConv2DOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_DepthwiseConv2DOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateDepthwiseConv2DOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ConcatEmbeddingsOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateConcatEmbeddingsOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LSHProjectionOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLSHProjectionOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_Pool2DOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreatePool2DOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SVDFOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSVDFOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_RNNOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateRNNOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_FullyConnectedOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateFullyConnectedOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SoftmaxOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSoftmaxOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ConcatenationOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateConcatenationOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_AddOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateAddOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_L2NormOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateL2NormOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLocalResponseNormalizationOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LSTMOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLSTMOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ResizeBilinearOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateResizeBilinearOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_CallOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateCallOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ReshapeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateReshapeOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SkipGramOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSkipGramOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SpaceToDepthOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSpaceToDepthOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_EmbeddingLookupSparseOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateEmbeddingLookupSparseOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_MulOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateMulOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_PadOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreatePadOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_GatherOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateGatherOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_BatchToSpaceNDOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateBatchToSpaceNDOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SpaceToBatchNDOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSpaceToBatchNDOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_TransposeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateTransposeOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ReducerOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateReducerOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SubOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSubOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_DivOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateDivOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SqueezeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSqueezeOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SequenceRNNOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_StridedSliceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ExpOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateExpOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_TopKV2Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateTopKV2Options(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SplitOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSplitOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LogSoftmaxOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLogSoftmaxOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_CastOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateCastOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_DequantizeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateDequantizeOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_MaximumMinimumOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateMaximumMinimumOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ArgMaxOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateArgMaxOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LessOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLessOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_NegOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateNegOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_PadV2Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreatePadV2Options(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_GreaterOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateGreaterOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_GreaterEqualOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateGreaterEqualOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LessEqualOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLessEqualOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SelectOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSelectOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SliceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSliceOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_TransposeConvOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateTransposeConvOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SparseToDenseOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_TileOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateTileOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ExpandDimsOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_EqualOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateEqualOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_NotEqualOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ShapeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateShapeOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_PowOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreatePowOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ArgMinOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateArgMinOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_FakeQuantOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateFakeQuantOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_PackOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreatePackOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LogicalOrOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_OneHotOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateOneHotOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LogicalAndOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLogicalAndOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LogicalNotOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLogicalNotOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_UnpackOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateUnpackOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_FloorDivOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateFloorDivOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SquareOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSquareOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ZerosLikeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateZerosLikeOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_FillOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateFillOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_BidirectionalSequenceLSTMOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateBidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_BidirectionalSequenceRNNOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateUnidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_FloorModOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateFloorModOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_RangeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateRangeOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ResizeNearestNeighborOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateResizeNearestNeighborOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LeakyReluOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateLeakyReluOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SquaredDifferenceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSquaredDifferenceOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_MirrorPadOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateMirrorPadOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_AbsOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateAbsOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SplitVOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSplitVOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_UniqueOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateUniqueOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ReverseV2Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateReverseV2Options(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_AddNOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateAddNOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_GatherNdOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateGatherNdOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_CosOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateCosOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_WhereOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateWhereOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_RankOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateRankOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ReverseSequenceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateReverseSequenceOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_MatrixDiagOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateMatrixDiagOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_QuantizeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_MatrixSetDiagOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateMatrixSetDiagOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_HardSwishOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateHardSwishOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_IfOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateIfOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_WhileOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateWhileOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_DepthToSpaceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateDepthToSpaceOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_NonMaxSuppressionV4Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateNonMaxSuppressionV4Options(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_NonMaxSuppressionV5Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateNonMaxSuppressionV5Options(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_ScatterNdOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateScatterNdOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SelectV2Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSelectV2Options(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_DensifyOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateDensifyOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SegmentSumOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union(); } default: return 0; @@ -15028,403 +15028,403 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FLATBUFFERS_NOEXCEPT : type(u.type), value(nullptr) { switch (type) { case BuiltinOptions_Conv2DOptions: { - value = new Conv2DOptionsT(*reinterpret_cast(u.value)); + value = new tflite::Conv2DOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_DepthwiseConv2DOptions: { - value = new DepthwiseConv2DOptionsT(*reinterpret_cast(u.value)); + value = new tflite::DepthwiseConv2DOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ConcatEmbeddingsOptions: { - value = new ConcatEmbeddingsOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ConcatEmbeddingsOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LSHProjectionOptions: { - value = new LSHProjectionOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LSHProjectionOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_Pool2DOptions: { - value = new Pool2DOptionsT(*reinterpret_cast(u.value)); + value = new tflite::Pool2DOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SVDFOptions: { - value = new SVDFOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SVDFOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_RNNOptions: { - value = new RNNOptionsT(*reinterpret_cast(u.value)); + value = new tflite::RNNOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_FullyConnectedOptions: { - value = new FullyConnectedOptionsT(*reinterpret_cast(u.value)); + value = new tflite::FullyConnectedOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SoftmaxOptions: { - value = new SoftmaxOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SoftmaxOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ConcatenationOptions: { - value = new ConcatenationOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ConcatenationOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_AddOptions: { - value = new AddOptionsT(*reinterpret_cast(u.value)); + value = new tflite::AddOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_L2NormOptions: { - value = new L2NormOptionsT(*reinterpret_cast(u.value)); + value = new tflite::L2NormOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LocalResponseNormalizationOptions: { - value = new LocalResponseNormalizationOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LocalResponseNormalizationOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LSTMOptions: { - value = new LSTMOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LSTMOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ResizeBilinearOptions: { - value = new ResizeBilinearOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ResizeBilinearOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_CallOptions: { - value = new CallOptionsT(*reinterpret_cast(u.value)); + value = new tflite::CallOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ReshapeOptions: { - value = new ReshapeOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ReshapeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SkipGramOptions: { - value = new SkipGramOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SkipGramOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SpaceToDepthOptions: { - value = new SpaceToDepthOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SpaceToDepthOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_EmbeddingLookupSparseOptions: { - value = new EmbeddingLookupSparseOptionsT(*reinterpret_cast(u.value)); + value = new tflite::EmbeddingLookupSparseOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MulOptions: { - value = new MulOptionsT(*reinterpret_cast(u.value)); + value = new tflite::MulOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_PadOptions: { - value = new PadOptionsT(*reinterpret_cast(u.value)); + value = new tflite::PadOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_GatherOptions: { - value = new GatherOptionsT(*reinterpret_cast(u.value)); + value = new tflite::GatherOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_BatchToSpaceNDOptions: { - value = new BatchToSpaceNDOptionsT(*reinterpret_cast(u.value)); + value = new tflite::BatchToSpaceNDOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SpaceToBatchNDOptions: { - value = new SpaceToBatchNDOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SpaceToBatchNDOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_TransposeOptions: { - value = new TransposeOptionsT(*reinterpret_cast(u.value)); + value = new tflite::TransposeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ReducerOptions: { - value = new ReducerOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ReducerOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SubOptions: { - value = new SubOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SubOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_DivOptions: { - value = new DivOptionsT(*reinterpret_cast(u.value)); + value = new tflite::DivOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SqueezeOptions: { - value = new SqueezeOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SqueezeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SequenceRNNOptions: { - value = new SequenceRNNOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SequenceRNNOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_StridedSliceOptions: { - value = new StridedSliceOptionsT(*reinterpret_cast(u.value)); + value = new tflite::StridedSliceOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ExpOptions: { - value = new ExpOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ExpOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_TopKV2Options: { - value = new TopKV2OptionsT(*reinterpret_cast(u.value)); + value = new tflite::TopKV2OptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SplitOptions: { - value = new SplitOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SplitOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LogSoftmaxOptions: { - value = new LogSoftmaxOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LogSoftmaxOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_CastOptions: { - value = new CastOptionsT(*reinterpret_cast(u.value)); + value = new tflite::CastOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_DequantizeOptions: { - value = new DequantizeOptionsT(*reinterpret_cast(u.value)); + value = new tflite::DequantizeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MaximumMinimumOptions: { - value = new MaximumMinimumOptionsT(*reinterpret_cast(u.value)); + value = new tflite::MaximumMinimumOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ArgMaxOptions: { - value = new ArgMaxOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ArgMaxOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LessOptions: { - value = new LessOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LessOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_NegOptions: { - value = new NegOptionsT(*reinterpret_cast(u.value)); + value = new tflite::NegOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_PadV2Options: { - value = new PadV2OptionsT(*reinterpret_cast(u.value)); + value = new tflite::PadV2OptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_GreaterOptions: { - value = new GreaterOptionsT(*reinterpret_cast(u.value)); + value = new tflite::GreaterOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_GreaterEqualOptions: { - value = new GreaterEqualOptionsT(*reinterpret_cast(u.value)); + value = new tflite::GreaterEqualOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LessEqualOptions: { - value = new LessEqualOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LessEqualOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SelectOptions: { - value = new SelectOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SelectOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SliceOptions: { - value = new SliceOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SliceOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_TransposeConvOptions: { - value = new TransposeConvOptionsT(*reinterpret_cast(u.value)); + value = new tflite::TransposeConvOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SparseToDenseOptions: { - value = new SparseToDenseOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SparseToDenseOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_TileOptions: { - value = new TileOptionsT(*reinterpret_cast(u.value)); + value = new tflite::TileOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ExpandDimsOptions: { - value = new ExpandDimsOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ExpandDimsOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_EqualOptions: { - value = new EqualOptionsT(*reinterpret_cast(u.value)); + value = new tflite::EqualOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_NotEqualOptions: { - value = new NotEqualOptionsT(*reinterpret_cast(u.value)); + value = new tflite::NotEqualOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ShapeOptions: { - value = new ShapeOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ShapeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_PowOptions: { - value = new PowOptionsT(*reinterpret_cast(u.value)); + value = new tflite::PowOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ArgMinOptions: { - value = new ArgMinOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ArgMinOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_FakeQuantOptions: { - value = new FakeQuantOptionsT(*reinterpret_cast(u.value)); + value = new tflite::FakeQuantOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_PackOptions: { - value = new PackOptionsT(*reinterpret_cast(u.value)); + value = new tflite::PackOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LogicalOrOptions: { - value = new LogicalOrOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LogicalOrOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_OneHotOptions: { - value = new OneHotOptionsT(*reinterpret_cast(u.value)); + value = new tflite::OneHotOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LogicalAndOptions: { - value = new LogicalAndOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LogicalAndOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LogicalNotOptions: { - value = new LogicalNotOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LogicalNotOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_UnpackOptions: { - value = new UnpackOptionsT(*reinterpret_cast(u.value)); + value = new tflite::UnpackOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_FloorDivOptions: { - value = new FloorDivOptionsT(*reinterpret_cast(u.value)); + value = new tflite::FloorDivOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SquareOptions: { - value = new SquareOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SquareOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ZerosLikeOptions: { - value = new ZerosLikeOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ZerosLikeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_FillOptions: { - value = new FillOptionsT(*reinterpret_cast(u.value)); + value = new tflite::FillOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_BidirectionalSequenceLSTMOptions: { - value = new BidirectionalSequenceLSTMOptionsT(*reinterpret_cast(u.value)); + value = new tflite::BidirectionalSequenceLSTMOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_BidirectionalSequenceRNNOptions: { - value = new BidirectionalSequenceRNNOptionsT(*reinterpret_cast(u.value)); + value = new tflite::BidirectionalSequenceRNNOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { - value = new UnidirectionalSequenceLSTMOptionsT(*reinterpret_cast(u.value)); + value = new tflite::UnidirectionalSequenceLSTMOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_FloorModOptions: { - value = new FloorModOptionsT(*reinterpret_cast(u.value)); + value = new tflite::FloorModOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_RangeOptions: { - value = new RangeOptionsT(*reinterpret_cast(u.value)); + value = new tflite::RangeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ResizeNearestNeighborOptions: { - value = new ResizeNearestNeighborOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ResizeNearestNeighborOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LeakyReluOptions: { - value = new LeakyReluOptionsT(*reinterpret_cast(u.value)); + value = new tflite::LeakyReluOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SquaredDifferenceOptions: { - value = new SquaredDifferenceOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SquaredDifferenceOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MirrorPadOptions: { - value = new MirrorPadOptionsT(*reinterpret_cast(u.value)); + value = new tflite::MirrorPadOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_AbsOptions: { - value = new AbsOptionsT(*reinterpret_cast(u.value)); + value = new tflite::AbsOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SplitVOptions: { - value = new SplitVOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SplitVOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_UniqueOptions: { - value = new UniqueOptionsT(*reinterpret_cast(u.value)); + value = new tflite::UniqueOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ReverseV2Options: { - value = new ReverseV2OptionsT(*reinterpret_cast(u.value)); + value = new tflite::ReverseV2OptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_AddNOptions: { - value = new AddNOptionsT(*reinterpret_cast(u.value)); + value = new tflite::AddNOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_GatherNdOptions: { - value = new GatherNdOptionsT(*reinterpret_cast(u.value)); + value = new tflite::GatherNdOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_CosOptions: { - value = new CosOptionsT(*reinterpret_cast(u.value)); + value = new tflite::CosOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_WhereOptions: { - value = new WhereOptionsT(*reinterpret_cast(u.value)); + value = new tflite::WhereOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_RankOptions: { - value = new RankOptionsT(*reinterpret_cast(u.value)); + value = new tflite::RankOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ReverseSequenceOptions: { - value = new ReverseSequenceOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ReverseSequenceOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MatrixDiagOptions: { - value = new MatrixDiagOptionsT(*reinterpret_cast(u.value)); + value = new tflite::MatrixDiagOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_QuantizeOptions: { - value = new QuantizeOptionsT(*reinterpret_cast(u.value)); + value = new tflite::QuantizeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MatrixSetDiagOptions: { - value = new MatrixSetDiagOptionsT(*reinterpret_cast(u.value)); + value = new tflite::MatrixSetDiagOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_HardSwishOptions: { - value = new HardSwishOptionsT(*reinterpret_cast(u.value)); + value = new tflite::HardSwishOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_IfOptions: { - value = new IfOptionsT(*reinterpret_cast(u.value)); + value = new tflite::IfOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_WhileOptions: { - value = new WhileOptionsT(*reinterpret_cast(u.value)); + value = new tflite::WhileOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_DepthToSpaceOptions: { - value = new DepthToSpaceOptionsT(*reinterpret_cast(u.value)); + value = new tflite::DepthToSpaceOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_NonMaxSuppressionV4Options: { - value = new NonMaxSuppressionV4OptionsT(*reinterpret_cast(u.value)); + value = new tflite::NonMaxSuppressionV4OptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_NonMaxSuppressionV5Options: { - value = new NonMaxSuppressionV5OptionsT(*reinterpret_cast(u.value)); + value = new tflite::NonMaxSuppressionV5OptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ScatterNdOptions: { - value = new ScatterNdOptionsT(*reinterpret_cast(u.value)); + value = new tflite::ScatterNdOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SelectV2Options: { - value = new SelectV2OptionsT(*reinterpret_cast(u.value)); + value = new tflite::SelectV2OptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_DensifyOptions: { - value = new DensifyOptionsT(*reinterpret_cast(u.value)); + value = new tflite::DensifyOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SegmentSumOptions: { - value = new SegmentSumOptionsT(*reinterpret_cast(u.value)); + value = new tflite::SegmentSumOptionsT(*reinterpret_cast(u.value)); break; } default: @@ -15435,502 +15435,502 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL inline void BuiltinOptionsUnion::Reset() { switch (type) { case BuiltinOptions_Conv2DOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_DepthwiseConv2DOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ConcatEmbeddingsOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LSHProjectionOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_Pool2DOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SVDFOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_RNNOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_FullyConnectedOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SoftmaxOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ConcatenationOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_AddOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_L2NormOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LSTMOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ResizeBilinearOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_CallOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ReshapeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SkipGramOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SpaceToDepthOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_EmbeddingLookupSparseOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_MulOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_PadOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_GatherOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_BatchToSpaceNDOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SpaceToBatchNDOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_TransposeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ReducerOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SubOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_DivOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SqueezeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SequenceRNNOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_StridedSliceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ExpOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_TopKV2Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SplitOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LogSoftmaxOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_CastOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_DequantizeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_MaximumMinimumOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ArgMaxOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LessOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_NegOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_PadV2Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_GreaterOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_GreaterEqualOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LessEqualOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SelectOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SliceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_TransposeConvOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SparseToDenseOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_TileOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ExpandDimsOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_EqualOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_NotEqualOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ShapeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_PowOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ArgMinOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_FakeQuantOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_PackOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LogicalOrOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_OneHotOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LogicalAndOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LogicalNotOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_UnpackOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_FloorDivOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SquareOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ZerosLikeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_FillOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_BidirectionalSequenceLSTMOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_BidirectionalSequenceRNNOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_FloorModOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_RangeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ResizeNearestNeighborOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_LeakyReluOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SquaredDifferenceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_MirrorPadOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_AbsOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SplitVOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_UniqueOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ReverseV2Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_AddNOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_GatherNdOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_CosOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_WhereOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_RankOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ReverseSequenceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_MatrixDiagOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_QuantizeOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_MatrixSetDiagOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_HardSwishOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_IfOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_WhileOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_DepthToSpaceOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_NonMaxSuppressionV4Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_NonMaxSuppressionV5Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_ScatterNdOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SelectV2Options: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_DensifyOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } case BuiltinOptions_SegmentSumOptions: { - auto ptr = reinterpret_cast(value); + auto ptr = reinterpret_cast(value); delete ptr; break; } @@ -15983,10 +15983,16 @@ inline void FinishSizePrefixedModelBuffer( fbb.FinishSizePrefixed(root, ModelIdentifier()); } -inline std::unique_ptr UnPackModel( +inline std::unique_ptr UnPackModel( const void *buf, const flatbuffers::resolver_function_t *res = nullptr) { - return std::unique_ptr(GetModel(buf)->UnPack(res)); + return std::unique_ptr(GetModel(buf)->UnPack(res)); +} + +inline std::unique_ptr UnPackSizePrefixedModel( + const void *buf, + const flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetSizePrefixedModel(buf)->UnPack(res)); } } // namespace tflite From c372452782693eaa194a260f3ec5d969dfd216de Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 18 Mar 2020 16:22:56 -0700 Subject: [PATCH 181/492] Make Div op not quantizable PiperOrigin-RevId: 301692244 Change-Id: I993e7dd806ac66169777f21ceb16efdd8d0d46cd --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 5624c7e2b73..96eb69f7c8f 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1083,7 +1083,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; } -def TFL_DivOp : TFL_Op<"div", [ResultsBroadcastableShape, NoSideEffect]> { +def TFL_DivOp : TFL_Op<"div", [ + ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { let summary = "Division operator"; let description = [{ From 2774308165496ef789306f4dd6c0e1c307dd1b81 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Wed, 18 Mar 2020 16:23:25 -0700 Subject: [PATCH 182/492] Override ConstantLiteral method in MlirHloBuilder Enable TF::InvOp that uses constant op for testing. PiperOrigin-RevId: 301692318 Change-Id: I03704d256f1293fd8418bb597657417ca8395585 --- tensorflow/compiler/mlir/xla/hlo_utils.cc | 8 ++++---- tensorflow/compiler/mlir/xla/hlo_utils.h | 2 +- tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc | 9 +++++++++ tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h | 2 ++ .../mlir/xla/tests/legalize-tf-with-tf2xla.mlir | 11 +++++++++++ .../mlir/xla/transforms/legalize_tf_with_tf2xla.cc | 4 ++-- 6 files changed, 29 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index a0ce8a796cb..e0c5c4a00f0 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -30,12 +30,12 @@ using mlir::AffineMap; using mlir::Builder; using mlir::DenseElementsAttr; using mlir::ShapedType; -using xla::Literal; +using xla::LiteralBase; using xla::StatusOr; template -::mlir::DenseElementsAttr CreateDenseAttrFromLiteral(const ShapedType& type, - const Literal& literal) { +::mlir::DenseElementsAttr CreateDenseAttrFromLiteral( + const ShapedType& type, const LiteralBase& literal) { auto data_span = literal.data(); return ::mlir::DenseElementsAttr::get( type, llvm::makeArrayRef(data_span.data(), data_span.size())); @@ -78,7 +78,7 @@ StatusOr ConvertTensorShapeToMemRefType( } StatusOr CreateDenseElementsAttrFromLiteral( - const Literal& literal, Builder builder) { + const LiteralBase& literal, Builder builder) { TF_ASSIGN_OR_RETURN(auto type, ConvertTensorShapeToType( literal.shape(), builder)); diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.h b/tensorflow/compiler/mlir/xla/hlo_utils.h index f4acc2484f0..003eda0b992 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/hlo_utils.h @@ -28,7 +28,7 @@ limitations under the License. namespace xla { StatusOr CreateDenseElementsAttrFromLiteral( - const Literal& literal, mlir::Builder builder); + const LiteralBase& literal, mlir::Builder builder); mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( const llvm::ArrayRef vector, mlir::Builder builder); diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 1573810bc90..dfb9ec4e837 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -67,6 +67,15 @@ StatusOr MlirHloBuilder::MakeXlaOp(mlir::Value val) { return XlaOp(handle, this); } +XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr, + CreateDenseElementsAttrFromLiteral(literal, builder_)); + auto op = builder_.create(loc_, attr); + return MakeXlaOp(op); + }); +} + StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) { diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 9bebbc025a5..232d1fa84e9 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -88,6 +88,8 @@ class MlirHloBuilder : public XlaBuilder { StatusOr GetShapePtr(XlaOp op) const override; private: + XlaOp ConstantLiteral(const LiteralSlice& literal) override; + StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) override; diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index f2dff2c9956..53df0d0a0fc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -79,6 +79,17 @@ func @convert(%arg0: tensor<2xi32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +// CHECK-LABEL: func @constant +func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: %[[SCALAR_ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[ONE:.*]] = "xla_hlo.broadcast_in_dim"(%[[SCALAR_ONE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + // CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[ONE]], %arg0 : tensor<2xf32> + // CHECK: return %[[RESULT]] + + %0 = "tf.Inv"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + // TODO(hinsu): Add a test with variant type once one of the ops supporting // the type is whitelisted. It should be rejected with unsupported type remark. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 962bf97c44d..913fc678558 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -77,8 +77,8 @@ static bool IsOpWhitelisted(Operation* op) { // building valid MLIR using MlirHloBuilder. // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for // all tf2xla kernels. - return isa(op) || isa(op) || - isa(op) || isa(op); + return isa(op) || isa(op) || isa(op) || + isa(op) || isa(op); } static llvm::Optional GetJitDevice( From fd82b0d889155002fc76eea7693c8f729255c65b Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Wed, 18 Mar 2020 16:56:37 -0700 Subject: [PATCH 183/492] Added AsString, Print, and PrintV2 to Tensorflow Generated Ops PiperOrigin-RevId: 301698279 Change-Id: I8a462226c736a053647c4644b7422b61db99818c --- .../mlir/tensorflow/ir/tf_generated_ops.td | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 39b24ad353f..3592fa62a25 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -243,6 +243,41 @@ Usage: TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } +def TF_AsStringOp : TF_Op<"AsString", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "Converts each entry in the given tensor to strings."; + + let description = [{ +Supports many numeric types and boolean. + +For Unicode, see the +[https://www.tensorflow.org/tutorials/representation/unicode](Working with Unicode text) +tutorial. + +Examples: + +>>> tf.strings.as_string([3, 2]) + +>>> tf.strings.as_string([3.1415926, 2.71828], precision=2).numpy() +array([b'3.14', b'2.72'], dtype=object) + }]; + + let arguments = (ins + TensorOf<[F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$input, + + DefaultValuedAttr:$precision, + DefaultValuedAttr:$scientific, + DefaultValuedAttr:$shortest, + DefaultValuedAttr:$width, + StrAttr:$fill + ); + + let results = (outs + TF_StrTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AssertOp : TF_Op<"Assert", []> { let summary = "Asserts that the given condition is true."; @@ -4611,6 +4646,23 @@ gradients in some corner cases. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_PrintV2Op : TF_Op<"PrintV2", []> { + let summary = "Prints a string scalar."; + + let description = [{ +Prints a string scalar to the desired output_stream. + }]; + + let arguments = (ins + TF_StrTensor:$input, + + DefaultValuedAttr:$output_stream, + DefaultValuedAttr:$end + ); + + let results = (outs); +} + def TF_ProdOp : TF_Op<"Prod", [NoSideEffect]> { let summary = [{ Computes the product of elements across dimensions of a tensor. From 554ec07c097aeff12b95c5bb773d76d4a060727e Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Wed, 18 Mar 2020 16:58:45 -0700 Subject: [PATCH 184/492] Release Python GIL when recovering from failures. This enables more concurrency of remote execution and remote copy during cluster update. It also exposes a race condition where client might use WorkerInterface to issue remote RPCs while we update the underlying WorkerSession. Previously, WorkerSession is the sole owner of WorkerCache, and resetting the cache invalidates all worker interfaces. This may lead to segfaults when executing distributed function after initialization (in ClusterFunctionLibraryRuntime) or calling callbacks after receiving tensors (RpcRendezvousMgr). This CL makes the WorkerCache pointer shared by the callers to avoid segfaults. Cleaned up some unnecessary code for recreating ClusterFLR and ProcessFLR. We no longer need this with the dynamic device managers. Also cleaned up unnecessary executor.wait logic in eager/context.py---we no longer need this after supporting parallel cluster updates. PiperOrigin-RevId: 301698639 Change-Id: I9b7a8fb75f9d714732fcf9dca3cda0014f7bb852 --- tensorflow/c/eager/c_api.cc | 9 +-------- tensorflow/core/common_runtime/eager/context.cc | 8 +------- tensorflow/core/common_runtime/eager/context.h | 3 +-- .../cluster_function_library_runtime.cc | 12 +++++++----- .../cluster_function_library_runtime.h | 6 ++++++ .../eager/destroy_tensor_handle_node.h | 13 ++++++------- .../eager/remote_execute_node.cc | 1 + .../eager/remote_tensor_handle_data.cc | 2 +- .../distributed_runtime/rpc/rpc_rendezvous_mgr.cc | 7 ++++--- .../core/distributed_runtime/worker_session.cc | 2 +- .../core/distributed_runtime/worker_session.h | 10 +++++++++- tensorflow/python/eager/context.py | 4 ---- tensorflow/python/eager/remote_cluster_test.py | 11 +++++++++++ tensorflow/python/tfe_wrapper.cc | 2 ++ 14 files changed, 51 insertions(+), 39 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 96dc288f213..94a0a76ada1 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -651,16 +651,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( grpc_server->worker_env()->session_mgr->UpdateSession( session_name, server_def, base_request.cluster_device_attributes(), true)); - TF_RETURN_IF_ERROR( - grpc_server->worker_env()->session_mgr->WorkerSessionForSession( - session_name, &worker_session)); - tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = - tensorflow::eager::CreateClusterFLR(context_id, context, - worker_session.get()); LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster( grpc_server->worker_env(), std::move(remote_eager_workers), - added_workers, removed_workers, context_id, r, device_mgr, - keep_alive_secs, cluster_flr)); + added_workers, removed_workers, context_id, r)); } #undef LOG_AND_RETURN_IF_ERROR diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 3628f6372da..49403c080f6 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -1062,8 +1062,7 @@ Status EagerContext::UpdateRemoteMaster( std::unique_ptr remote_eager_workers, const std::vector& add_remote_contexts, const std::vector& remove_remote_contexts, uint64 context_id, - Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, - DistributedFunctionLibraryRuntime* cluster_flr) { + Rendezvous* r) { { tf_shared_lock l(remote_state_mu_); if (context_id != context_id_) { @@ -1103,7 +1102,6 @@ Status EagerContext::UpdateRemoteMaster( if (rendezvous_ != nullptr) rendezvous_->Unref(); rendezvous_ = r; remote_eager_workers_ = std::move(remote_eager_workers); - ResetClusterFLR(cluster_flr); InitPrioritizedDeviceTypeList(); default_executor_.ClearError(); @@ -1113,10 +1111,6 @@ Status EagerContext::UpdateRemoteMaster( entry.second->ClearError(); } } - const auto* config = pflr_->config(); - ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION, - &func_lib_def_, config->graph_options().optimizer_options(), - thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_); } // Register existing functions to the newly added remote workers. Note that diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 4006ecb04de..f775a3976ad 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -361,8 +361,7 @@ class EagerContext : public core::RefCounted { std::unique_ptr remote_eager_workers, const std::vector& add_remote_contexts, const std::vector& remove_remote_contexts, uint64 context_id, - Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, - DistributedFunctionLibraryRuntime* cluster_flr); + Rendezvous* r); // Similar with InitializeRemoteMaster but this context will not kill remote // contexts in shutdown. diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index 2f6e97a4aee..dc2a8fb6fce 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -186,8 +186,9 @@ void ClusterFunctionLibraryRuntime::Instantiate( auto target = options.target; VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target << " (this: " << this << ")"; - WorkerInterface* wi = - worker_session_->worker_cache()->GetOrCreateWorker(target); + std::shared_ptr worker_cache = + worker_session_->GetSharedWorkerCache(); + WorkerInterface* wi = worker_cache->GetOrCreateWorker(target); if (wi == nullptr) { std::vector workers; @@ -233,13 +234,14 @@ void ClusterFunctionLibraryRuntime::Instantiate( wi->RegisterGraphAsync( req, resp, - [this, handle, req, resp, wi, function_name, target, send_keys, recv_keys, - done](const Status& status) { + [this, handle, req, resp, worker_cache, wi, function_name, target, + send_keys, recv_keys, done](const Status& status) { if (status.ok()) { mutex_lock l(mu_); *handle = function_data_.size(); function_data_.push_back(FunctionData(resp->graph_handle(), target, - wi, *send_keys, *recv_keys)); + worker_cache, wi, *send_keys, + *recv_keys)); VLOG(1) << "CFLR::Instantiate: [Success] " << function_name << " on " << target << " (this: " << this << ")" << " with handle: " << *handle; diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h index 78fd550366b..b9763ffddc7 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/framework/function.h" @@ -68,15 +69,20 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { struct FunctionData { const string graph_handle; const string target; + // Hold a shared pointer to the underlying worker cache to avoid it being + // deleted in potential cluster update. + const std::shared_ptr worker_cache; WorkerInterface* wi = nullptr; const std::vector send_keys; const std::vector recv_keys; FunctionData(const string& graph_handle, const string& target, + std::shared_ptr worker_cache, WorkerInterface* wi, const std::vector& send_keys, const std::vector& recv_keys) : graph_handle(graph_handle), target(target), + worker_cache(std::move(worker_cache)), wi(wi), send_keys(send_keys), recv_keys(recv_keys) {} diff --git a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h index 2f4f7b91280..a2ea5f615bd 100644 --- a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h +++ b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h @@ -30,15 +30,14 @@ namespace eager { class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode { public: DestroyTensorHandleNode(std::unique_ptr request, - EagerClient* eager_client, bool ready) + core::RefCountPtr eager_client, + bool ready) : tensorflow::AsyncEagerNode(), request_(std::move(request)), - eager_client_(eager_client), - ready_(ready) { - eager_client_->Ref(); - } + eager_client_(std::move(eager_client)), + ready_(ready) {} - ~DestroyTensorHandleNode() override { eager_client_->Unref(); } + ~DestroyTensorHandleNode() override {} void RunAsync(StatusCallback done) override { EnqueueResponse* response = new EnqueueResponse; @@ -78,7 +77,7 @@ class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode { private: std::unique_ptr request_; - EagerClient* eager_client_; + core::RefCountPtr eager_client_; const string remote_task_; bool ready_; }; diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc index 81547615706..3eab62b7c9d 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc @@ -77,6 +77,7 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) { LOG(ERROR) << "Ignoring an error encountered when setting " "remote shape of tensor handle: " << retvals[i] << " with status: " << status.ToString() + << " and SetRemoteShape status: " << s.ToString() << "\nThis should never happen. " "Please file an issue with the TensorFlow Team."; } diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc index 0e5c614f57d..1f0f5a43fe2 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc @@ -54,7 +54,7 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task, VLOG(3) << "Sending request to delete " << request->DebugString(); std::unique_ptr node( absl::make_unique( - std::move(request), eager_client.get(), ready)); + std::move(request), std::move(eager_client), ready)); auto& executor = ctx->Executor(); if (executor.Async()) { Status status = executor.AddOrExecute(std::move(node)); diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index b758ec9e08c..37e88bafadb 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -224,11 +224,12 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( " is invalid remote source device."); } WorkerSession* sess = session(); + std::shared_ptr worker_cache = + sess->GetSharedWorkerCache(); // The worker will be released in a subsequent call to // `sess->worker_cache()->ReleaseWorker()` (if the call has not yet been // initialized) or `call->ReleaseWorker()` (if it has been initialized). - WorkerInterface* rwi = - sess->worker_cache()->GetOrCreateWorker(call->src_worker_); + WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_); if (s.ok() && rwi == nullptr) { s = errors::Internal("No worker known as ", call->src_worker_); } @@ -265,7 +266,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( // Start "call". Ref(); - call->Start([this, call]() { + call->Start([this, call, worker_cache]() { // Removes "call" from active_. Prevent StartAbort(). DeregisterCall(call); // If StartAbort was called prior to DeregisterCall, then the diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index a8ee08e2f36..8a758700695 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -129,7 +129,7 @@ Status WorkerSession::UpdateWorkerCacheAndDevices( std::unique_ptr new_worker_cache, std::vector> added_remote_devices, const std::vector& removed_remote_devices) { - worker_cache_ = std::unique_ptr( + worker_cache_ = std::shared_ptr( new WorkerFreeListCache(std::move(new_worker_cache))); TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index ca915b6a03a..4dc46aba35b 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -65,6 +65,14 @@ class WorkerSession { DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr); + // In the eager runtime we allow WorkerSession to be updated, where the + // worker cache will be recreated. If WorkerSession upate is expected and a + // worker in the cache is used in RPCs, the caller should hold a shared + // pointer to avoid the workers getting deleted. + std::shared_ptr GetSharedWorkerCache() { + return worker_cache_; + } + // Update an existing worker session with new set of remote workers and // devices. Added devices will be owned by the worker session, and removed // devices will be freed by their names. @@ -89,7 +97,7 @@ class WorkerSession { const string worker_name_; // Object from which WorkerInterface instances can be obtained. - std::unique_ptr worker_cache_; + std::shared_ptr worker_cache_; // graph_mgr keeps track of the registered graphs of this session. // diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index ab2e18ed99d..46331461d4a 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -594,10 +594,6 @@ class Context(object): if self._context_handle: server_def_str = server_def.SerializeToString() - # Current executor might have pending nodes that involves updated remote - # devices. Wait for them to finish before updating. - self.executor.wait() - self.executor.clear_error() pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle, keep_alive_secs, server_def_str) self._initialize_logical_devices() diff --git a/tensorflow/python/eager/remote_cluster_test.py b/tensorflow/python/eager/remote_cluster_test.py index e26b99a8aa0..78e7098d081 100644 --- a/tensorflow/python/eager/remote_cluster_test.py +++ b/tensorflow/python/eager/remote_cluster_test.py @@ -128,6 +128,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): def tearDown(self): super(DynamicClusterTest, self).tearDown() + ops.device(None).__enter__() context._reset_context() @test_util.run_in_async_and_sync_mode @@ -370,6 +371,9 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): def worker_fn(i): return math_ops.matmul(i, i) + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + def thread_fn(device, results): for i in range(num_calls): with self._coord.stop_on_exception(): @@ -505,6 +509,13 @@ class DynamicClusterWithoutLazyRemoteInputsCopyTest(DynamicClusterTest): context._reset_context() context.context().lazy_remote_inputs_copy = True + # TODO(haoyuzhang): When lazyh remote inputs copy is disabled, we use the + # WorkerService RunGraph request to execute component functions in distributed + # function execution. We currently do not have access control in WorkerService + # to allow concurrent cluster update and function execution. + def testMultiThreadPendingNodesLockFree(self): + self.skipTest("Unsupported case") + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 105d8810dd0..09221d8b0a2 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -461,9 +461,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) { tensorflow::make_safe(TF_NewStatus()); tensorflow::Safe_TF_BufferPtr buf = tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); + Py_BEGIN_ALLOW_THREADS; TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs, buf.get()->data, buf.get()->length, status.get()); + Py_END_ALLOW_THREADS; tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); }); m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) { From 8d3c68e5533fdf4f5abc728b0a7de9cb475a5f1a Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 18 Mar 2020 17:00:56 -0700 Subject: [PATCH 185/492] [Executor] Move kernel expensiveness tracking from `OpKernel` to `ExecutorImpl`. This change introduces an internal `KernelStats` class that tracks the cost estimate and expensiveness for every kernel in an `ExecutorImpl`, instead of tracking it as a member of individual `OpKernel` objects. This change makes several optimizations to the expensiveness tracking, which is used in kernel tracing and dispatch: 1. Avoid making multiple virtual calls to access the result of `IsExpensive()`. The check is now inlinable, and requires reading one or two atomic values from an array. According to pprof, the `callq *0x20(%rax)` instruction is about 22x the cost of the logic for looking up the expensive bit. 2. When the cost estimate drops below the threshold, we stop updating it, so the change overwrites the "expensive" bit to false, avoiding an unnecessary read on each access. 3. (A little speculative.) By packing the state into two vectors, we could have better cache locality than the `OpKernel`-based approach, because the former `OpKernel::expensive_` and `OpKernel::cost_estimate_` are on a different cache line from the `OpKernel` vtable. In `ScheduleReady()` we avoid touching `OpKernel` fields at all. We also avoid i-cache pollution from executing the different overrides of `OpKernel::IsExpensive()`. PiperOrigin-RevId: 301699014 Change-Id: Ic8d85d24e59708b64d34584f59587843f7b2f7c1 --- tensorflow/core/common_runtime/executor.cc | 85 +++++++++++++++++++--- tensorflow/core/framework/op_kernel.cc | 11 +-- tensorflow/core/framework/op_kernel.h | 28 +------ 3 files changed, 77 insertions(+), 47 deletions(-) diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index a1b2224743f..f3cf11b274f 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -346,6 +346,68 @@ class ExecutorImpl : public Executor { private: friend class ExecutorState; + // Stores execution time information about the kernels in an executor's graph. + class KernelStats { + public: + KernelStats() = default; + + void Initialize(const GraphView& gview) { + is_expensive_ = absl::make_unique[]>(gview.num_nodes()); + cost_estimates_ = + absl::make_unique(gview.num_nodes()); + for (int32 i = 0; i < gview.num_nodes(); ++i) { + if (gview.node(i)) { + is_expensive_[i] = gview.node(i)->kernel->IsExpensive(); + cost_estimates_[i] = kInitialCostEstimateCycles; + } + } + } + + // Returns true iff the given node is considered "expensive". The + // executor uses this flag to optimize graph execution, for example + // by "inlining" inexpensive kernels. + bool IsExpensive(const NodeItem& node) const { + return is_expensive_[node.node_id] && + (cost_estimates_[node.node_id].load(std::memory_order_relaxed) > + kOpIsExpensiveThresholdCycles); + } + + // Updates the dynamic cost estimate, which is used to determine whether the + // given node is expensive. The new cost estimate is a weighted average of + // the old cost estimate and the latest cost. + // + // NOTE: We currently only expect updates to the cost estimate when + // `is_expensive_[node.node_id]` is true (or at least, it *was* true, when + // we started to execute the kernel. As a result, we expect that a kernel + // can only ever transition from "expensive" to "inexpensive", but not vice + // versa. + void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) { + // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous + // updates may result in one or more updates being ignored. This does not + // affect correctness but may slow down the update frequency. + std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id]; + uint64 new_estimate = (kCostDecay - 1) * + cost_estimate.load(std::memory_order_relaxed) / + kCostDecay + + (elapsed_cycles / kCostDecay); + cost_estimate.store(new_estimate, std::memory_order_relaxed); + if (new_estimate < kOpIsExpensiveThresholdCycles) { + is_expensive_[node.node_id].store(false, std::memory_order_relaxed); + } + } + + private: + // Initial time (in CPU cycles) we expect an operation to take. Used to + // determine whether an operation should be place in a threadpool. + // Operations start out "expensive". + static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000; + static const uint64 kOpIsExpensiveThresholdCycles = 5000; + static const uint64 kCostDecay = 10; + + std::unique_ptr[]> is_expensive_; + std::unique_ptr cost_estimates_; + }; + struct ControlFlowInfo { gtl::FlatSet unique_frame_names; std::vector frame_names; @@ -396,6 +458,7 @@ class ExecutorImpl : public Executor { // Owned. LocalExecutorParams params_; GraphView gview_; + mutable KernelStats kernel_stats_; // Root nodes (with no in edges) that should form the initial ready queue std::vector root_nodes_; @@ -732,7 +795,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) { // Initialize PendingCounts only after item->pending_id is initialized for // all nodes. InitializePending(&graph, cf_info); - + kernel_stats_.Initialize(gview_); return gview_.SetAllocAttrs(&graph, params_.device); } @@ -1713,8 +1776,8 @@ struct ExecutorState::AsyncState { // Returns true if `item` might be traced by the given trace and event // collectors. Returns false only if `item` definitely will not be traced. -bool MightTrace(const NodeItem& item, - const tracing::EventCollector* event_collector) { +bool MightTrace(const tracing::EventCollector* event_collector, + bool is_expensive) { // Tracing will only be enabled if either `event_collector` is non null, // or `trace_collector` is non-null and enabled for this particular kernel. // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and @@ -1728,8 +1791,7 @@ bool MightTrace(const NodeItem& item, if (profiler::ScopedAnnotation::IsEnabled()) return true; - return profiler::TraceMe::Active( - profiler::GetTFTraceMeLevel(item.kernel->IsExpensive())); + return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive)); } Status ExecutorState::ProcessSync(const NodeItem& item, @@ -1742,8 +1804,9 @@ Status ExecutorState::ProcessSync(const NodeItem& item, OpKernel* op_kernel = item.kernel; Device* device = impl_->params_.device; + const bool is_expensive = impl_->kernel_stats_.IsExpensive(item); - if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) { + if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) { tracing::ScopedRegion region(tracing::EventCategory::kCompute, op_kernel->name_view()); profiler::AnnotatedTraceMe activity( @@ -1751,16 +1814,16 @@ Status ExecutorState::ProcessSync(const NodeItem& item, return op_kernel->TraceString( &ctx, /*verbose=*/profiler::TfOpDetailsEnabled()); }, - profiler::GetTFTraceMeLevel(op_kernel->IsExpensive())); + profiler::GetTFTraceMeLevel(is_expensive)); device->Compute(op_kernel, &ctx); nodestats::SetOpEnd(stats); s = ProcessOutputs(item, &ctx, outputs, stats); } else { // In the common case, avoid creating any tracing objects. - if (op_kernel->IsExpensive()) { + if (is_expensive) { KernelTimer timer; device->Compute(op_kernel, &ctx); - op_kernel->UpdateCostEstimate(timer.ElapsedCycles()); + impl_->kernel_stats_.UpdateCostEstimate(item, timer.ElapsedCycles()); } else { device->Compute(op_kernel, &ctx); } @@ -1821,7 +1884,7 @@ void ExecutorState::ProcessAsync(const NodeItem& item, return async_kernel->TraceString( &state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled()); }, - profiler::GetTFTraceMeLevel(async_kernel->IsExpensive())); + profiler::GetTFTraceMeLevel(impl_->kernel_stats_.IsExpensive(item))); impl_->params_.device->ComputeAsync(async_kernel, &state->ctx, std::move(done)); } @@ -2443,7 +2506,7 @@ void ExecutorState::ScheduleReady(TaggedNodeSeq* ready, } else { for (auto& tagged_node : *ready) { const NodeItem& item = *tagged_node.node_item; - if (tagged_node.is_dead || !item.kernel->IsExpensive()) { + if (tagged_node.is_dead || !impl_->kernel_stats_.IsExpensive(item)) { // Inline this inexpensive node. inline_ready->push_back(tagged_node); } else { diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 15d55cc19d0..40e075ba737 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -105,8 +105,7 @@ OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred) name_view_(props_->node_def.name()), type_string_view_(props_->node_def.op()), graph_def_version_(context->graph_def_version()), - is_deferred_(is_deferred), - cost_estimate_(OpKernel::kInitialCostEstimateCycles) { + is_deferred_(is_deferred) { OP_REQUIRES_OK(context, NameRangesForNode(props_->node_def, *props_->op_def, &input_name_map_, &output_name_map_)); @@ -133,8 +132,7 @@ OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, name_view_(props_->node_def.name()), type_string_view_(props_->node_def.op()), graph_def_version_(context->graph_def_version()), - is_deferred_(is_deferred), - cost_estimate_(OpKernel::kInitialCostEstimateCycles) { + is_deferred_(is_deferred) { OP_REQUIRES_OK(context, NameRangesForNode(props_->node_def, *props_->op_def, &input_name_map_, &output_name_map_)); @@ -149,11 +147,6 @@ OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, OpKernel::~OpKernel() {} -const uint64 OpKernel::kInitialCostEstimateCycles; -const uint64 OpKernel::kOpIsExpensiveThresholdCycles; -const uint64 OpKernel::kCostDecay; - - Status OpKernel::InputRange(StringPiece input_name, int* start, int* stop) const { const auto result = input_name_map_.find(input_name); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 7a92a40e103..1644eff9319 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ -#include #include #include #include @@ -136,38 +135,14 @@ class OpKernel { // Returns nullptr iff this op kernel is synchronous. virtual AsyncOpKernel* AsAsync() { return nullptr; } - // Initial time (in CPU cycles) we expect an operation to take. Used to - // determine whether an operation should be place in a threadpool. Operations - // start out "expensive". - static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000; - static const uint64 kOpIsExpensiveThresholdCycles = 5000; - static const uint64 kCostDecay = 10; - // Returns true iff this op kernel is considered "expensive". The // runtime may use this flag to optimize graph execution for example // to "inline" inexpensive kernels. - virtual bool IsExpensive() { - return expensive_ && (cost_estimate_.load(std::memory_order_relaxed) > - kOpIsExpensiveThresholdCycles); - } + virtual bool IsExpensive() { return expensive_; } // Returns a pointer to the tensor stored inside constant ops. virtual const Tensor* const_tensor() const { return nullptr; } - // Updates the dynamic cost estimate, which is used to determine whether this - // op is expensive. The new cost estimate is a weighted average of the old - // cost estimate and the latest cost. - void UpdateCostEstimate(uint64 elapsed_cycles) { - // N.B. Updates to `cost_estimate_` are atomic but unlocked. Simultaneous - // updates may result in one or more updates being ignored. This does not - // affect correctness but may slow down the update frequency. - cost_estimate_.store( - (kCostDecay - 1) * cost_estimate_.load(std::memory_order_relaxed) / - kCostDecay + - (elapsed_cycles / kCostDecay), - std::memory_order_relaxed); - } - // Accessors. const NodeDef& def() const { return props_->node_def; } const string& name() const { return props_->node_def.name(); } @@ -220,7 +195,6 @@ class OpKernel { const int graph_def_version_; const bool is_deferred_; bool expensive_; - std::atomic_uint_fast64_t cost_estimate_; TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); }; From 89257f79df38c4f470aa21f97c8bc1b586ad4444 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 18 Mar 2020 17:01:17 -0700 Subject: [PATCH 186/492] Build quantized ops into TF windows distribution. PiperOrigin-RevId: 301699062 Change-Id: I21d5a5bcac2c53e90707fc45bbda90ef20abc1c7 --- tensorflow/core/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 8efada20e24..9fafb6f8e47 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1046,6 +1046,7 @@ cc_library( "//tensorflow/core/kernels:parsing", "//tensorflow/core/kernels:partitioned_function_ops", "//tensorflow/core/kernels:pooling_ops", + "//tensorflow/core/kernels:quantized_ops", "//tensorflow/core/kernels:ragged_ops", "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:stateful_random_ops", @@ -1072,7 +1073,6 @@ cc_library( "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:array_not_windows", "//tensorflow/core/kernels:math_not_windows", - "//tensorflow/core/kernels:quantized_ops", "//tensorflow/core/kernels/neon:neon_depthwise_conv_op", ]) + if_mkl([ "//tensorflow/core/kernels:mkl_aggregate_ops", From 14392ca38756428aeec97e5db689d6daef174f27 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 18 Mar 2020 17:03:14 -0700 Subject: [PATCH 187/492] Update another test case I missed in 301407257 Test fails with: OverflowError: Python int too large to convert to C long On windows, sizeof(long) is 4 bytes. Therefore, the large integers seem to be a problem when backed by some numpy types: https://stackoverflow.com/questions/38314118/overflowerror-python-int-too-large-to-convert-to-c-long-on-windows-but-not-ma PiperOrigin-RevId: 301699600 Change-Id: Ifc80587f4e2937f7590e25e1884593efb4b12c4d --- tensorflow/compiler/tests/stateless_random_ops_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 56b49689607..a56c9206861 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -126,7 +126,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): n = 10000000 x = stateless.stateless_truncated_normal( shape=[n], seed=seed_t, dtype=dtype) - y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) random_test_util.test_truncated_normal( self.assertEqual, self.assertAllClose, n, y, variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3) From 0fac468c40727ecb8eb2b1b71db7338c907f4870 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 17:26:30 -0700 Subject: [PATCH 188/492] Add common subgraph elimination to the tf.data optimization pipeline, This optimizer was pulled out of arithmetic optimizer as a separate pass. PiperOrigin-RevId: 301703580 Change-Id: I668d4659f11d4b426148a153c04df79659888489 --- tensorflow/core/grappler/optimizers/data/BUILD | 1 + tensorflow/core/grappler/optimizers/data/meta_optimizer.cc | 3 +++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index bdc36c97e59..519f689b278 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -635,6 +635,7 @@ cc_library( "@com_google_absl//absl/strings", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/optimizers:arithmetic_optimizer", + "//tensorflow/core/grappler/optimizers:common_subgraph_elimination", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler/optimizers:dependency_optimizer", diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc index 150a44f3035..39b59a229df 100644 --- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" +#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h" @@ -174,6 +175,8 @@ Status TFDataMetaOptimizer::Init( enabled_optimizers_["function"] = MakeUnique( RewriterConfig::ON, /*lower_control_flow=*/true); enabled_optimizers_["shape"] = MakeUnique(); + enabled_optimizers_["common_subgraph_elimination"] = + MakeUnique(); enabled_optimizers_["arithmetic"] = MakeUnique(); enabled_optimizers_["dependency"] = MakeUnique(); From cfbf43aa1f76f3889acbacd6cd649d0cfd6f079a Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 18 Mar 2020 17:39:11 -0700 Subject: [PATCH 189/492] Build more kernels into windows packages. BUILD issues around these seem to be resolved now. PiperOrigin-RevId: 301705450 Change-Id: I9ef82be3d3b3906dba520918aac4ad55a811c83a --- tensorflow/core/BUILD | 4 +--- tensorflow/core/kernels/BUILD | 16 ++-------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 9fafb6f8e47..df502b675b0 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1027,6 +1027,7 @@ cc_library( "//tensorflow/core/kernels:data_flow", "//tensorflow/core/kernels:decode_proto_op", "//tensorflow/core/kernels:encode_proto_op", + "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", @@ -1070,9 +1071,6 @@ cc_library( "//tensorflow/core/kernels:word2vec_kernels", "//tensorflow/core/kernels/sparse:kernels", ] + if_not_windows([ - "//tensorflow/core/kernels:fact_op", - "//tensorflow/core/kernels:array_not_windows", - "//tensorflow/core/kernels:math_not_windows", "//tensorflow/core/kernels/neon:neon_depthwise_conv_op", ]) + if_mkl([ "//tensorflow/core/kernels:mkl_aggregate_ops", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index fdb55e8c928..0477d260e10 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -984,13 +984,6 @@ ARRAY_DEPS = [ "//third_party/eigen3", ] + if_sycl(["//tensorflow/core:sycl_runtime"]) -cc_library( - name = "array_not_windows", - deps = [ - ":immutable_constant_op", - ], -) - tf_kernel_library( name = "immutable_constant_op", prefix = "immutable_constant_op", @@ -1036,6 +1029,7 @@ cc_library( ":host_constant_op", ":identity_n_op", ":identity_op", + ":immutable_constant_op", ":inplace_ops", ":listdiff_op", ":matrix_band_part_op", @@ -3922,13 +3916,6 @@ MATH_DEPS = [ "//third_party/eigen3", ] -cc_library( - name = "math_not_windows", - deps = [ - ":sparse_matmul_op", - ], -) - tf_kernel_library( name = "sparse_matmul_op", defines = select({ @@ -3967,6 +3954,7 @@ cc_library( ":scan_ops", ":segment_reduction_ops", ":sequence_ops", + ":sparse_matmul_op", "//tensorflow/core/kernels/special_math:special_math_op", ], ) From ed02dd2ff2646180273f177bbc85c93dc09a3241 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 18:01:55 -0700 Subject: [PATCH 190/492] Typo fix `def call(inputs, self)` -> `def call(self, inputs)` in `Layer.add_loss` documentation. PiperOrigin-RevId: 301708682 Change-Id: Ia18b9034c4d228343a63a1c0f6572fe6f76907b0 --- tensorflow/python/keras/engine/base_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index a6cb30aa181..8ae529fbfcb 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1185,7 +1185,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ```python class MyLayer(tf.keras.layers.Layer): - def call(inputs, self): + def call(self, inputs): self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True) return inputs ``` From e64e3cd7014e93014349d68c9198cb651f56224d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 18:07:37 -0700 Subject: [PATCH 191/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301709885 Change-Id: I9aba334137004bdcd1e0820c4fee9a4215c163c2 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 3d05bb08fa3..7be0c66548c 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11815,7 +11815,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12072,7 +12072,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12083,7 +12083,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12301,7 +12301,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12312,7 +12312,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19153,7 +19153,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20224,7 +20224,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21396,7 +21396,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22104,7 +22104,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22300,7 +22300,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22484,7 +22484,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22543,7 +22543,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22717,7 +22717,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23098,7 +23098,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25441,7 +25441,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25504,7 +25504,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25747,7 +25747,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26370,7 +26370,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45499,7 +45499,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46287,7 +46287,7 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46350,7 +46350,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 2eb785c3fce905ff155d9ca49eb5e1d3ac9979ba Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Wed, 18 Mar 2020 18:14:13 -0700 Subject: [PATCH 192/492] Keras api doc fixit: Supplement with more information and testable example in LearningRateScheduler callback. PiperOrigin-RevId: 301710883 Change-Id: Ie1bf9e21de4c16cd44ec9e42059a4d39f072cf38 --- tensorflow/python/keras/callbacks.py | 50 ++++++++++++++++++---------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 734b833fd62..3f9a3fd684b 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1562,28 +1562,42 @@ class RemoteMonitor(Callback): class LearningRateScheduler(Callback): """Learning rate scheduler. - Arguments: + At the beginning of every epoch, this callback gets the learning rate + value from `schedule` function provided at `__init__`, with the current epoch, + and applies that learning rate on the optimizer. + + Example: + + >>> # This function keeps the learning rate at 0.001 for the first ten epochs + >>> # and decreases it exponentially after that. + >>> def scheduler(epoch): + ... if epoch < 10: + ... return 0.001 + ... else: + ... return 0.001 * tf.math.exp(0.1 * (10 - epoch)) + >>> + >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) + >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') + >>> round(model.optimizer.lr.numpy(), 5) + 0.01 + + >>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler) + >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), + ... epochs=2, callbacks=[callback], verbose=0) + >>> round(model.optimizer.lr.numpy(), 5) + 0.001 + + """ + + def __init__(self, schedule, verbose=0): + """Initialize a `keras.callbacks.LearningRateScheduler` callback. + + Arguments: schedule: a function that takes an epoch index as input (integer, indexed from 0) and returns a new learning rate as output (float). verbose: int. 0: quiet, 1: update messages. - - ```python - # This function keeps the learning rate at 0.001 for the first ten epochs - # and decreases it exponentially after that. - def scheduler(epoch): - if epoch < 10: - return 0.001 - else: - return 0.001 * tf.math.exp(0.1 * (10 - epoch)) - - callback = tf.keras.callbacks.LearningRateScheduler(scheduler) - model.fit(data, labels, epochs=100, callbacks=[callback], - validation_data=(val_data, val_labels)) - ``` - """ - - def __init__(self, schedule, verbose=0): + """ super(LearningRateScheduler, self).__init__() self.schedule = schedule self.verbose = verbose From a6cdce91ccd875b8e7c3efbf967f16bbc5a0b48e Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 18 Mar 2020 18:35:46 -0700 Subject: [PATCH 193/492] Standardize name scopes used during model construction. PiperOrigin-RevId: 301713794 Change-Id: Ifa309e22955183968ad51c6989be5356b8266cc1 --- tensorflow/python/keras/engine/base_layer.py | 11 ++++++++- .../python/keras/engine/base_layer_test.py | 24 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 8ae529fbfcb..66de8e7bd5f 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2083,7 +2083,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._dtype_policy = policy.Policy(value) def _name_scope(self): - return self.name + name_scope = self.name + current_name_scope = ops.get_name_scope() + if current_name_scope: + name_scope = current_name_scope + '/' + name_scope + if name_scope: + # Note that the trailing `/` prevents autogenerated + # numerical suffixes to get appended. It will also fully reset + # nested name scope (i.e. the outer name scope has no effect). + name_scope += '/' + return name_scope def _init_set_name(self, name, zero_based=True): if not name: diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 94766fe177a..1999f313d6b 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -936,6 +936,30 @@ class NameScopingTest(keras_parameterized.TestCase): self.assertEqual(layer.bias.name, 'MyName/bias:0') self.assertEqual(layer.kernel.name, 'MyName/kernel:0') + def test_name_scope_functional_api(self): + inputs = input_layer.Input((3,)) + layer = layers.Dense(10, name='MyName') + _ = layer(inputs) + self.assertEqual(layer.bias.name, 'MyName/bias:0') + self.assertEqual(layer.kernel.name, 'MyName/kernel:0') + + def test_name_scope_functional_api_nested(self): + + class NestedLayer(base_layer.Layer): + + def __init__(self, name='OuterName'): + super(NestedLayer, self).__init__(name=name) + self.dense = layers.Dense(10, name='InnerName') + + def call(self, inputs): + return self.dense(inputs) + + inputs = input_layer.Input((3,)) + layer = NestedLayer() + _ = layer(inputs) + self.assertEqual(layer.dense.bias.name, 'OuterName/InnerName/bias:0') + self.assertEqual(layer.dense.kernel.name, 'OuterName/InnerName/kernel:0') + def test_name_scope_sublayer(self): class NameScopeTracker(base_layer.Layer): From e377b6dbcfa480d10493fd5467ad8d127a0c6af8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 19:13:05 -0700 Subject: [PATCH 194/492] Remove the TENSORFLOW_MEM_DEBUG compilation flag from the path that passes TF op name etc. to BFCAllocator, i.e. enable the passing by default. PiperOrigin-RevId: 301718627 Change-Id: I80d75f1d7141b80f3454b79fbb4befe60b2d6d8c --- tensorflow/core/common_runtime/bfc_allocator.cc | 2 -- .../core/common_runtime/eager/eager_operation.cc | 2 -- .../core/common_runtime/eager/eager_operation.h | 2 -- tensorflow/core/framework/allocator.cc | 2 -- tensorflow/core/framework/allocator.h | 14 ++------------ 5 files changed, 2 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 1100ba9684c..df2bec93f0c 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -460,9 +460,7 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name, ",bytes_available=", bytes_available, ",peak_bytes_in_use=", stats.peak_bytes_in_use, ",requested_bytes=", requested_bytes, -#ifdef TENSORFLOW_MEM_DEBUG ",tf_op=", pending_op_name, ",id=", pending_step_id, -#endif "#"); }, traceme_level); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 94b85a190c1..7c4d04646a7 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -57,9 +57,7 @@ Status EagerOperation::Reset( cancellation_manager_ = nullptr; executor_ = executor ? executor : &ctx_.Executor(); remote_func_params_ = remote_func_params; -#ifdef TENSORFLOW_MEM_DEBUG op_name_ = op; -#endif return SetDeviceName(raw_device_name, true); } diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 4b46fc5c709..3e3474d6b61 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -121,10 +121,8 @@ class EagerOperation { return remote_func_params_; } -#ifdef TENSORFLOW_MEM_DEBUG const char* op_name() const { return op_name_; } const char* op_name_ = nullptr; -#endif Status MaybeInferSingleInputAttrs(TensorHandle* handle); Status InferInputListAttrs(int num_inputs); diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 6757a9b593e..7224aa8051b 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -27,10 +27,8 @@ limitations under the License. namespace tensorflow { -#ifdef TENSORFLOW_MEM_DEBUG thread_local const char* pending_op_name = nullptr; thread_local uint64 pending_step_id = 0; -#endif string AllocatorStats::DebugString() const { return strings::Printf( diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 2e239a4d6de..609fe716180 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -62,9 +62,8 @@ struct AllocationAttributes { TF_DISALLOW_COPY_AND_ASSIGN(AllocationAttributes); }; -// If defined, the runtime will cache Op names in thread-local memory -// and some allocators will try to tag allocations with the requesting Op. -#ifdef TENSORFLOW_MEM_DEBUG +// The runtime will cache Op names in thread-local memory and some allocators +// will try to tag allocations with the requesting Op. extern thread_local const char* pending_op_name; extern thread_local uint64 pending_step_id; #define MEMDEBUG_CACHE_OP(N) \ @@ -76,15 +75,6 @@ extern thread_local uint64 pending_step_id; pending_step_id = (N); \ } while (0) #define MEMDEBUG_CACHE_VAL pending_op_name -#else -#define MEMDEBUG_CACHE_OP(N) \ - do { \ - } while (0) -#define MEMDEBUG_CACHE_STEPID(N) \ - do { \ - } while (0) -#define MEMDEBUG_CACHE_VAL nullptr -#endif // Runtime statistics collected by an allocator. Exactly the same as // stream_executor::AllocatorStats, but independently defined to preserve the From 7fe72602b44f2a2b04fea25a849694b939139329 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 19:15:11 -0700 Subject: [PATCH 195/492] Standardize name scopes used during model construction. PiperOrigin-RevId: 301718829 Change-Id: I09d0ffe16b08c369b864290c9e33ebc3b0d85edb --- tensorflow/python/keras/engine/base_layer.py | 11 +-------- .../python/keras/engine/base_layer_test.py | 24 ------------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 66de8e7bd5f..8ae529fbfcb 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2083,16 +2083,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._dtype_policy = policy.Policy(value) def _name_scope(self): - name_scope = self.name - current_name_scope = ops.get_name_scope() - if current_name_scope: - name_scope = current_name_scope + '/' + name_scope - if name_scope: - # Note that the trailing `/` prevents autogenerated - # numerical suffixes to get appended. It will also fully reset - # nested name scope (i.e. the outer name scope has no effect). - name_scope += '/' - return name_scope + return self.name def _init_set_name(self, name, zero_based=True): if not name: diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 1999f313d6b..94766fe177a 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -936,30 +936,6 @@ class NameScopingTest(keras_parameterized.TestCase): self.assertEqual(layer.bias.name, 'MyName/bias:0') self.assertEqual(layer.kernel.name, 'MyName/kernel:0') - def test_name_scope_functional_api(self): - inputs = input_layer.Input((3,)) - layer = layers.Dense(10, name='MyName') - _ = layer(inputs) - self.assertEqual(layer.bias.name, 'MyName/bias:0') - self.assertEqual(layer.kernel.name, 'MyName/kernel:0') - - def test_name_scope_functional_api_nested(self): - - class NestedLayer(base_layer.Layer): - - def __init__(self, name='OuterName'): - super(NestedLayer, self).__init__(name=name) - self.dense = layers.Dense(10, name='InnerName') - - def call(self, inputs): - return self.dense(inputs) - - inputs = input_layer.Input((3,)) - layer = NestedLayer() - _ = layer(inputs) - self.assertEqual(layer.dense.bias.name, 'OuterName/InnerName/bias:0') - self.assertEqual(layer.dense.kernel.name, 'OuterName/InnerName/kernel:0') - def test_name_scope_sublayer(self): class NameScopeTracker(base_layer.Layer): From 60fcd57c93dd0bb8a1683338127480118c70d9bb Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Wed, 18 Mar 2020 19:41:09 -0700 Subject: [PATCH 196/492] [tf.data service] Add compression utils for dataset elements. PiperOrigin-RevId: 301721552 Change-Id: I359ffe13df37e53e8cc217dbc05fa696ddbc1f35 --- tensorflow/core/data/service/BUILD | 37 +++++ .../core/data/service/compression_utils.cc | 151 ++++++++++++++++++ .../core/data/service/compression_utils.h | 40 +++++ .../data/service/compression_utils_test.cc | 55 +++++++ 4 files changed, 283 insertions(+) create mode 100644 tensorflow/core/data/service/compression_utils.cc create mode 100644 tensorflow/core/data/service/compression_utils.h create mode 100644 tensorflow/core/data/service/compression_utils_test.cc diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 6003362406f..68c0f2d47d7 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -4,6 +4,10 @@ load( "tf_additional_all_protos", "tf_proto_library", ) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) package( default_visibility = [ @@ -41,6 +45,39 @@ tf_proto_library( ], ) +cc_library( + name = "compression_utils", + srcs = ["compression_utils.cc"], + hdrs = [ + "compression_utils.h", + ], + deps = [ + ":common_proto_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "compression_utils_test", + srcs = ["compression_utils_test.cc"], + deps = [ + ":compression_utils", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels/data:dataset_test_base", + ], +) + cc_grpc_library( name = "master_cc_grpc_proto", srcs = [":master_proto"], diff --git a/tensorflow/core/data/service/compression_utils.cc b/tensorflow/core/data/service/compression_utils.cc new file mode 100644 index 00000000000..c4a47e1b00e --- /dev/null +++ b/tensorflow/core/data/service/compression_utils.cc @@ -0,0 +1,151 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/compression_utils.h" + +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/snappy.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace tensorflow { +namespace data { +namespace service_util { + +Status Compress(const std::vector& element, CompressedElement* out) { + tensorflow::profiler::TraceMe activity( + "Compress", tensorflow::profiler::TraceMeLevel::kInfo); + + // Step 1: Determine the total uncompressed size. This requires serializing + // non-memcopyable tensors, which we save to use again later. + std::vector non_memcpy_components; + int64 total_size = 0; + for (auto& component : element) { + if (DataTypeCanUseMemcpy(component.dtype())) { + // Some datatypes can be memcopied, allowing us to save two copies + // (AsProtoTensorContent and SerializeToArray). + total_size += DMAHelper::buffer(&component)->size(); + } else { + non_memcpy_components.emplace_back(); + component.AsProtoTensorContent(&non_memcpy_components.back()); + total_size += non_memcpy_components.back().ByteSizeLong(); + } + } + + // Step 2: Write the tensor data to a buffer, and compress that buffer. + // We use tstring for access to resize_uninitialized. + tstring uncompressed; + uncompressed.resize_uninitialized(total_size); + // Position in `uncompressed` to write the next component. + char* position = uncompressed.mdata(); + int non_memcpy_component_index = 0; + for (auto& component : element) { + ComponentMetadata* metadata = out->mutable_component_metadata()->Add(); + metadata->set_dtype(component.dtype()); + component.shape().AsProto(metadata->mutable_tensor_shape()); + if (DataTypeCanUseMemcpy(component.dtype())) { + const TensorBuffer* buffer = DMAHelper::buffer(&component); + memcpy(position, buffer->data(), buffer->size()); + metadata->set_tensor_size_bytes(buffer->size()); + } else { + TensorProto& proto = non_memcpy_components[non_memcpy_component_index++]; + proto.SerializeToArray(position, proto.ByteSizeLong()); + metadata->set_tensor_size_bytes(proto.ByteSizeLong()); + } + position += metadata->tensor_size_bytes(); + } + DCHECK_EQ(position, uncompressed.mdata() + total_size); + + if (!port::Snappy_Compress(uncompressed.mdata(), total_size, + out->mutable_data())) { + return errors::Internal("Failed to compress using snappy."); + } + return Status::OK(); +} + +Status Uncompress(const CompressedElement& compressed, + std::vector* out) { + tensorflow::profiler::TraceMe activity( + "Uncompress", tensorflow::profiler::TraceMeLevel::kInfo); + int num_components = compressed.component_metadata_size(); + out->clear(); + out->reserve(num_components); + + // Step 1: Prepare the memory that we will uncompress into. + std::vector iov(num_components); + // We use tstring for access to resize_uninitialized. + std::vector tensor_proto_strs; + // num_components is a conservative estimate. It is important to reserve + // vector space so that the vector doesn't resize itself, which could + // invalidate pointers to its strings' data. + tensor_proto_strs.reserve(num_components); + int64 total_size = 0; + for (int i = 0; i < num_components; ++i) { + const ComponentMetadata& metadata = compressed.component_metadata(i); + if (DataTypeCanUseMemcpy(metadata.dtype())) { + out->emplace_back(metadata.dtype(), metadata.tensor_shape()); + TensorBuffer* buffer = DMAHelper::buffer(&out->back()); + iov[i].iov_base = buffer->data(); + iov[i].iov_len = buffer->size(); + } else { + // Allocate an empty Tensor. We will fill it out later after + // uncompressing into the tensor_proto_str. + out->emplace_back(); + tensor_proto_strs.emplace_back(); + tstring& tensor_proto_str = tensor_proto_strs.back(); + tensor_proto_str.resize_uninitialized(metadata.tensor_size_bytes()); + iov[i].iov_base = tensor_proto_str.mdata(); + iov[i].iov_len = tensor_proto_str.size(); + } + total_size += iov[i].iov_len; + } + + // Step 2: Uncompress into the iovec. + const std::string& compressed_data = compressed.data(); + size_t uncompressed_size; + if (!port::Snappy_GetUncompressedLength( + compressed_data.data(), compressed_data.size(), &uncompressed_size)) { + return errors::Internal("Could not get snappy uncompressed length"); + } + if (uncompressed_size != total_size) { + return errors::Internal( + "Uncompressed size mismatch. Snappy expects ", uncompressed_size, + " whereas the tensor metadata suggests ", total_size); + } + if (!port::Snappy_UncompressToIOVec(compressed_data.data(), + compressed_data.size(), iov.data(), + num_components)) { + return errors::Internal("Failed to perform snappy decompression."); + } + + // Step 3: Deserialize tensor proto strings to tensors. + int tensor_proto_strs_index = 0; + for (int i = 0; i < num_components; ++i) { + if (DataTypeCanUseMemcpy(compressed.component_metadata(i).dtype())) { + continue; + } + TensorProto tp; + if (!tp.ParseFromString(tensor_proto_strs[tensor_proto_strs_index++])) { + return errors::Internal("Could not parse TensorProto"); + } + if (!out->at(i).FromProto(tp)) { + return errors::Internal("Could not parse Tensor"); + } + } + return Status::OK(); +} + +} // namespace service_util +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/compression_utils.h b/tensorflow/core/data/service/compression_utils.h new file mode 100644 index 00000000000..96698aaaf09 --- /dev/null +++ b/tensorflow/core/data/service/compression_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_ + +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { +namespace service_util { + +// Compresses the components of `element` into the `CompressedElement` proto. +// +// In addition to writing the actual compressed bytes, `Compress` fills +// out the per-component metadata for the `CompressedElement`. +Status Compress(const std::vector& element, CompressedElement* out); + +// Uncompresses a `CompressedElement` into a vector of tensor components. +Status Uncompress(const CompressedElement& compressed, + std::vector* out); + +} // namespace service_util +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_ diff --git a/tensorflow/core/data/service/compression_utils_test.cc b/tensorflow/core/data/service/compression_utils_test.cc new file mode 100644 index 00000000000..b5da13efeed --- /dev/null +++ b/tensorflow/core/data/service/compression_utils_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/compression_utils.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace service_util { + +class ParameterizedCompressionUtilsTest + : public DatasetOpsTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) { + std::vector element = GetParam(); + CompressedElement compressed; + TF_ASSERT_OK(Compress(element, &compressed)); + std::vector round_trip_element; + TF_ASSERT_OK(Uncompress(compressed, &round_trip_element)); + TF_EXPECT_OK( + ExpectEqual(element, round_trip_element, /*compare_order=*/true)); +} + +std::vector> TestCases() { + return { + CreateTensors(TensorShape{1}, {{1}}), // int64 + CreateTensors(TensorShape{1}, {{1}, {2}}), // multiple int64 + CreateTensors(TensorShape{1}, {{"a"}, {"b"}}), // tstring + {CreateTensor(TensorShape{1}, {"a"}), + CreateTensor(TensorShape{1}, {1})}, // mixed tstring/int64 + {}, // empty + }; +} + +INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest, + ::testing::ValuesIn(TestCases())); + +} // namespace service_util +} // namespace data +} // namespace tensorflow From a90d94d4fc0882b7c371216ba263bf2d4baadb97 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Wed, 18 Mar 2020 19:43:20 -0700 Subject: [PATCH 197/492] Add `get_tpu_system_metadata` API to TPUClusterResolver. Also export `tf.tpu.experimental.TPUSystemMetadata` and `tf.tpu.experimental.Topology` symbols. PiperOrigin-RevId: 301721761 Change-Id: I765e04f0e8cb3e2f556b3486d6ee692dcb0456ac --- tensorflow/python/distribute/BUILD | 1 - .../python/distribute/cluster_resolver/BUILD | 1 + .../cluster_resolver/cluster_resolver_test.py | 6 +-- .../gce_cluster_resolver_test.py | 4 +- .../kubernetes_cluster_resolver_test.py | 2 +- .../slurm_cluster_resolver_test.py | 5 +- .../tfconfig_cluster_resolver_test.py | 2 +- .../cluster_resolver/tpu_cluster_resolver.py | 11 ++++ tensorflow/python/distribute/tpu_strategy.py | 19 +------ tensorflow/python/eager/BUILD | 1 + tensorflow/python/eager/remote_test.py | 2 +- tensorflow/python/tpu/topology.py | 2 + tensorflow/python/tpu/tpu_strategy_util.py | 2 +- tensorflow/python/tpu/tpu_system_metadata.py | 39 ++++++++++---- ...ter_resolver.-t-p-u-cluster-resolver.pbtxt | 4 ++ ....experimental.-t-p-u-system-metadata.pbtxt | 35 ++++++++++++ ...ensorflow.tpu.experimental.-topology.pbtxt | 53 +++++++++++++++++++ .../v1/tensorflow.tpu.experimental.pbtxt | 8 +++ ...ter_resolver.-t-p-u-cluster-resolver.pbtxt | 4 ++ ....experimental.-t-p-u-system-metadata.pbtxt | 35 ++++++++++++ ...ensorflow.tpu.experimental.-topology.pbtxt | 53 +++++++++++++++++++ .../v2/tensorflow.tpu.experimental.pbtxt | 8 +++ 22 files changed, 257 insertions(+), 40 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-t-p-u-system-metadata.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-topology.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.-t-p-u-system-metadata.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.-topology.pbtxt diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index fd4895e6e02..459cfb6b1bf 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -164,7 +164,6 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/data", - "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/ops/losses", "//tensorflow/python/ops/losses:loss_reduction", "//tensorflow/tools/docs:doc_controls", diff --git a/tensorflow/python/distribute/cluster_resolver/BUILD b/tensorflow/python/distribute/cluster_resolver/BUILD index 3c105758527..1a9d0202837 100644 --- a/tensorflow/python/distribute/cluster_resolver/BUILD +++ b/tensorflow/python/distribute/cluster_resolver/BUILD @@ -67,6 +67,7 @@ py_library( deps = [ ":base_cluster_resolver_py", "//tensorflow/python:training_server_lib", + "//tensorflow/python/tpu:tpu_lib", "//tensorflow/python/tpu/client", ] + tf_additional_rpc_deps(), ) diff --git a/tensorflow/python/distribute/cluster_resolver/cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/cluster_resolver_test.py index d4ebd2c8e14..225166c28aa 100644 --- a/tensorflow/python/distribute/cluster_resolver/cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/cluster_resolver_test.py @@ -20,9 +20,9 @@ from __future__ import print_function from tensorflow.python import framework from tensorflow.python.client import session -from tensorflow.python.distribute.cluster_resolver import ClusterResolver -from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver -from tensorflow.python.distribute.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver from tensorflow.python.eager.context import LogicalDevice from tensorflow.python.framework import test_util from tensorflow.python.platform import test diff --git a/tensorflow/python/distribute/cluster_resolver/gce_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/gce_cluster_resolver_test.py index 47d1cdc0da9..f39c86a0495 100644 --- a/tensorflow/python/distribute/cluster_resolver/gce_cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/gce_cluster_resolver_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.distribute.cluster_resolver import GCEClusterResolver -from tensorflow.python.distribute.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib diff --git a/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver_test.py index f4e4cd82129..598c3da4642 100644 --- a/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.distribute.cluster_resolver import KubernetesClusterResolver +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib diff --git a/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver_test.py index 07f9e81994a..0c2eb7d5254 100644 --- a/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver_test.py @@ -20,8 +20,9 @@ from __future__ import print_function import os -from tensorflow.python.distribute.cluster_resolver import SlurmClusterResolver -from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import expand_hostlist, expand_tasks_per_node +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import expand_hostlist +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import expand_tasks_per_node +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib diff --git a/tensorflow/python/distribute/cluster_resolver/tfconfig_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/tfconfig_cluster_resolver_test.py index 2989e24c284..08ae0c08e3a 100644 --- a/tensorflow/python/distribute/cluster_resolver/tfconfig_cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/tfconfig_cluster_resolver_test.py @@ -22,7 +22,7 @@ import os from tensorflow.python import framework from tensorflow.python.client import session -from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver from tensorflow.python.eager.context import LogicalDevice from tensorflow.python.framework import test_util from tensorflow.python.platform import test diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py index 2f874de1b87..a1e95fc380d 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py @@ -24,6 +24,7 @@ import re from tensorflow.python.distribute.cluster_resolver import cluster_resolver from tensorflow.python.framework import errors from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -219,6 +220,16 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver): def get_job_name(self): return self.task_type + def get_tpu_system_metadata(self): + """Retrieves TPU system metadata given a TPUClusterResolver.""" + cluster_spec = self.cluster_spec() + cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None + tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access + self.master(), + cluster_def=cluster_def, + query_topology=False)) + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 2a216118f22..a45ac5785ee 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -53,29 +53,12 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import from tensorflow.python.tpu import tpu from tensorflow.python.tpu import tpu_strategy_util -from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.python.tpu import training_loop from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export -def get_tpu_system_metadata(tpu_cluster_resolver): - """Retrieves TPU system metadata given a TPUClusterResolver.""" - master = tpu_cluster_resolver.master() - - # pylint: disable=protected-access - cluster_spec = tpu_cluster_resolver.cluster_spec() - cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None - tpu_system_metadata = ( - tpu_system_metadata_lib._query_tpu_system_metadata( - master, - cluster_def=cluster_def, - query_topology=False)) - - return tpu_system_metadata - - @contextlib.contextmanager def maybe_init_scope(): if ops.executing_eagerly_outside_functions(): @@ -287,7 +270,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): self._tpu_function_cache = weakref.WeakKeyDictionary() self._tpu_cluster_resolver = tpu_cluster_resolver - self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) + self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata() self._device_assignment = device_assignment tpu_devices_flat = [ diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 5696915dc80..55bc5942253 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -843,6 +843,7 @@ py_library( ":context", "//tensorflow/core:protos_all_py", "//tensorflow/python:platform", + "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py", ], ) diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index ce560db2f14..e0a9523ef57 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -25,7 +25,7 @@ import numpy as np import six from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import remote diff --git a/tensorflow/python/tpu/topology.py b/tensorflow/python/tpu/topology.py index de233949a60..9c7941fb896 100644 --- a/tensorflow/python/tpu/topology.py +++ b/tensorflow/python/tpu/topology.py @@ -22,6 +22,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.protobuf.tpu import topology_pb2 +from tensorflow.python.util.tf_export import tf_export def _tpu_device_name(job, task, device): @@ -40,6 +41,7 @@ def _tpu_host_device_name(job, task): return "/job:%s/task:%d/device:CPU:0" % (job, task) +@tf_export("tpu.experimental.Topology") class Topology(object): """Describes a set of TPU devices. diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index e28aea87f96..543c91167cd 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib -from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import device diff --git a/tensorflow/python/tpu/tpu_system_metadata.py b/tensorflow/python/tpu/tpu_system_metadata.py index fcfb4bf68a6..4becb13f3aa 100644 --- a/tensorflow/python/tpu/tpu_system_metadata.py +++ b/tensorflow/python/tpu/tpu_system_metadata.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import tpu +from tensorflow.python.util.tf_export import tf_export _PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min _RETRY_TIMES = 12 * 24 # 1 day @@ -39,15 +40,33 @@ _DEFAULT_JOB_NAME = 'tpu_worker' _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') -# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration, -# including num_cores and num_hosts. -_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ - 'num_cores', - 'num_hosts', - 'num_of_cores_per_host', - 'topology', - 'devices', -]) + +@tf_export('tpu.experimental.TPUSystemMetadata') +class TPUSystemMetadata( + collections.namedtuple('TPUSystemMetadata', [ + 'num_cores', + 'num_hosts', + 'num_of_cores_per_host', + 'topology', + 'devices', + ])): + """Describes some metadata about the TPU system. + + Attributes: + num_cores: interger. Total number of TPU cores in the TPU system. + num_hosts: interger. Total number of hosts (TPU workers) in the TPU system. + num_of_cores_per_host: interger. Number of TPU cores per host (TPU worker). + topology: an instance of `tf.tpu.experimental.Topology`, which describes the + physical topology of TPU system. + devices: a tuple of strings, which describes all the TPU devices in the + system. + """ + + def __new__(cls, num_cores, num_hosts, num_of_cores_per_host, topology, + devices): + return super(TPUSystemMetadata, + cls).__new__(cls, num_cores, num_hosts, num_of_cores_per_host, + topology, devices) def _query_tpu_system_metadata(master_address, cluster_def=None, @@ -129,7 +148,7 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, spec.device_index) devices = tuple(sorted(devices, key=_sort_key)) - metadata = _TPUSystemMetadata( + metadata = TPUSystemMetadata( num_cores=tpu_core_count, num_hosts=len(device_dict), num_of_cores_per_host=num_of_cores_per_host, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt index dbc76c24813..c0dc0054165 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "get_master" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_tpu_system_metadata" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "master" argspec: "args=[\'self\', \'task_type\', \'task_id\', \'rpc_layer\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-t-p-u-system-metadata.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-t-p-u-system-metadata.pbtxt new file mode 100644 index 00000000000..4c0c2252b6d --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-t-p-u-system-metadata.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.tpu.experimental.TPUSystemMetadata" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "devices" + mtype: "" + } + member { + name: "num_cores" + mtype: "" + } + member { + name: "num_hosts" + mtype: "" + } + member { + name: "num_of_cores_per_host" + mtype: "" + } + member { + name: "topology" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-topology.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-topology.pbtxt new file mode 100644 index 00000000000..e8a2ca82ca0 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-topology.pbtxt @@ -0,0 +1,53 @@ +path: "tensorflow.tpu.experimental.Topology" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "device_coordinates" + mtype: "" + } + member { + name: "mesh_rank" + mtype: "" + } + member { + name: "mesh_shape" + mtype: "" + } + member { + name: "missing_devices" + mtype: "" + } + member { + name: "num_tasks" + mtype: "" + } + member { + name: "num_tpus_per_task" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'serialized\', \'mesh_shape\', \'device_coordinates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "cpu_device_name_at_coordinates" + argspec: "args=[\'self\', \'device_coordinates\', \'job\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialized" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "task_ordinal_at_coordinates" + argspec: "args=[\'self\', \'device_coordinates\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tpu_device_name_at_coordinates" + argspec: "args=[\'self\', \'device_coordinates\', \'job\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tpu_device_ordinal_at_coordinates" + argspec: "args=[\'self\', \'device_coordinates\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.pbtxt index f4a5a71ada7..ef1c8078cca 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.pbtxt @@ -20,6 +20,14 @@ tf_module { name: "StochasticGradientDescentParameters" mtype: "" } + member { + name: "TPUSystemMetadata" + mtype: "" + } + member { + name: "Topology" + mtype: "" + } member_method { name: "embedding_column" argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'max_sequence_length\', \'learning_rate_fn\', \'embedding_lookup_device\', \'tensor_core_shape\', \'use_safe_embedding_lookup\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'0\', \'None\', \'None\', \'None\', \'True\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt index dbc76c24813..c0dc0054165 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "get_master" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_tpu_system_metadata" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "master" argspec: "args=[\'self\', \'task_type\', \'task_id\', \'rpc_layer\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.-t-p-u-system-metadata.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.-t-p-u-system-metadata.pbtxt new file mode 100644 index 00000000000..4c0c2252b6d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.-t-p-u-system-metadata.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.tpu.experimental.TPUSystemMetadata" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "devices" + mtype: "" + } + member { + name: "num_cores" + mtype: "" + } + member { + name: "num_hosts" + mtype: "" + } + member { + name: "num_of_cores_per_host" + mtype: "" + } + member { + name: "topology" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.-topology.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.-topology.pbtxt new file mode 100644 index 00000000000..e8a2ca82ca0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.-topology.pbtxt @@ -0,0 +1,53 @@ +path: "tensorflow.tpu.experimental.Topology" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "device_coordinates" + mtype: "" + } + member { + name: "mesh_rank" + mtype: "" + } + member { + name: "mesh_shape" + mtype: "" + } + member { + name: "missing_devices" + mtype: "" + } + member { + name: "num_tasks" + mtype: "" + } + member { + name: "num_tpus_per_task" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'serialized\', \'mesh_shape\', \'device_coordinates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "cpu_device_name_at_coordinates" + argspec: "args=[\'self\', \'device_coordinates\', \'job\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialized" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "task_ordinal_at_coordinates" + argspec: "args=[\'self\', \'device_coordinates\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tpu_device_name_at_coordinates" + argspec: "args=[\'self\', \'device_coordinates\', \'job\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tpu_device_ordinal_at_coordinates" + argspec: "args=[\'self\', \'device_coordinates\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.pbtxt index 0f6d2766b2e..df31799828c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.pbtxt @@ -4,6 +4,14 @@ tf_module { name: "DeviceAssignment" mtype: "" } + member { + name: "TPUSystemMetadata" + mtype: "" + } + member { + name: "Topology" + mtype: "" + } member_method { name: "initialize_tpu_system" argspec: "args=[\'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], " From 295c3c9f0780654c44729b75eb8db56614e6643b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 19:46:19 -0700 Subject: [PATCH 198/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301722070 Change-Id: Ic41563dc28d9df6951ebaee54d21885410c0c65b --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 7be0c66548c..3d05bb08fa3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11815,7 +11815,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12072,7 +12072,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12083,7 +12083,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12301,7 +12301,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12312,7 +12312,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19153,7 +19153,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20224,7 +20224,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21396,7 +21396,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22104,7 +22104,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22300,7 +22300,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22484,7 +22484,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22543,7 +22543,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22717,7 +22717,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23098,7 +23098,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25441,7 +25441,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25504,7 +25504,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25747,7 +25747,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26370,7 +26370,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45499,7 +45499,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46287,7 +46287,7 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46350,7 +46350,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 6c7e338ae7f0b0f2e224319de7e2165141c148fb Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Wed, 18 Mar 2020 19:47:04 -0700 Subject: [PATCH 199/492] Upgrade abseil-cpp to Abseil LTS branch, Feb 2020, Patch 1, i.e. commit df3ea785d8c30a9503321a3d35ee7d35808f190d. Unfortunately, commit e9324d926a9189e222741fce6e676f0944661a72 from June 21, 2019 includes a change that modifies absl/container/internal/compressed_tuple.h not compatible with CUDA on Windows. The aforementioned file is reverted to commit 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f which is the last known "safe" version. PiperOrigin-RevId: 301722155 Change-Id: I0dd338fd00ef25f8e1204907c6f5366f3e511759 --- tensorflow/workspace.bzl | 8 +- ...m_google_absl_fix_mac_and_nvcc_build.patch | 297 +++++++++++++++++- 2 files changed, 287 insertions(+), 18 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 3d29648c1ba..085ae961cd6 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -189,11 +189,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): # TODO: Remove the patch when https://github.com/abseil/abseil-cpp/issues/326 is resolved # and when TensorFlow is build against CUDA 10.2 patch_file = clean_dep("//third_party:com_google_absl_fix_mac_and_nvcc_build.patch"), - sha256 = "acd93f6baaedc4414ebd08b33bebca7c7a46888916101d8c0b8083573526d070", # SHARED_ABSL_SHA - strip_prefix = "abseil-cpp-43ef2148c0936ebf7cb4be6b19927a9d9d145b8f", + sha256 = "f368a8476f4e2e0eccf8a7318b98dafbe30b2600f4e3cf52636e5eb145aba06a", # SHARED_ABSL_SHA + strip_prefix = "abseil-cpp-df3ea785d8c30a9503321a3d35ee7d35808f190d", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/43ef2148c0936ebf7cb4be6b19927a9d9d145b8f.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/43ef2148c0936ebf7cb4be6b19927a9d9d145b8f.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", ], ) diff --git a/third_party/com_google_absl_fix_mac_and_nvcc_build.patch b/third_party/com_google_absl_fix_mac_and_nvcc_build.patch index 038e618de44..271e941bfe8 100644 --- a/third_party/com_google_absl_fix_mac_and_nvcc_build.patch +++ b/third_party/com_google_absl_fix_mac_and_nvcc_build.patch @@ -1,6 +1,6 @@ --- ./absl/time/internal/cctz/BUILD.bazel 2019-09-23 13:20:52.000000000 -0700 +++ ./absl/time/internal/cctz/BUILD.bazel.fixed 2019-09-23 13:20:48.000000000 -0700 -@@ -76,15 +76,6 @@ +@@ -74,15 +74,6 @@ "include/cctz/time_zone.h", "include/cctz/zone_info_source.h", ], @@ -14,22 +14,291 @@ - "//conditions:default": [], - }), visibility = ["//visibility:public"], - deps = [":civil_time"], - ) + deps = [ + ":civil_time", --- ./absl/strings/string_view.h 2019-09-23 13:20:52.000000000 -0700 +++ ./absl/strings/string_view.h.fixed 2019-09-23 13:20:48.000000000 -0700 -@@ -492,7 +492,14 @@ - (std::numeric_limits::max)(); - - static constexpr size_type CheckLengthInternal(size_type len) { -+#if defined(__NVCC__) && (__CUDACC_VER_MAJOR__<10 || (__CUDACC_VER_MAJOR__==10 && __CUDACC_VER_MINOR__<2)) && !defined(NDEBUG) -+ // An nvcc bug treats the original return expression as a non-constant, -+ // which is not allowed in a constexpr function. This only happens when -+ // NDEBUG is not defined. This will be fixed in the CUDA 10.2 release. -+ return len; +@@ -283,7 +283,14 @@ + // Returns the ith element of the `string_view` using the array operator. + // Note that this operator does not perform any bounds checking. + constexpr const_reference operator[](size_type i) const { ++#if defined(__NVCC__) && (__CUDACC_VER_MAJOR__ < 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ < 2)) ++ // An NVCC bug treats the original return expression as a non-constant, ++ // which is not allowed in a constexpr function. This will be fixed in the ++ // CUDA 10.2 release. ++ return ptr_[i]; +#else - return ABSL_ASSERT(len <= kMaxSize), len; + return ABSL_ASSERT(i < size()), ptr_[i]; +#endif } - const char* ptr_; + // string_view::at() +@@ -292,25 +299,46 @@ + // and an exception of type `std::out_of_range` will be thrown on invalid + // access. + constexpr const_reference at(size_type i) const { ++#if defined(__NVCC__) && (__CUDACC_VER_MAJOR__ < 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ < 2)) ++ // An NVCC bug treats the original return expression as a non-constant, ++ // which is not allowed in a constexpr function. This will be fixed in the ++ // CUDA 10.2 release. ++ return ptr_[i]; ++#else + return ABSL_PREDICT_TRUE(i < size()) + ? ptr_[i] + : ((void)base_internal::ThrowStdOutOfRange( + "absl::string_view::at"), + ptr_[i]); ++#endif + } + + // string_view::front() + // + // Returns the first element of a `string_view`. + constexpr const_reference front() const { ++#if defined(__NVCC__) && (__CUDACC_VER_MAJOR__ < 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ < 2)) ++ // An NVCC bug treats the original return expression as a non-constant, ++ // which is not allowed in a constexpr function. This will be fixed in the ++ // CUDA 10.2 release. ++ return ptr_[0]; ++#else + return ABSL_ASSERT(!empty()), ptr_[0]; ++#endif + } + + // string_view::back() + // + // Returns the last element of a `string_view`. + constexpr const_reference back() const { ++#if defined(__NVCC__) && (__CUDACC_VER_MAJOR__ < 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ < 2)) ++ // An NVCC bug treats the original return expression as a non-constant, ++ // which is not allowed in a constexpr function. This will be fixed in the ++ // CUDA 10.2 release. ++ return ptr_[size() - 1]; ++#else + return ABSL_ASSERT(!empty()), ptr_[size() - 1]; ++#endif + } + + // string_view::data() +@@ -519,7 +547,14 @@ + (std::numeric_limits::max)(); + + static constexpr size_type CheckLengthInternal(size_type len) { ++#if defined(__NVCC__) && (__CUDACC_VER_MAJOR__ < 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ < 2)) ++ // An NVCC bug treats the original return expression as a non-constant, ++ // which is not allowed in a constexpr function. This will be fixed in the ++ // CUDA 10.2 release. ++ return len; ++#else + return (void)ABSL_ASSERT(len <= kMaxSize), len; ++#endif + } + + static constexpr size_type StrlenInternal(const char* str) { +--- ./absl/container/internal/compressed_tuple.h 2020-03-04 12:57:37.000000000 -0800 ++++ ./absl/container/internal/compressed_tuple.h.fixed 2019-06-20 11:54:01.000000000 -0700 +@@ -32,7 +32,6 @@ Revert to commit 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f as commit e9324d926a9189e222741fce6e676f0944661a72 includes a change not compatible with CUDA on Windows. + #ifndef ABSL_CONTAINER_INTERNAL_COMPRESSED_TUPLE_H_ + #define ABSL_CONTAINER_INTERNAL_COMPRESSED_TUPLE_H_ + +-#include + #include + #include + #include +@@ -77,110 +76,61 @@ + #endif + } + +-// We can't use EBCO on other CompressedTuples because that would mean that we +-// derive from multiple Storage<> instantiations with the same I parameter, +-// and potentially from multiple identical Storage<> instantiations. So anytime +-// we use type inheritance rather than encapsulation, we mark +-// CompressedTupleImpl, to make this easy to detect. +-struct uses_inheritance {}; +- + template + constexpr bool ShouldUseBase() { +- return std::is_class::value && std::is_empty::value && !IsFinal() && +- !std::is_base_of::value; ++ return std::is_class::value && std::is_empty::value && !IsFinal(); + } + + // The storage class provides two specializations: + // - For empty classes, it stores T as a base class. + // - For everything else, it stores T as a member. +-template ::type>()> +-#else +- bool UseBase = ShouldUseBase()> +-#endif ++template >()> + struct Storage { ++ using T = ElemT; + T value; + constexpr Storage() = default; +- template +- explicit constexpr Storage(absl::in_place_t, V&& v) +- : value(absl::forward(v)) {} ++ explicit constexpr Storage(T&& v) : value(absl::forward(v)) {} + constexpr const T& get() const& { return value; } + T& get() & { return value; } + constexpr const T&& get() const&& { return absl::move(*this).value; } + T&& get() && { return std::move(*this).value; } + }; + +-template +-struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC Storage : T { ++template ++struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC Storage ++ : ElemT { ++ using T = internal_compressed_tuple::ElemT; + constexpr Storage() = default; +- +- template +- explicit constexpr Storage(absl::in_place_t, V&& v) +- : T(absl::forward(v)) {} +- ++ explicit constexpr Storage(T&& v) : T(absl::forward(v)) {} + constexpr const T& get() const& { return *this; } + T& get() & { return *this; } + constexpr const T&& get() const&& { return absl::move(*this); } + T&& get() && { return std::move(*this); } + }; + +-template ++template + struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTupleImpl; + +-template +-struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTupleImpl< +- CompressedTuple, absl::index_sequence, ShouldAnyUseBase> ++template ++struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC ++ CompressedTupleImpl, absl::index_sequence> + // We use the dummy identity function through std::integral_constant to + // convince MSVC of accepting and expanding I in that context. Without it + // you would get: + // error C3548: 'I': parameter pack cannot be used in this context +- : uses_inheritance, +- Storage::value>... { +- constexpr CompressedTupleImpl() = default; +- template +- explicit constexpr CompressedTupleImpl(absl::in_place_t, Vs&&... args) +- : Storage(absl::in_place, absl::forward(args))... {} +- friend CompressedTuple; +-}; +- +-template +-struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTupleImpl< +- CompressedTuple, absl::index_sequence, false> +- // We use the dummy identity function as above... +- : Storage::value, false>... { ++ : Storage, ++ std::integral_constant::value>... { + constexpr CompressedTupleImpl() = default; +- template +- explicit constexpr CompressedTupleImpl(absl::in_place_t, Vs&&... args) +- : Storage(absl::in_place, absl::forward(args))... {} +- friend CompressedTuple; ++ explicit constexpr CompressedTupleImpl(Ts&&... args) ++ : Storage, I>(absl::forward(args))... {} + }; + +-std::false_type Or(std::initializer_list); +-std::true_type Or(std::initializer_list); +- +-// MSVC requires this to be done separately rather than within the declaration +-// of CompressedTuple below. +-template +-constexpr bool ShouldAnyUseBase() { +- return decltype( +- Or({std::integral_constant()>()...})){}; +-} +- +-template +-using TupleMoveConstructible = typename std::conditional< +- std::is_reference::value, std::is_convertible, +- std::is_constructible>::type; +- + } // namespace internal_compressed_tuple + + // Helper class to perform the Empty Base Class Optimization. + // Ts can contain classes and non-classes, empty or not. For the ones that + // are empty classes, we perform the CompressedTuple. If all types in Ts are +-// empty classes, then CompressedTuple is itself an empty class. (This +-// does not apply when one or more of those empty classes is itself an empty +-// CompressedTuple.) ++// empty classes, then CompressedTuple is itself an empty class. + // + // To access the members, use member .get() function. + // +@@ -196,58 +146,36 @@ + template + class ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTuple + : private internal_compressed_tuple::CompressedTupleImpl< +- CompressedTuple, absl::index_sequence_for, +- internal_compressed_tuple::ShouldAnyUseBase()> { ++ CompressedTuple, absl::index_sequence_for> { + private: + template + using ElemT = internal_compressed_tuple::ElemT; + +- template +- using StorageT = internal_compressed_tuple::Storage, I>; +- + public: +- // There seems to be a bug in MSVC dealing in which using '=default' here will +- // cause the compiler to ignore the body of other constructors. The work- +- // around is to explicitly implement the default constructor. +-#if defined(_MSC_VER) +- constexpr CompressedTuple() : CompressedTuple::CompressedTupleImpl() {} +-#else + constexpr CompressedTuple() = default; +-#endif +- explicit constexpr CompressedTuple(const Ts&... base) +- : CompressedTuple::CompressedTupleImpl(absl::in_place, base...) {} +- +- template ...)>>, +- internal_compressed_tuple::TupleMoveConstructible< +- Ts, Vs&&>...>::value, +- bool> = true> +- explicit constexpr CompressedTuple(Vs&&... base) +- : CompressedTuple::CompressedTupleImpl(absl::in_place, +- absl::forward(base)...) {} ++ explicit constexpr CompressedTuple(Ts... base) ++ : CompressedTuple::CompressedTupleImpl(absl::forward(base)...) {} + + template + ElemT& get() & { +- return internal_compressed_tuple::Storage, I>::get(); ++ return internal_compressed_tuple::Storage::get(); + } + + template + constexpr const ElemT& get() const& { +- return StorageT::get(); ++ return internal_compressed_tuple::Storage::get(); + } + + template + ElemT&& get() && { +- return std::move(*this).StorageT::get(); ++ return std::move(*this) ++ .internal_compressed_tuple::template Storage::get(); + } + + template + constexpr const ElemT&& get() const&& { +- return absl::move(*this).StorageT::get(); ++ return absl::move(*this) ++ .internal_compressed_tuple::template Storage::get(); + } + }; + From 292f5650c8431a70b29a509106d9fffe4150247d Mon Sep 17 00:00:00 2001 From: Yifei Feng Date: Wed, 18 Mar 2020 20:12:10 -0700 Subject: [PATCH 200/492] Add more details to how to fix license check failure in sanity build. PiperOrigin-RevId: 301725066 Change-Id: Ie45918d16efee9298e5e2061fe3b7275b419e417 --- tensorflow/tools/ci_build/ci_sanity.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index aff42505215..4f7ce00eb53 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -397,9 +397,10 @@ do_external_licenses_check(){ if [[ -s ${MISSING_LICENSES_FILE} ]] ; then echo "Missing the licenses for the following external dependencies:" cat ${MISSING_LICENSES_FILE} + echo "Please add the license(s) to ${LICENSES_TARGET}." fi if [[ -s ${EXTRA_LICENSES_FILE} ]] ; then - echo "Please remove the licenses for the following external dependencies:" + echo "Please remove the licenses for the following external dependencies from target ${LICENSES_TARGET}." cat ${EXTRA_LICENSES_FILE} fi rm -rf ${EXTERNAL_DEPENDENCIES_FILE} From 6c47413fa04ee04b85e7ce2a44a0fa1b3b16464f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BF=97=E8=B1=AA?= Date: Thu, 19 Mar 2020 11:26:23 +0800 Subject: [PATCH 201/492] Fix `'for' loop initial declarations are only allowed in C99 mode` Fix `'for' loop initial declarations are only allowed in C99 mode` --- tensorflow/lite/c/common.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c index 7196f32b62a..eb5584d1b9d 100644 --- a/tensorflow/lite/c/common.c +++ b/tensorflow/lite/c/common.c @@ -119,7 +119,8 @@ void TfLiteSparsityFree(TfLiteSparsity* sparsity) { } if (sparsity->dim_metadata) { - for (int i = 0; i < sparsity->dim_metadata_size; i++) { + int i; + for (i = 0; i < sparsity->dim_metadata_size; i++) { TfLiteDimensionMetadata metadata = sparsity->dim_metadata[i]; if (metadata.format == kTfLiteDimSparseCSR) { TfLiteIntArrayFree(metadata.array_segments); From ffbdfbbe6a86ce03480470f403e19f194d7592f7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 21:45:56 -0700 Subject: [PATCH 202/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301735872 Change-Id: Ie90dd25176dd2889a375239231134938f53b7a6d --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 3d05bb08fa3..7be0c66548c 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11815,7 +11815,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12072,7 +12072,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12083,7 +12083,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12301,7 +12301,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12312,7 +12312,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19153,7 +19153,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20224,7 +20224,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21396,7 +21396,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22104,7 +22104,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22300,7 +22300,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22484,7 +22484,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22543,7 +22543,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22717,7 +22717,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23098,7 +23098,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25441,7 +25441,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25504,7 +25504,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25747,7 +25747,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26370,7 +26370,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45499,7 +45499,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46287,7 +46287,7 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46350,7 +46350,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 6eca56262213a7ad6646e50edd3b38729d9ce947 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 18 Mar 2020 21:48:13 -0700 Subject: [PATCH 203/492] Support negative axis for reverse_v2. PiperOrigin-RevId: 301736084 Change-Id: Id23154b459566d5f85ef0c8d8bc98117c510b963 --- tensorflow/lite/kernels/reverse.cc | 6 +++++- tensorflow/lite/testing/op_tests/reverse_v2.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/kernels/reverse.cc b/tensorflow/lite/kernels/reverse.cc index 75114ee863a..760236ad6a7 100644 --- a/tensorflow/lite/kernels/reverse.cc +++ b/tensorflow/lite/kernels/reverse.cc @@ -68,8 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor); int axis = GetTensorData(axis_tensor)[0]; + const int rank = NumDimensions(input); + if (axis < 0) { + axis += rank; + } - TF_LITE_ENSURE(context, axis >= 0 && axis < NumDimensions(input)); + TF_LITE_ENSURE(context, axis >= 0 && axis < rank); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); switch (output->type) { diff --git a/tensorflow/lite/testing/op_tests/reverse_v2.py b/tensorflow/lite/testing/op_tests/reverse_v2.py index 97e833781dc..ed86b57a218 100644 --- a/tensorflow/lite/testing/op_tests/reverse_v2.py +++ b/tensorflow/lite/testing/op_tests/reverse_v2.py @@ -30,7 +30,7 @@ def make_reverse_v2_tests(options): test_parameters = [{ "dtype": [tf.float32, tf.bool], "base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]], - "axis": [0, 1, 2, 3], + "axis": [-2, -1, 0, 1, 2, 3], }] def get_valid_axis(parameters): From 7fb2f93d40a9a6e98fed29c296d40c537b92c148 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 18 Mar 2020 21:59:57 -0700 Subject: [PATCH 204/492] remove bidi rnn seq_len test, we didn't support seq_len anyway in tflite kernel anyway PiperOrigin-RevId: 301737256 Change-Id: I6bf748c78abe99075aeb30ec20a1ca265ab6717c --- .../lite/experimental/examples/lstm/BUILD | 1 - .../lstm/bidirectional_sequence_rnn_test.py | 55 ------------------- 2 files changed, 56 deletions(-) diff --git a/tensorflow/lite/experimental/examples/lstm/BUILD b/tensorflow/lite/experimental/examples/lstm/BUILD index 719e59c6a8c..cb5a98e4078 100644 --- a/tensorflow/lite/experimental/examples/lstm/BUILD +++ b/tensorflow/lite/experimental/examples/lstm/BUILD @@ -111,7 +111,6 @@ py_test( tags = [ "no_oss", "no_pip", - "notap", # b/141373014 ], deps = [ ":rnn", diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py index 00fdb4a2f96..2f0a7821572 100644 --- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py +++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py @@ -297,33 +297,6 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False) self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) - def testStaticRnnMultiRnnCellWithSequenceLength(self): - sess = tf.compat.v1.Session() - - x, prediction, output_class = self.buildModel( - self.buildRnnLayer(), - self.buildRnnLayer(), - False, - is_inference=False, - use_sequence_length=True) - self.trainModel(x, prediction, output_class, sess) - - saver = tf.train.Saver() - x, prediction, output_class, new_sess = self.saveAndRestoreModel( - self.buildRnnLayer(), - self.buildRnnLayer(), - sess, - saver, - False, - use_sequence_length=True) - - test_inputs, expected_output = self.getInferenceResult( - x, output_class, new_sess) - - # Test Toco-converted model. - result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False) - self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) - @test_util.enable_control_flow_v2 def testDynamicRnnMultiRnnCell(self): sess = tf.compat.v1.Session() @@ -347,34 +320,6 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False) self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) - @test_util.enable_control_flow_v2 - def testDynamicRnnMultiRnnCellWithSequenceLength(self): - sess = tf.compat.v1.Session() - - x, prediction, output_class = self.buildModel( - self.buildRnnLayer(), - self.buildRnnLayer(), - True, - is_inference=False, - use_sequence_length=True) - self.trainModel(x, prediction, output_class, sess) - - saver = tf.compat.v1.train.Saver() - x, prediction, output_class, new_sess = self.saveAndRestoreModel( - self.buildRnnLayer(), - self.buildRnnLayer(), - sess, - saver, - is_dynamic_rnn=True, - use_sequence_length=True) - - test_inputs, expected_output = self.getInferenceResult( - x, output_class, new_sess) - - # Test Toco-converted model. - result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False) - self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) - if __name__ == "__main__": tf.disable_v2_behavior() From f3e1c5c94dc776885758244b1434ff3ffcde5293 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Mar 2020 22:01:39 -0700 Subject: [PATCH 205/492] Make exported MetaGraphDefs deterministic Nondeterministic serialized protos hurt caching. PiperOrigin-RevId: 301737437 Change-Id: Ia354d81bac2f2140d0ca84feca85d70a03c25a67 --- tensorflow/python/framework/graph_io.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/graph_io.py b/tensorflow/python/framework/graph_io.py index 68d7f35f825..1e014a61e96 100644 --- a/tensorflow/python/framework/graph_io.py +++ b/tensorflow/python/framework/graph_io.py @@ -71,5 +71,6 @@ def write_graph(graph_or_graph_def, logdir, name, as_text=True): text_format.MessageToString( graph_def, float_format='')) else: - file_io.atomic_write_string_to_file(path, graph_def.SerializeToString()) + file_io.atomic_write_string_to_file( + path, graph_def.SerializeToString(deterministic=True)) return path From c66ecc5c0e5cd3612ec52d307d601aef88f6570c Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 18 Mar 2020 22:15:51 -0700 Subject: [PATCH 206/492] Disable broken select_and_scatter_test PiperOrigin-RevId: 301739389 Change-Id: I7c0687e276aac637eef1ca920b3899313eb8d50d --- tensorflow/compiler/xla/tests/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 3255aa84685..115d3a0edfa 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1497,6 +1497,8 @@ xla_test( timeout = "long", srcs = ["select_and_scatter_test.cc"], tags = [ + "manual", # TODO(b/151876386) + "no_oss", "no_rocm", "optonly", ], From 732bce892c0559c35bf7f3bb4408498e7dba8acd Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 18 Mar 2020 22:32:39 -0700 Subject: [PATCH 207/492] Fix a floating-point check in the max pool kernel PiperOrigin-RevId: 301741177 Change-Id: I46de90f4781165733aba6e1a4c84517a2cfed24c --- tensorflow/lite/kernels/pooling.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/kernels/pooling.cc b/tensorflow/lite/kernels/pooling.cc index e871b72f4a1..63c6eb1239f 100644 --- a/tensorflow/lite/kernels/pooling.cc +++ b/tensorflow/lite/kernels/pooling.cc @@ -91,9 +91,9 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (pool_type == kAverage || pool_type == kMax) { - TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); - TF_LITE_ENSURE_EQ(context, input->params.zero_point, - output->params.zero_point); + TFLITE_DCHECK_LE(std::abs(input->params.scale - output->params.scale), + 1.0e-6); + TFLITE_DCHECK_EQ(input->params.zero_point, output->params.zero_point); } if (pool_type == kL2) { // We currently don't have a quantized implementation of L2Pool From 7b2850ee7de47b4a270e6a1ec38c73b9395b2483 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 18 Mar 2020 22:49:31 -0700 Subject: [PATCH 208/492] [XLA] Fix select and scatter tests. - Use std::vector instead of absl::span - Reduce the size of one large test (192s -> 35s) - Reenable the test in OSS. PiperOrigin-RevId: 301743200 Change-Id: I2788859a9210a43e4fcd65478701ed7df420d181 --- tensorflow/compiler/xla/tests/BUILD | 2 -- .../compiler/xla/tests/select_and_scatter_test.cc | 10 +++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 115d3a0edfa..3255aa84685 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1497,8 +1497,6 @@ xla_test( timeout = "long", srcs = ["select_and_scatter_test.cc"], tags = [ - "manual", # TODO(b/151876386) - "no_oss", "no_rocm", "optonly", ], diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 79ac469d89b..9c8b935428f 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -42,8 +42,8 @@ struct SelectAndScatterTestParam { std::vector operand_shape; std::vector source_shape; Padding padding_type; - absl::Span window_dimensions; - absl::Span window_strides; + std::vector window_dimensions; + std::vector window_strides; }; class SelectAndScatterTest @@ -186,11 +186,11 @@ INSTANTIATE_TEST_CASE_P( Padding::kValid, {2, 1, 1}, {3, 1, 1}}, - SelectAndScatterTestParam{{160, 160, 8, 256}, + SelectAndScatterTestParam{{10, 10, 8, 256}, {5, 5, 8, 256}, Padding::kSame, - {32, 32, 1, 1}, - {32, 32, 1, 1}}, + {2, 2, 1, 1}, + {2, 2, 1, 1}}, SelectAndScatterTestParam{ {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, SelectAndScatterTestParam{ From d6a7beddcc971018d79f3be676dd687357334d19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BF=97=E8=B1=AA?= Date: Thu, 19 Mar 2020 14:18:25 +0800 Subject: [PATCH 209/492] Apply suggestion --- tensorflow/lite/c/common.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c index eb5584d1b9d..a9d92b223ca 100644 --- a/tensorflow/lite/c/common.c +++ b/tensorflow/lite/c/common.c @@ -119,8 +119,8 @@ void TfLiteSparsityFree(TfLiteSparsity* sparsity) { } if (sparsity->dim_metadata) { - int i; - for (i = 0; i < sparsity->dim_metadata_size; i++) { + int i = 0; + for (; i < sparsity->dim_metadata_size; i++) { TfLiteDimensionMetadata metadata = sparsity->dim_metadata[i]; if (metadata.format == kTfLiteDimSparseCSR) { TfLiteIntArrayFree(metadata.array_segments); From 367d4a9690cc039ac03f3045f4b3b80f3e7a726c Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 18 Mar 2020 23:15:30 -0700 Subject: [PATCH 210/492] Expose tensorflow/c/**/*.h in tensorflow pip package includes. These are required for code that relies on e.g. tensorflow/python/**/*.h includes, since these use the C api. PiperOrigin-RevId: 301747036 Change-Id: Ib04535523f4a922cc7db32a3483d004876eb4d28 --- tensorflow/tools/pip_package/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 4dfe616263b..64a4469e0da 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -246,6 +246,7 @@ headers = ( list(find_files('*.proto', 'tensorflow/compiler')) + list(find_files('*.proto', 'tensorflow/core')) + list(find_files('*.proto', 'tensorflow/python')) + + list(find_files('*.h', 'tensorflow/c')) + list(find_files('*.h', 'tensorflow/cc')) + list(find_files('*.h', 'tensorflow/compiler')) + list(find_files('*.h', 'tensorflow/core')) + From 48ca2c65fa3861ef9824cf1787af4e79e4d71dc2 Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Wed, 18 Mar 2020 23:23:24 -0700 Subject: [PATCH 211/492] Upgrade to Flatbuffer 1.12 PiperOrigin-RevId: 301748175 Change-Id: I3fd58c962bbb972d6c6c81931cde58df7472ed4c --- tensorflow/lite/tools/make/download_dependencies.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/tools/make/download_dependencies.sh b/tensorflow/lite/tools/make/download_dependencies.sh index 30ae7579d3d..2156feafef0 100755 --- a/tensorflow/lite/tools/make/download_dependencies.sh +++ b/tensorflow/lite/tools/make/download_dependencies.sh @@ -44,8 +44,8 @@ ABSL_SHA="$(eval echo $(grep '# SHARED_ABSL_SHA' "${BZL_FILE_PATH}" | grep -o '\ NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" FARMHASH_SHA="$(eval echo $(grep '# SHARED_FARMHASH_SHA' "${BZL_FILE_PATH}" | grep -o '\".*\"'))" -FLATBUFFERS_URL="https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.11.0.tar.gz" -FLATBUFFERS_SHA="3f4a286642094f45b1b77228656fbd7ea123964f19502f9ecfd29933fd23a50b" +FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz" +FLATBUFFERS_SHA="62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45" FFT2D_URL="https://storage.googleapis.com/mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz" FP16_URL="https://github.com/Maratyszcza/FP16/archive/febbb1c163726b5db24bed55cc9dc42529068997.zip" FFT2D_SHA="ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9" From 8b219365ebccda053c38f471e1b6c17c7afac3a8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 00:01:30 -0700 Subject: [PATCH 212/492] Moved VectorOps to Vector PiperOrigin-RevId: 301751648 Change-Id: I8e95ed1be7d0e58d897760aa3060eeb4abc6227c --- .../compiler/mlir/lite/flatbuffer_import.cc | 4 +- .../mlir/lite/flatbuffer_translate.cc | 2 +- tensorflow/compiler/mlir/lite/ir/tfl_ops.h | 2 +- .../quantization/import_quant_stats_pass.cc | 4 +- .../mlir/lite/quantization/lite/tfl_to_std.cc | 2 +- .../mlir/lite/quantization/quantization.td | 2 +- .../lite/quantization/quantization_driver.cc | 4 +- .../lite/quantization/quantization_traits.h | 2 +- .../lite/quantization/quantization_utils.cc | 10 +- .../lite/quantization/quantization_utils.h | 6 +- .../quantization/tensorflow/tf_to_quant.cc | 2 +- .../mlir/lite/quantization/xla/materialize.cc | 2 +- .../lite/transforms/default_quant_params.cc | 4 +- .../mlir/lite/transforms/legalize_tf.cc | 4 +- .../transforms/load_quantization_recipe.cc | 2 +- .../mlir/lite/transforms/prepare_quantize.cc | 2 +- .../mlir/lite/transforms/prepare_tf.cc | 4 +- .../compiler/mlir/lite/transforms/quantize.cc | 2 +- tensorflow/workspace.bzl | 4 +- third_party/mlir/BUILD | 119 ++++++++++++------ 20 files changed, 115 insertions(+), 68 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 29233f86e4a..4b888764053 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -44,8 +44,8 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index a5831559546..a75c1b3bab2 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -41,7 +41,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index cfe18a218bc..ffdafc1844f 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -18,7 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project #include "mlir/Dialect/Traits.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 617f968b958..26062b96de0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -23,8 +23,8 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/AffineExpr.h" // TF:llvm-project #include "mlir/IR/AffineMap.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc index 41efadde20d..d680c889d2c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization.td b/tensorflow/compiler/mlir/lite/quantization/quantization.td index 966740e605f..7bfcdb65686 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/quantization.td @@ -20,7 +20,7 @@ limitations under the License. #define TF_Quantization include "mlir/IR/OpBase.td" -include "mlir/Dialect/QuantOps/QuantOpsBase.td" +include "mlir/Dialect/Quant/QuantOpsBase.td" //===----------------------------------------------------------------------===// // QuantizedType definitions. diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 5f52c892421..531a442fd6b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -24,8 +24,8 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index db2567fbda0..885831ad0ce 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -18,7 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index a321170349a..f5c7287631a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -22,11 +22,11 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantizeUtils.h" // TF:llvm-project +#include "mlir/Dialect/Quant/UniformSupport.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index e9d29758823..6a54262363c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -23,9 +23,9 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index d2884edafdf..1a310de8b01 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc index 59704b4c73a..ab170def2b5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc @@ -25,7 +25,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 30fe391762f..bb48c392a5f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 80689f7b7c4..3210ac7bc2b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -28,8 +28,8 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/UniformSupport.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc index 4fde08bc1cf..59b1dcce35d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc +++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc @@ -19,7 +19,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 316a9d2cf2a..287b9ca911c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 1ff321780a4..efcc950cae0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -39,8 +39,8 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/Quant/UniformSupport.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index a9570625400..c78b04df247 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -19,7 +19,7 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 085ae961cd6..1bdcc0abd4f 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -597,8 +597,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "98369178bc695ba5d64314beb62d5ba5c9f14e2e" - LLVM_SHA256 = "c30eb278889c64e5a57e31d9bad794c6019d5396ce58a6ba874b0e4763f21097" + LLVM_COMMIT = "c5b81466c2bcc194e5563f39f5be3638760b4849" + LLVM_SHA256 = "f623a7e9585e76831abc967547dfbcd5a6ecd148ed5c4e088bdae94dc7d8bda7" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 0e4ac2c07b6..e9fef46c4df 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -406,6 +406,52 @@ cc_library( ], ) +gentbl( + name = "ShapeOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-op-decls", + "include/mlir/Dialect/Shape/IR/ShapeOps.h.inc", + ), + ( + "-gen-op-defs", + "include/mlir/Dialect/Shape/IR/ShapeOps.cpp.inc", + ), + ( + "-gen-dialect-decls", + "include/mlir/Dialect/Shape/IR/ShapeOpsDialect.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Shape/IR/ShapeOps.td", + td_srcs = [ + ":StdOpsTdFiles", + ], +) + +cc_library( + name = "Shape", + srcs = glob( + [ + "lib/Dialect/Shape/IR/*.cpp", + ], + ), + hdrs = glob([ + "include/mlir/Dialect/Shape/IR/*.h", + ]), + includes = ["include"], + deps = [ + ":CallOpInterfaces", + ":CommonFolders", + ":IR", + ":ShapeOpsIncGen", + ":SideEffects", + ":Support", + "@llvm-project//llvm:support", + ], +) + cc_library( name = "StandardOps", srcs = glob( @@ -439,15 +485,15 @@ cc_library( name = "VectorOps", srcs = glob( [ - "lib/Dialect/VectorOps/*.cpp", - "lib/Dialect/VectorOps/*.h", - "lib/Dialect/VectorOps/EDSC/*.cpp", - "lib/Dialect/VectorOps/EDSC/*.h", + "lib/Dialect/Vector/*.cpp", + "lib/Dialect/Vector/*.h", + "lib/Dialect/Vector/EDSC/*.cpp", + "lib/Dialect/Vector/EDSC/*.h", ], ), hdrs = glob([ - "include/mlir/Dialect/VectorOps/*.h", - "include/mlir/Dialect/VectorOps/EDSC/*.h", + "include/mlir/Dialect/Vector/*.h", + "include/mlir/Dialect/Vector/EDSC/*.h", ]), includes = ["include"], deps = [ @@ -2004,6 +2050,7 @@ cc_library( ":SDBM", ":SPIRVDialect", ":SPIRVLowering", + ":Shape", ":StandardOps", ":StandardToSPIRVConversions", ":Transforms", @@ -2256,8 +2303,8 @@ cc_library( filegroup( name = "QuantizationOpsTdFiles", srcs = [ - "include/mlir/Dialect/QuantOps/QuantOps.td", - "include/mlir/Dialect/QuantOps/QuantOpsBase.td", + "include/mlir/Dialect/Quant/QuantOps.td", + "include/mlir/Dialect/Quant/QuantOpsBase.td", "include/mlir/Interfaces/SideEffects.td", ":OpBaseTdFiles", ], @@ -2270,15 +2317,15 @@ gentbl( tbl_outs = [ ( "-gen-op-decls", - "include/mlir/Dialect/QuantOps/QuantOps.h.inc", + "include/mlir/Dialect/Quant/QuantOps.h.inc", ), ( "-gen-op-defs", - "include/mlir/Dialect/QuantOps/QuantOps.cpp.inc", + "include/mlir/Dialect/Quant/QuantOps.cpp.inc", ), ( "-gen-dialect-decls", - "include/mlir/Dialect/QuantOps/QuantOpsDialect.h.inc", + "include/mlir/Dialect/Quant/QuantOpsDialect.h.inc", ), ( "-gen-op-doc", @@ -2286,7 +2333,7 @@ gentbl( ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/QuantOps/QuantOps.td", + td_file = "include/mlir/Dialect/Quant/QuantOps.td", td_srcs = [ ":QuantizationOpsTdFiles", ], @@ -2295,23 +2342,23 @@ gentbl( cc_library( name = "QuantOps", srcs = [ - "lib/Dialect/QuantOps/IR/QuantOps.cpp", - "lib/Dialect/QuantOps/IR/QuantTypes.cpp", - "lib/Dialect/QuantOps/IR/TypeDetail.h", - "lib/Dialect/QuantOps/IR/TypeParser.cpp", - "lib/Dialect/QuantOps/Transforms/ConvertConst.cpp", - "lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp", - "lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp", - "lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp", - "lib/Dialect/QuantOps/Utils/UniformSupport.cpp", + "lib/Dialect/Quant/IR/QuantOps.cpp", + "lib/Dialect/Quant/IR/QuantTypes.cpp", + "lib/Dialect/Quant/IR/TypeDetail.h", + "lib/Dialect/Quant/IR/TypeParser.cpp", + "lib/Dialect/Quant/Transforms/ConvertConst.cpp", + "lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp", + "lib/Dialect/Quant/Utils/FakeQuantSupport.cpp", + "lib/Dialect/Quant/Utils/QuantizeUtils.cpp", + "lib/Dialect/Quant/Utils/UniformSupport.cpp", ], hdrs = [ - "include/mlir/Dialect/QuantOps/FakeQuantSupport.h", - "include/mlir/Dialect/QuantOps/Passes.h", - "include/mlir/Dialect/QuantOps/QuantOps.h", - "include/mlir/Dialect/QuantOps/QuantTypes.h", - "include/mlir/Dialect/QuantOps/QuantizeUtils.h", - "include/mlir/Dialect/QuantOps/UniformSupport.h", + "include/mlir/Dialect/Quant/FakeQuantSupport.h", + "include/mlir/Dialect/Quant/Passes.h", + "include/mlir/Dialect/Quant/QuantOps.h", + "include/mlir/Dialect/Quant/QuantTypes.h", + "include/mlir/Dialect/Quant/QuantizeUtils.h", + "include/mlir/Dialect/Quant/UniformSupport.h", ], includes = ["include"], deps = [ @@ -2328,7 +2375,7 @@ filegroup( name = "FxpMathOpsTdFiles", srcs = [ "include/mlir/Dialect/FxpMathOps/FxpMathOps.td", - "include/mlir/Dialect/QuantOps/QuantOpsBase.td", + "include/mlir/Dialect/Quant/QuantOpsBase.td", "include/mlir/Interfaces/SideEffects.td", ":OpBaseTdFiles", ], @@ -2681,7 +2728,7 @@ cc_library( filegroup( name = "VectorOpsTdFiles", srcs = [ - "include/mlir/Dialect/VectorOps/VectorOps.td", + "include/mlir/Dialect/Vector/VectorOps.td", ":AffineOpsTdFiles", ":OpBaseTdFiles", ], @@ -2693,15 +2740,15 @@ gentbl( tbl_outs = [ ( "-gen-op-decls", - "include/mlir/Dialect/VectorOps/VectorOps.h.inc", + "include/mlir/Dialect/Vector/VectorOps.h.inc", ), ( "-gen-op-defs", - "include/mlir/Dialect/VectorOps/VectorOps.cpp.inc", + "include/mlir/Dialect/Vector/VectorOps.cpp.inc", ), ( "-gen-dialect-decls -dialect=vector", - "include/mlir/Dialect/VectorOps/VectorOpsDialect.h.inc", + "include/mlir/Dialect/Vector/VectorOpsDialect.h.inc", ), ( "-gen-op-doc", @@ -2709,7 +2756,7 @@ gentbl( ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/VectorOps/VectorOps.td", + td_file = "include/mlir/Dialect/Vector/VectorOps.td", td_srcs = [ ":VectorOpsTdFiles", ], @@ -2718,7 +2765,7 @@ gentbl( filegroup( name = "VectorTransformPatternsTdFiles", srcs = [ - "include/mlir/Dialect/VectorOps/VectorTransformPatterns.td", + "include/mlir/Dialect/Vector/VectorTransformPatterns.td", ":AffineOpsTdFiles", ":LinalgOpsTdFiles", ":LinalgStructuredOpsTdFiles", @@ -2733,11 +2780,11 @@ gentbl( tbl_outs = [ ( "-gen-rewriters", - "include/mlir/Dialect/VectorOps/VectorTransformPatterns.h.inc", + "include/mlir/Dialect/Vector/VectorTransformPatterns.h.inc", ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/VectorOps/VectorTransformPatterns.td", + td_file = "include/mlir/Dialect/Vector/VectorTransformPatterns.td", td_srcs = [ ":VectorTransformPatternsTdFiles", ], From 9a7ad6fe1231291fb139f353b7a596a431c1b1b9 Mon Sep 17 00:00:00 2001 From: Xinan Jiang Date: Wed, 18 Mar 2020 18:35:17 +0800 Subject: [PATCH 213/492] [MLIR][XLA] Fix ops erase bug in DeadTempBufferRemoval --- .../compiler/xla/service/mlir_gpu/kernel_lowering.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 151d82fd2a1..45a9470808b 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -244,14 +244,19 @@ struct DeadTempBufferRemoval : mlir::FunctionPass { } void runOnFunction() override { + llvm::SmallVector opsToErase; getFunction().walk([&](mlir::AllocOp allocOp) { if (!operationConsideredDead(allocOp)) { return; } - // TODO(herhut): There should be a generic helper for this. - recursiveErase(allocOp); + opsToErase.push_back(allocOp); }); + + for (auto *op : opsToErase) { + // TODO(herhut): There should be a generic helper for this. + recursiveErase(op); + } } }; From 91577d1d8c085971eac6e97e32e7fab47bfb8568 Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Thu, 19 Mar 2020 00:56:20 -0700 Subject: [PATCH 214/492] Add 5D support for BroadcastSub PiperOrigin-RevId: 301758471 Change-Id: Ifa9eb85543d9407cf7581eb903785808626ae1a1 --- .../delegates/nnapi/acceleration_test_list.cc | 1 + .../lite/kernels/batch_to_space_nd_test.cc | 4 +- tensorflow/lite/kernels/internal/common.h | 41 +++ .../internal/optimized/legacy_optimized_ops.h | 2 +- .../internal/optimized/optimized_ops.h | 6 +- .../lite/kernels/internal/reference/sub.h | 320 ++++++++---------- tensorflow/lite/kernels/register.cc | 2 +- tensorflow/lite/kernels/sub.cc | 12 +- tensorflow/lite/kernels/sub_test.cc | 22 +- tensorflow/lite/kernels/test_util.cc | 5 +- tensorflow/lite/micro/kernels/sub.cc | 6 +- tensorflow/lite/testing/op_tests/binary_op.py | 24 +- tensorflow/lite/toco/tflite/export_test.cc | 3 + tensorflow/lite/toco/tflite/operator.cc | 17 + tensorflow/lite/toco/tflite/operator_test.cc | 68 +++- .../lite/tools/versioning/op_version.cc | 31 +- tensorflow/lite/tools/versioning/op_version.h | 4 + 17 files changed, 367 insertions(+), 201 deletions(-) diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index ea41ea01d81..1c98ea56bbc 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -334,6 +334,7 @@ SplitOpTest/SplitOpTest/.+/0,29 FloatSqueezeOpTest/.+,29 # sub_test +-FloatSubOpModel/WithBroadcast5D FloatSubOpModel/.+ -QuantizedSubOpModel/.+Int16 -QuantizedSubOpModel/.+Int8 diff --git a/tensorflow/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/lite/kernels/batch_to_space_nd_test.cc index a279fd45f55..cffa1036c84 100644 --- a/tensorflow/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/lite/kernels/batch_to_space_nd_test.cc @@ -67,7 +67,7 @@ class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel { std::initializer_list crops, const TensorType& type = TensorType_FLOAT32) { int spatial_dims = static_cast(block_shape.size()); - input_ = AddInput(type); + input_ = AddInput({type, input_shape}); block_shape_ = AddConstInput(TensorType_INT32, block_shape, {spatial_dims}); crops_ = AddConstInput(TensorType_INT32, crops, {spatial_dims, 2}); output_ = AddOutput(type); @@ -91,7 +91,7 @@ class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel { public: BatchToSpaceNDOpDynamicModel(std::initializer_list input_shape, const TensorType& type = TensorType_FLOAT32) { - input_ = AddInput(type); + input_ = AddInput({type, input_shape}); block_shape_ = AddInput(TensorType_INT32); crops_ = AddInput(TensorType_INT32); output_ = AddOutput(type); diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index 97c2f6eae4a..73fc8590e74 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -524,6 +524,12 @@ inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, i3 * desc.strides[3]; } +inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) { + return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] + + indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] + + indexes[4] * desc.strides[4]; +} + // Given the dimensions of the operands for an element-wise binary broadcast, // adjusts them so that they can be directly iterated over with simple loops. // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and @@ -677,6 +683,41 @@ inline void NdArrayDescsForElementwiseBroadcast( } } +// Detailed implementation of NDOpsHelper, the indexes must be a zero array. +// This implementation is equivalent to N nested loops. Ex, if N=4, it can be +// re-writen as: +// for (int b = 0; b < output.extents[0]; ++b) { +// for (int y = 0; y < output.extents[1]; ++y) { +// for (int x = 0; x < output.extents[2]; ++x) { +// for (int c = 0; c < output.extents[3]; ++c) { +// calc({b,y,x,c}); +// } +// } +// } +// } +template +typename std::enable_if::type NDOpsHelperImpl( + const NdArrayDesc& output, const Calc& calc, int indexes[N]) { + for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) { + NDOpsHelperImpl(output, calc, indexes); + } +} + +template +typename std::enable_if::type NDOpsHelperImpl( + const NdArrayDesc& output, const Calc& calc, int indexes[N]) { + for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) { + calc(indexes); + } +} + +// Execute the calc function in the innermost iteration based on the shape of +// the output. The calc function should take a single argument of type int[N]. +template +inline void NDOpsHelper(const NdArrayDesc& output, const Calc& calc) { + int indexes[N] = {0}; + NDOpsHelperImpl(output, calc, indexes); +} // Copied from gemmlowp::RoundDown when we dropped direct dependency on // gemmlowp. // diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index 325498b3f3f..f87738c34ff 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -48,7 +48,7 @@ using reference_ops::BroadcastGreaterEqual; using reference_ops::BroadcastLess; using reference_ops::BroadcastLessEqual; using reference_ops::BroadcastMul4DSlow; -using reference_ops::BroadcastSub4DSlow; +using reference_ops::BroadcastSubSlow; using reference_ops::Concatenation; using reference_ops::ConcatenationWithScaling; using reference_ops::DepthConcatenation; diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 15006d12c08..09122686db5 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -70,7 +70,7 @@ using reference_ops::Broadcast4DSlowLessEqualWithScaling; using reference_ops::Broadcast4DSlowLessWithScaling; using reference_ops::BroadcastAdd4DSlow; using reference_ops::BroadcastMul4DSlow; -using reference_ops::BroadcastSub4DSlow; +using reference_ops::BroadcastSubSlow; using reference_ops::Concatenation; using reference_ops::ConcatenationWithScaling; using reference_ops::DepthConcatenation; @@ -2959,8 +2959,8 @@ void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape, auto scalar = input2_data[0]; output_map.array() = input1_map.array() - scalar; } else { - BroadcastSub4DSlow(params, input1_shape, input1_data, input2_shape, - input2_data, output_shape, output_data); + BroadcastSubSlow(params, input1_shape, input1_data, input2_shape, + input2_data, output_shape, output_data); } } diff --git a/tensorflow/lite/kernels/internal/reference/sub.h b/tensorflow/lite/kernels/internal/reference/sub.h index 4f4a9156121..ae48491c04e 100644 --- a/tensorflow/lite/kernels/internal/reference/sub.h +++ b/tensorflow/lite/kernels/internal/reference/sub.h @@ -63,20 +63,24 @@ inline void SubNonBroadcast(const ArithmeticParams& params, // reference_ops.h. Once an optimized version is implemented and NdArrayDesc // is no longer referenced in this file, move NdArrayDesc from types.h to // reference_ops.h. -inline void BroadcastSub4DSlow(const ArithmeticParams& params, - const RuntimeShape& input1_shape, - const float* input1_data, - const RuntimeShape& input2_shape, - const float* input2_data, - const RuntimeShape& output_shape, - float* output_data) { - ruy::profiler::ScopeLabel label("BroadcastSub4DSlow/float"); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; +template +inline void BroadcastSubSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + ruy::profiler::ScopeLabel label("BroadcastSubSlow/float"); + TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N); + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2); - const RuntimeShape extended_output_shape = - RuntimeShape::ExtendedShape(4, output_shape); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -89,35 +93,34 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < extended_output_shape.Dims(0); ++b) { - for (int y = 0; y < extended_output_shape.Dims(1); ++y) { - for (int x = 0; x < extended_output_shape.Dims(2); ++x) { - for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - output_data[Offset(extended_output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, b, y, x, c)] - - input2_data[SubscriptToIndex(desc2, b, y, x, c)], - params.float_activation_min, params.float_activation_max); - } - } - } - } + auto sub_func = [&](int indexes[N]) { + output_data[SubscriptToIndex(output_desc, indexes)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, indexes)] - + input2_data[SubscriptToIndex(desc2, indexes)], + params.float_activation_min, params.float_activation_max); + }; + NDOpsHelper(output_desc, sub_func); } -inline void BroadcastSub4DSlow(const ArithmeticParams& params, - const RuntimeShape& input1_shape, - const uint8* input1_data, - const RuntimeShape& input2_shape, - const uint8* input2_data, - const RuntimeShape& output_shape, - uint8* output_data) { - ruy::profiler::ScopeLabel label("BroadcastSub4DSlow/uint8"); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; +template +inline void BroadcastSubSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const uint8* input1_data, + const RuntimeShape& input2_shape, + const uint8* input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + ruy::profiler::ScopeLabel label("BroadcastSubSlow/uint8"); + TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N); + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2); - const RuntimeShape extended_output_shape = - RuntimeShape::ExtendedShape(4, output_shape); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -130,58 +133,51 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < extended_output_shape.Dims(0); ++b) { - for (int y = 0; y < extended_output_shape.Dims(1); ++y) { - for (int x = 0; x < extended_output_shape.Dims(2); ++x) { - for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - const int32 input1_val = - params.input1_offset + - input1_data[SubscriptToIndex(desc1, b, y, x, c)]; - const int32 input2_val = - params.input2_offset + - input2_data[SubscriptToIndex(desc2, b, y, x, c)]; - const int32 shifted_input1_val = - input1_val * (1 << params.left_shift); - const int32 shifted_input2_val = - input2_val * (1 << params.left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, params.input1_multiplier, - params.input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, params.input2_multiplier, - params.input2_shift); - const int32 raw_sub = scaled_input1_val - scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sub, params.output_multiplier, params.output_shift) + - params.output_offset; - const int32 clamped_output = - std::min(params.quantized_activation_max, - std::max(params.quantized_activation_min, raw_output)); - output_data[Offset(extended_output_shape, b, y, x, c)] = - static_cast(clamped_output); - } - } - } - } + auto sub_func = [&](int indexes[N]) { + const int32 input1_val = + params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)]; + const int32 input2_val = + params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)]; + const int32 shifted_input1_val = input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = input2_val * (1 << params.left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, params.input1_multiplier, params.input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, params.input2_multiplier, params.input2_shift); + const int32 raw_sub = scaled_input1_val - scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[SubscriptToIndex(output_desc, indexes)] = + static_cast(clamped_output); + }; + NDOpsHelper(output_desc, sub_func); } -inline void BroadcastSub4DSlow(const ArithmeticParams& params, - const RuntimeShape& input1_shape, - const int32* input1_data, - const RuntimeShape& input2_shape, - const int32* input2_data, - const RuntimeShape& output_shape, - int32* output_data) { - ruy::profiler::ScopeLabel label("BroadcastSub4DSlow/int32"); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; +template +inline void BroadcastSubSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + ruy::profiler::ScopeLabel label("BroadcastSubSlow/int32"); + TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N); + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2); - const RuntimeShape extended_output_shape = - RuntimeShape::ExtendedShape(4, output_shape); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -194,36 +190,31 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < extended_output_shape.Dims(0); ++b) { - for (int y = 0; y < extended_output_shape.Dims(1); ++y) { - for (int x = 0; x < extended_output_shape.Dims(2); ++x) { - for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - output_data[Offset(extended_output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, b, y, x, c)] - - input2_data[SubscriptToIndex(desc2, b, y, x, c)], - params.quantized_activation_min, - params.quantized_activation_max); - } - } - } - } + auto sub_func = [&](int indexes[N]) { + output_data[SubscriptToIndex(output_desc, indexes)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, indexes)] - + input2_data[SubscriptToIndex(desc2, indexes)], + params.quantized_activation_min, params.quantized_activation_max); + }; + NDOpsHelper(output_desc, sub_func); } -inline void BroadcastSub4DSlow(const ArithmeticParams& params, - const RuntimeShape& input1_shape, - const int8_t* input1_data, - const RuntimeShape& input2_shape, - const int8_t* input2_data, - const RuntimeShape& output_shape, - int8_t* output_data) { - ruy::profiler::ScopeLabel label("BroadcastSub4DSlow/int8"); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; +template +inline void BroadcastSubSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int8_t* input1_data, + const RuntimeShape& input2_shape, + const int8_t* input2_data, + const RuntimeShape& output_shape, + int8_t* output_data) { + ruy::profiler::ScopeLabel label("BroadcastSubSlow/int8"); + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2); - const RuntimeShape extended_output_shape = - RuntimeShape::ExtendedShape(4, output_shape); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -236,56 +227,48 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < extended_output_shape.Dims(0); ++b) { - for (int y = 0; y < extended_output_shape.Dims(1); ++y) { - for (int x = 0; x < extended_output_shape.Dims(2); ++x) { - for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - const int32_t input1_val = - params.input1_offset + - input1_data[SubscriptToIndex(desc1, b, y, x, c)]; - const int32_t input2_val = - params.input2_offset + - input2_data[SubscriptToIndex(desc2, b, y, x, c)]; - const int32_t shifted_input1_val = - input1_val * (1 << params.left_shift); - const int32_t shifted_input2_val = - input2_val * (1 << params.left_shift); - const int32_t scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, params.input1_multiplier, - params.input1_shift); - const int32_t scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, params.input2_multiplier, - params.input2_shift); - const int32_t raw_sub = scaled_input1_val - scaled_input2_val; - const int32_t raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sub, params.output_multiplier, params.output_shift) + - params.output_offset; - const int32_t clamped_output = - std::min(params.quantized_activation_max, - std::max(params.quantized_activation_min, raw_output)); - output_data[Offset(extended_output_shape, b, y, x, c)] = - static_cast(clamped_output); - } - } - } - } + auto sub_func = [&](int indexes[N]) { + const int32_t input1_val = + params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)]; + const int32_t input2_val = + params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)]; + const int32_t shifted_input1_val = input1_val * (1 << params.left_shift); + const int32_t shifted_input2_val = input2_val * (1 << params.left_shift); + const int32_t scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, params.input1_multiplier, params.input1_shift); + const int32_t scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, params.input2_multiplier, params.input2_shift); + const int32_t raw_sub = scaled_input1_val - scaled_input2_val; + const int32_t raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32_t clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[SubscriptToIndex(output_desc, indexes)] = + static_cast(clamped_output); + }; + NDOpsHelper(output_desc, sub_func); } -template -void BroadcastSub4DSlow(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const T* input1_data, - const RuntimeShape& input2_shape, const T* input2_data, - const RuntimeShape& output_shape, T* output_data) { - ruy::profiler::ScopeLabel label("BroadcastSub4DSlow/templated"); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; +template +void BroadcastSubSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape& input2_shape, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + ruy::profiler::ScopeLabel label("BroadcastSubSlow/templated"); + TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N); + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2); - const RuntimeShape extended_output_shape = - RuntimeShape::ExtendedShape(4, output_shape); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -298,23 +281,16 @@ void BroadcastSub4DSlow(const ArithmeticParams& params, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < extended_output_shape.Dims(0); ++b) { - for (int y = 0; y < extended_output_shape.Dims(1); ++y) { - for (int x = 0; x < extended_output_shape.Dims(2); ++x) { - for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - output_data[Offset(extended_output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, b, y, x, c)] - - input2_data[SubscriptToIndex(desc2, b, y, x, c)], - params.quantized_activation_min, - params.quantized_activation_max); - } - } - } - } + auto sub_func = [&](int indexes[N]) { + output_data[SubscriptToIndex(output_desc, indexes)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, indexes)] - + input2_data[SubscriptToIndex(desc2, indexes)], + params.quantized_activation_min, params.quantized_activation_max); + }; + NDOpsHelper(output_desc, sub_func); } - // Element-wise Sub that can often be used for inner loop of broadcast sub as // well as the non-broadcast sub. inline void SubElementwise(int size, const ArithmeticParams& params, diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index d069bc3dbf1..cf9f8b99ee4 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -139,7 +139,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DIV, Register_DIV()); AddBuiltin(BuiltinOperator_SUB, Register_SUB(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 3); AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(), /* min_version */ 1, /* max_version */ 3); AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V(), diff --git a/tensorflow/lite/kernels/sub.cc b/tensorflow/lite/kernels/sub.cc index f2913faeb76..55a91acf1b5 100644 --- a/tensorflow/lite/kernels/sub.cc +++ b/tensorflow/lite/kernels/sub.cc @@ -237,13 +237,13 @@ void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, if (output->type == kTfLiteInt32) { if (kernel_type == kReference) { if (data->requires_broadcast) { - TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, int32_t); + TF_LITE_SUB(reference_ops, BroadcastSubSlow, int32_t); } else { TF_LITE_SUB(reference_ops, SubWithActivation, int32_t); } } else { if (data->requires_broadcast) { - TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, int32_t); + TF_LITE_SUB(optimized_ops, BroadcastSubSlow, int32_t); } else { TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t); } @@ -251,13 +251,13 @@ void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, } else if (output->type == kTfLiteFloat32) { if (kernel_type == kReference) { if (data->requires_broadcast) { - TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, float); + TF_LITE_SUB(reference_ops, BroadcastSubSlow, float); } else { TF_LITE_SUB(reference_ops, SubWithActivation, float); } } else { if (data->requires_broadcast) { - TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, float); + TF_LITE_SUB(optimized_ops, BroadcastSubSlow, float); } else { TF_LITE_SUB(optimized_ops, SubWithActivation, float); } @@ -321,13 +321,13 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, } else { if (kernel_type == kReference) { if (need_broadcast) { - TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, int16_t); + TF_LITE_SUB(reference_ops, BroadcastSubSlow, int16_t); } else { TF_LITE_SUB(reference_ops, Sub16, int16_t); } } else { if (need_broadcast) { - TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, int16_t); + TF_LITE_SUB(optimized_ops, BroadcastSubSlow, int16_t); } else { TF_LITE_SUB(optimized_ops, Sub16, int16_t); } diff --git a/tensorflow/lite/kernels/sub_test.cc b/tensorflow/lite/kernels/sub_test.cc index 24b554f087b..adda1b810ce 100644 --- a/tensorflow/lite/kernels/sub_test.cc +++ b/tensorflow/lite/kernels/sub_test.cc @@ -142,6 +142,22 @@ TEST(FloatSubOpModel, WithBroadcast) { } } +TEST(FloatSubOpModel, WithBroadcast5D) { + std::vector> test_shapes = {{1, 3, 1, 2, 1}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatSubOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, // always a scalar + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 1.7, 0.5, -1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.5}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-2.5, -0.3, 1.2, 0.0, -1.6, 1.5}))) + << "With shape number " << i; + } +} + TEST(IntegerSubOpModel, NoActivation) { IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, @@ -179,7 +195,7 @@ TEST(IntegerSubOpModel, VariousInputShapes) { TEST(IntegerSubOpModel, WithBroadcast) { std::vector> test_shapes = { - {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}, {1, 3, 1, 2, 1}}; for (int i = 0; i < test_shapes.size(); ++i) { IntegerSubOpModel m({TensorType_INT32, test_shapes[i]}, {TensorType_INT32, {}}, // always a scalar @@ -375,7 +391,7 @@ TEST(QuantizedSubOpModel, QuantizedTestsNoActivationBroadcastInt16) { std::numeric_limits::max(); float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); std::vector> test_shapes = { - {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}, {1, 3, 1, 2, 1}}; for (int i = 0; i < test_shapes.size(); ++i) { QuantizedSubOpModel m({TensorType_INT16, test_shapes[i], kMin, kMax}, {TensorType_INT16, {}, kMin, kMax}, @@ -398,7 +414,7 @@ TEST(QuantizedSubOpModel, QuantizedTestsReluActivationBroadcastInt16) { std::numeric_limits::max(); float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); std::vector> test_shapes = { - {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}, {1, 3, 1, 2, 1}}; for (int i = 0; i < test_shapes.size(); ++i) { QuantizedSubOpModel m({TensorType_INT16, test_shapes[i], kMin, kMax}, {TensorType_INT16, {}, kMin, kMax}, diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index e45665faa30..9559140291b 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -122,7 +122,7 @@ int SingleOpModel::AddOutput(const TensorData& t) { void SingleOpModel::SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type, flatbuffers::Offset builtin_options) { - opcodes_.push_back(CreateOperatorCode(builder_, type, 0)); + opcodes_.push_back(CreateOperatorCode(builder_, type, 0, 0)); operators_.push_back(CreateOperator( builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), builder_.CreateVector(outputs_), builtin_options_type, @@ -201,7 +201,6 @@ void SingleOpModel::BuildInterpreter(std::vector> input_shapes, void SingleOpModel::ApplyDelegate() { if (force_use_nnapi) { - // TODO(b/124505407): Check the result and fail accordingly. interpreter_->ModifyGraphWithDelegate(TestNnApiDelegate()); } @@ -350,7 +349,7 @@ void MultiOpModel::AddBuiltinOp( BuiltinOperator type, BuiltinOptions builtin_options_type, const flatbuffers::Offset& builtin_options, const std::vector& inputs, const std::vector& outputs) { - opcodes_.push_back(CreateOperatorCode(builder_, type, 0)); + opcodes_.push_back(CreateOperatorCode(builder_, type, 0, 0)); const int opcode_index = opcodes_.size() - 1; operators_.push_back(CreateOperator( builder_, opcode_index, builder_.CreateVector(inputs), diff --git a/tensorflow/lite/micro/kernels/sub.cc b/tensorflow/lite/micro/kernels/sub.cc index c1f43e9c6bc..3dbe4c202f4 100644 --- a/tensorflow/lite/micro/kernels/sub.cc +++ b/tensorflow/lite/micro/kernels/sub.cc @@ -116,7 +116,7 @@ void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, GetTensorShape(input2), GetTensorData(input2), \ GetTensorShape(output), GetTensorData(output)) if (data->requires_broadcast) { - TF_LITE_SUB(tflite::reference_ops::BroadcastSub4DSlow); + TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow); } else { TF_LITE_SUB(tflite::reference_ops::SubWithActivation); } @@ -150,13 +150,13 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node, GetTensorShape(output), GetTensorData(output)); if (output->type == kTfLiteInt8) { if (need_broadcast) { - TF_LITE_SUB(tflite::reference_ops::BroadcastSub4DSlow, int8_t); + TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow, int8_t); } else { TF_LITE_SUB(tflite::reference_ops::Sub, int8_t); } } else { if (need_broadcast) { - TF_LITE_SUB(tflite::reference_ops::BroadcastSub4DSlow, uint8_t); + TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow, uint8_t); } else { TF_LITE_SUB(tflite::reference_ops::Sub, uint8_t); } diff --git a/tensorflow/lite/testing/op_tests/binary_op.py b/tensorflow/lite/testing/op_tests/binary_op.py index 48c4296cc19..118c95dc777 100644 --- a/tensorflow/lite/testing/op_tests/binary_op.py +++ b/tensorflow/lite/testing/op_tests/binary_op.py @@ -26,10 +26,14 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function def make_binary_op_tests(options, binary_operator, allow_fully_quantize=False, - expected_tf_failures=0): + expected_tf_failures=0, + test_parameters=None): """Make a set of tests to do binary ops with and without broadcast.""" - test_parameters = [ + if test_parameters is None: + test_parameters = [] + + test_parameters = test_parameters + [ # Avoid creating all combinations to keep the test size small. { "dtype": [tf.float32, tf.int32], @@ -185,7 +189,21 @@ def make_div_tests(options): @register_make_test_function() def make_sub_tests(options): - make_binary_op_tests(options, tf.subtract, allow_fully_quantize=True) + """Make zip tests for sub op with additional cases.""" + test_parameters = [ + { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 3, 3, 3]], + "input_shape_2": [[3]], + "activation": [False], + "fully_quantize": [False], + }, + ] + make_binary_op_tests( + options, + tf.subtract, + allow_fully_quantize=True, + test_parameters=test_parameters) @register_make_test_function() diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index 47b32940050..c3f378f2d78 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -74,6 +74,9 @@ class ExportTest : public ::testing::Test { input1_array.data_type = ArrayDataType::kFloat; input2_array.data_type = ArrayDataType::kFloat; output_array.data_type = ArrayDataType::kFloat; + input1_array.copy_shape({1, 2, 2, 2}); + input2_array.copy_shape({1, 2, 2, 2}); + output_array.copy_shape({1, 2, 2, 2}); input_model_.operators.emplace_back(op); } else if (name == "Assert") { auto* op = new TensorFlowAssertOperator; diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index fc408730759..76a5889948a 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -265,6 +265,23 @@ class Sub : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input1_name = op_signature.op->inputs[0]; + const string& input2_name = op_signature.op->inputs[1]; + const Array& input1_array = op_signature.model->GetArray(input1_name); + const Array& input2_array = op_signature.model->GetArray(input2_name); + ::tflite::OpSignature op_sig = + GetVersioningOpSig(builtin_op(), op_signature); + if (input1_array.has_shape() && input2_array.has_shape()) { + op_sig.options.sub.num_dims = + std::max(input1_array.shape().dimensions_count(), + input2_array.shape().dimensions_count()); + op_sig.options.sub.need_broadcast = + (input1_array.shape() != input2_array.shape()); + } + return ::tflite::GetBuiltinOperatorVersion(op_sig); + } }; class Div : public BuiltinOperator(); + StridedSliceOperator op; + op.inputs = {"input1"}; + auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); + const BaseOperator* base_op = operator_by_type_map.at(op.type).get(); + + Model uint8_model; + Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]); + uint8_array.data_type = ArrayDataType::kUint8; + OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model}; + EXPECT_EQ(base_op->GetVersion(uint8_signature), 1); + + Model int8_model; + Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]); + int8_array.data_type = ArrayDataType::kInt8; + OperatorSignature int8_signature = {.op = &op, .model = &int8_model}; + EXPECT_EQ(base_op->GetVersion(int8_signature), 2); + + Model bool_model; + Array& bool_array = bool_model.GetOrCreateArray(op.inputs[0]); + bool_array.data_type = ArrayDataType::kBool; + OperatorSignature bool_signature = {.op = &op, .model = &bool_model}; + EXPECT_EQ(base_op->GetVersion(bool_signature), 3); + + op.start_indices = {0, 0, 0, 0, 0}; + op.stop_indices = {1, 2, 2, 2, 2}; + op.strides = {1, 1, 1, 1, 1}; + EXPECT_EQ(base_op->GetVersion(uint8_signature), 4); + EXPECT_EQ(base_op->GetVersion(int8_signature), 4); + EXPECT_EQ(base_op->GetVersion(bool_signature), 4); } TEST_F(OperatorTest, VersioningSpaceToDepthTest) { @@ -982,8 +1010,6 @@ TEST_F(OperatorTest, VersioningSumTest) { TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest(); } -TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest(); } - void SimpleMulVersioningTest(ArrayDataType data_type, float multiplier, int version) { MulOperator op; @@ -1014,6 +1040,42 @@ TEST_F(OperatorTest, VersioningMulTest) { SimpleMulVersioningTest(ArrayDataType::kInt8, 2.0f, 3); } +void SimpleSubVersioningTest(ArrayDataType data_type, Shape shape1, + Shape shape2, int version) { + SubOperator op; + op.inputs = {"input1", "input2"}; + op.outputs = {"output"}; + auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); + const BaseOperator* base_op = operator_by_type_map.at(op.type).get(); + + Model model; + Array& input0 = model.GetOrCreateArray(op.inputs[0]); + Array& input1 = model.GetOrCreateArray(op.inputs[1]); + Array& output = model.GetOrCreateArray(op.outputs[0]); + + input0.data_type = data_type; + input0.copy_shape(shape1); + input1.data_type = data_type; + input1.copy_shape(shape2); + output.data_type = data_type; + + OperatorSignature signature = {.op = &op, .model = &model}; + EXPECT_EQ(base_op->GetVersion(signature), version); +} + +TEST_F(OperatorTest, VersioningSubTest) { + SimpleSubVersioningTest(ArrayDataType::kUint8, {1, 2, 2, 2}, {1, 2, 2, 2}, 1); + SimpleSubVersioningTest(ArrayDataType::kInt8, {1, 2, 2, 2}, {1, 2, 2, 2}, 2); + SimpleSubVersioningTest(ArrayDataType::kUint8, {1, 2, 2}, {1, 2, 2}, 1); + SimpleSubVersioningTest(ArrayDataType::kInt8, {1, 2, 2}, {1, 2, 2}, 2); + SimpleSubVersioningTest(ArrayDataType::kUint8, {1, 2, 2, 2}, {1, 2, 2, 1}, 1); + SimpleSubVersioningTest(ArrayDataType::kInt8, {1, 2, 2, 2}, {1, 2, 2, 1}, 2); + SimpleSubVersioningTest(ArrayDataType::kUint8, {1, 2, 2, 2, 2}, + {1, 2, 2, 2, 1}, 3); + SimpleSubVersioningTest(ArrayDataType::kInt8, {1, 2, 2, 2, 2}, + {1, 2, 2, 2, 1}, 3); +} + TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest(); } TEST_F(OperatorTest, VersioningPadV2Test) { diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index e1b10233733..7a74eefaa7c 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -32,6 +32,20 @@ namespace { inline int GetNumDims(const SubGraph* subgraph, const Operator* op, int idx) { return subgraph->tensors()->Get(op->inputs()->Get(idx))->shape()->size(); } + +// Compare shape of two tensors with idx1 and idx2 of an operator op, return +// true if they have the same shape. +inline bool HaveSameShapes(const SubGraph* subgraph, const Operator* op, + int idx1, int idx2) { + const flatbuffers::Vector* shape1 = + subgraph->tensors()->Get(op->inputs()->Get(idx1))->shape(); + const flatbuffers::Vector* shape2 = + subgraph->tensors()->Get(op->inputs()->Get(idx2))->shape(); + if (shape1->size() != shape2->size()) { + return false; + } + return std::equal(shape1->begin(), shape1->end(), shape2->begin()); +} } // namespace int GetBuiltinOperatorVersion(const OpSignature& op_sig) { @@ -342,9 +356,18 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; + case BuiltinOperator_SUB: + if (op_sig.options.sub.need_broadcast && + op_sig.options.sub.num_dims > 4) { + return 3; + } + if (op_sig.input_types.at(0) == TensorType_INT8) { + return 2; + } + return 1; + case BuiltinOperator_AVERAGE_POOL_2D: case BuiltinOperator_ADD: - case BuiltinOperator_SUB: case BuiltinOperator_CONCATENATION: case BuiltinOperator_MAX_POOL_2D: case BuiltinOperator_PAD: @@ -486,6 +509,12 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, op_sig.options.space_batch.num_dims = GetNumDims(subgraph, op, 0); } break; + case BuiltinOperator_SUB: { + op_sig.options.sub.need_broadcast = !HaveSameShapes(subgraph, op, 0, 1); + op_sig.options.sub.num_dims = + std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1)); + } break; + default: break; } diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h index e12d9aef99a..e22e5403a8a 100644 --- a/tensorflow/lite/tools/versioning/op_version.h +++ b/tensorflow/lite/tools/versioning/op_version.h @@ -55,6 +55,10 @@ typedef struct { struct { int32_t num_dims; } space_batch; + struct { + int32_t num_dims; + bool need_broadcast; + } sub; } options; } OpSignature; From 067ba5e60077708b6fd886bfb82588e68433aae7 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 19 Mar 2020 01:08:31 -0700 Subject: [PATCH 215/492] Internal change PiperOrigin-RevId: 301760226 Change-Id: I15b5d12d542267c484bed8bf2f8dcf0117298f16 --- third_party/mlir/test.BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD index 3c881ca4b2e..a27c7c9f9f3 100644 --- a/third_party/mlir/test.BUILD +++ b/third_party/mlir/test.BUILD @@ -162,6 +162,7 @@ cc_library( "@llvm-project//mlir:Analysis", "@llvm-project//mlir:EDSC", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgTransforms", From 49df241a6fa11e8029eede48c02be234d8592eeb Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Thu, 19 Mar 2020 01:08:33 -0700 Subject: [PATCH 216/492] Use MLIR based TensorFlow compiler in XLA on demand compiler if mlir bridge is enabled Added a separate python test for now but in a follow-up CL the same set of tests will run on both bridges. PiperOrigin-RevId: 301760228 Change-Id: I3028edf7a365f81dabd7d769e714a4eef5e93312 --- tensorflow/compiler/jit/BUILD | 1 + .../compiler/jit/xla_compilation_cache.cc | 27 ++++++- .../compiler/jit/xla_compilation_cache.h | 2 +- .../compiler/jit/xla_compile_on_demand_op.cc | 9 ++- tensorflow/compiler/mlir/tensorflow/BUILD | 5 ++ .../tensorflow/utils/compile_mlir_util.cc | 56 ++++++++++--- .../mlir/tensorflow/utils/compile_mlir_util.h | 10 +++ .../utils/compile_mlir_util_test.cc | 40 ++++++++++ tensorflow/compiler/tests/BUILD | 20 +++++ .../compiler/tests/unary_mlir_ops_test.py | 80 +++++++++++++++++++ tensorflow/tools/lib_package/BUILD | 2 + 11 files changed, 233 insertions(+), 19 deletions(-) create mode 100644 tensorflow/compiler/tests/unary_mlir_ops_test.py diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f71331af0df..8a868c3283a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -338,6 +338,7 @@ cc_library( deps = [ ":xla_activity_listener", ":xla_activity_proto_cc", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 5540fee7276..06df2da37b8 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" @@ -253,7 +255,7 @@ static xla::StatusOr> CreateGraph( Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, absl::Span args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompileOptions& compile_options, bool use_mlir_bridge, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { const NodeDef& def = ctx->op_kernel().def(); @@ -273,8 +275,27 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - return compiler->CompileGraph(compile_options, node_def.name(), - std::move(graph), args, result); + + bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kParameter; + }); + // Use MLIR bridge if all the arguments are parameters. + // TODO(hinsu): Support other argument types. + if (!use_mlir_bridge || !are_params) { + return compiler->CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result); + } + + absl::InlinedVector arg_shapes; + arg_shapes.reserve(args.size()); + for (const XlaCompiler::Argument& arg : args) { + arg_shapes.push_back(absl::get(arg.shape)); + } + GraphDebugInfo debug_info; + return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()}, + compile_options.use_tuple_arg, + *options.flib_def, debug_info, + options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 83a0bda97d5..08b2ca4b778 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -82,7 +82,7 @@ class XlaCompilationCache : public ResourceBase { Status CompileSingleOp( const XlaCompiler::Options& options, absl::Span args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompileOptions& compile_options, bool use_mlir_bridge, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 45ce68ba9c0..79c44ea30b7 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -188,7 +188,8 @@ Status XlaCompileOnDemandOp::Compile( XlaCompiler::Options options; options.device_type = metadata.jit_device_type(); options.client = metadata.client(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + FunctionLibraryRuntime* flr = ctx->function_library(); + options.flib_def = flr->GetFunctionLibraryDefinition(); options.shape_representation_fn = metadata.shape_representation_fn(); XlaCompiler::CompileOptions compile_options; @@ -204,8 +205,10 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_arguments, variable_args, ctx, &args)); - return cache->CompileSingleOp(options, args, ctx, compile_options, result, - executable); + const ConfigProto* config = flr->config_proto(); + bool use_mlir = config ? config->experimental().enable_mlir_bridge() : false; + return cache->CompileSingleOp(options, args, ctx, compile_options, use_mlir, + result, executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 3bed4e753e0..5087e98b038 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1046,9 +1046,11 @@ cc_library( hdrs = ["utils/compile_mlir_util.h"], deps = [ ":bridge_logger", + ":convert_graphdef", ":convert_type", ":dump_mlir_util", ":error_util", + ":mlir_roundtrip_flags", ":tensorflow_dialect_registration", ":tensorflow_passes", ":translate_utils", @@ -1059,6 +1061,7 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", @@ -1083,8 +1086,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/stream_executor/lib", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 10aad0a03ff..713afa5c214 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -28,6 +28,8 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" @@ -260,18 +262,11 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, return Status::OK(); } -Status CompileSerializedMlirToXlaHlo( - llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, +static Status CompileMlirToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { - mlir::MLIRContext mlir_context; - mlir::OwningModuleRef mlir_module; - - TF_RETURN_IF_ERROR( - ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - auto module_op = mlir_module.get(); - if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -292,9 +287,14 @@ Status CompileSerializedMlirToXlaHlo( GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping); auto shape_representation_fn_no_fast_memory = - [shape_representation_fn](const TensorShape& shape, DataType dtype) { - return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); - }; + [shape_representation_fn](const TensorShape& shape, + DataType dtype) -> StatusOr { + if (shape_representation_fn) + return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; // Compute all input shapes. TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, @@ -316,4 +316,36 @@ Status CompileSerializedMlirToXlaHlo( return Status::OK(); } +Status CompileSerializedMlirToXlaHlo( + llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, + bool use_tuple_args, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + mlir::MLIRContext mlir_context; + mlir::OwningModuleRef mlir_module; + + TF_RETURN_IF_ERROR( + ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); + return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args, + shape_representation_fn, compilation_result); +} + +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + mlir::MLIRContext context; + GraphImportConfig config; + config.graph_as_function = true; + auto module_or = + ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); + if (!module_or.ok()) return module_or.status(); + + return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, + use_tuple_args, shape_representation_fn, + compilation_result); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 41fa8b90e4f..bf95c7c0d61 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -53,6 +54,15 @@ Status CompileSerializedMlirToXlaHlo( bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); + +// Same as the above but takes input as TensorFlow Graph. +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index b258dd68ae1..d4e5d71e525 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -20,6 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -248,5 +251,42 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { ::testing::HasSubstr(expected_signature)); } +// Verify that conversion from Graph to MLIR and empty shape representation +// function is successful. +TEST(CompileGraphToXlaHlo, Basic) { + setenv("TF_DUMP_GRAPH_PREFIX", "-", 1); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + Graph graph(OpRegistry::Global()); + + Tensor dummy_tensor(DT_FLOAT, TensorShape({1})); + test::FillValues(&dummy_tensor, {-1.0}); + + Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT); + test::graph::Retval(&graph, 0, arg); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(CompileGraphToXlaHlo( + graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def, + GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); + + const xla::HloModuleConfig module_config( + result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + + string expected_hlo_module_string = R"(HloModule main.3 + +ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { + %Arg_0.1 = f32[] parameter(0) + ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1) +} + +)"; + + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 77cd3dc074c..d586b8178c5 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1354,6 +1354,26 @@ tf_xla_py_test( ], ) +# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it. +tf_xla_py_test( + name = "unary_mlir_ops_test", + size = "medium", + srcs = ["unary_mlir_ops_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fused_batchnorm_test", size = "medium", diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py new file mode 100644 index 00000000000..2b3dec3d5a7 --- /dev/null +++ b/tensorflow/compiler/tests/unary_mlir_ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class UnaryOpsTest(xla_test.XLATestCase): + """Test cases for unary operators.""" + + def __init__(self, method_name='runTest'): + super(UnaryOpsTest, self).__init__(method_name) + context.context().enable_mlir_bridge = True + + def _assertOpOutputMatchesExpected(self, + op, + inp, + expected, + equality_test=None, + rtol=1e-3, + atol=1e-5): + """Verifies that 'op' produces 'expected' when fed input 'inp' . + + Args: + op: operator to test + inp: numpy input array to use as input to 'op'. + expected: numpy array representing the expected output of 'op'. + equality_test: either None, or a function that tests two numpy arrays for + equality. If None, self.assertAllClose is used. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + with self.session() as session: + with self.test_scope(): + pinp = array_ops.placeholder( + dtypes.as_dtype(inp.dtype), inp.shape, name='a') + output = op(pinp) + result = session.run(output, {pinp: inp}) + if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + else: + equality_test(result, expected, rtol=rtol, atol=atol) + + def testNumericOps(self): + # TODO(hinsu): Enable complex types after fixing the failure in export to + # HLOModule. + for dtype in self.numeric_types - {np.int8, np.uint8} - self.complex_types: + self._assertOpOutputMatchesExpected( + math_ops.abs, + np.array([[2, -1]], dtype=dtype), + expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 30ab95e370d..89ec2a0c7c3 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -154,6 +154,7 @@ genrule( "@icu//:icu4c/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", + "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE", @@ -234,6 +235,7 @@ genrule( "@icu//:icu4j/main/shared/licenses/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", + "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE", From cf46f78c72bbcb86aff64d06e6fbbb54f2f31d8a Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 19 Mar 2020 01:14:32 -0700 Subject: [PATCH 217/492] Handle bitcast indexing better. We may not need to recompute the linear index if we already have one. Also do some code cleanup in ir_array.cc PiperOrigin-RevId: 301760879 Change-Id: I001c503d7874c07a000fa1a55b37180a84bcc277 --- .../xla/service/elemental_ir_emitter.cc | 2 +- .../xla/service/gpu/tests/gpu_index_test.cc | 26 +++++++++++++++++++ .../compiler/xla/service/llvm_ir/ir_array.cc | 22 +++++++++------- .../compiler/xla/service/llvm_ir/ir_array.h | 6 ++--- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 1d18b2c65a8..3eb6dab3129 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2376,7 +2376,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( &operand_to_generator](const IrArray::Index& target_index) { return operand_to_generator.at(hlo->operand(0))( target_index.SourceIndexOfTranspose( - hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), b_)); + hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions())); }; case HloOpcode::kPad: return [this, hlo, &operand_to_generator]( diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index 67b291c8fcb..871692a7b26 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -149,5 +149,31 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithSizeOneDimensions) { /*match_optimized_ir=*/false); } +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithTranspose) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithTranspose { + x = f32[2,1024,3,256]{3,2,1,0} parameter(0) + y = f32[1024,2,256,3]{2,3,0,1} parameter(1) + transpose = f32[1024,2,256,3]{3,2,1,0} transpose(x), dimensions={1,0,3,2} + ROOT gte = pred[1024,2,256,3]{2,3,0,1} compare(transpose, y), direction=GE + })", + config) + .ValueOrDie(); + // Check the optimized IR contains no udiv and urem. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK-NOT: udiv +; CHECK-NOT: urem + )", + /*match_optimized_ir=*/true); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 781ba9b980d..396fcf9e92e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -188,16 +188,13 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice( std::vector source_multi_index(multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; - auto type = multidim_[i]->getType(); - if (stride != 1) { source_multi_index[i] = builder->CreateAdd( - builder->CreateMul(multidim_[i], - llvm::ConstantInt::get(type, stride)), - llvm::ConstantInt::get(type, starts[i])); + builder->CreateMul(multidim_[i], GetConstantWithIndexType(stride)), + GetConstantWithIndexType(starts[i])); } else { - source_multi_index[i] = builder->CreateAdd( - multidim_[i], llvm::ConstantInt::get(type, starts[i])); + source_multi_index[i] = + builder->CreateAdd(multidim_[i], GetConstantWithIndexType(starts[i])); } } return Index(source_multi_index, operand_shape, index_type_); @@ -205,8 +202,7 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice( IrArray::Index IrArray::Index::SourceIndexOfTranspose( const Shape& shape, const Shape& operand_shape, - absl::Span dimension_mapping, - llvm::IRBuilder<>* builder) const { + absl::Span dimension_mapping) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); @@ -223,6 +219,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( const Shape& shape, const Shape& operand_shape, llvm::IRBuilder<>* builder) const { CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape)); + // In case the bitcast is just a reshape, we can use SourceIndexOfReshape() // instead. This will reuse linear() if possible, so we don't have to build a // new 'linear_index'. @@ -230,6 +227,13 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( return SourceIndexOfReshape(shape, operand_shape, builder); } + // If we have a linear index, we can definitely use it because we know the + // operation is a bitcast. This will recompute the multi-dimensional index for + // the operand based on the linear index. + if (linear() != nullptr) { + return Index(linear(), operand_shape, builder); + } + // First linearize the index coming from the output of the bitcast. We want // the physical index of the element in the buffer. This is like Linearize, // but takes the layout into account. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index a1af1477de0..e838c4a0534 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -135,9 +135,9 @@ class IrArray { // Given that "this" is the target index of a transpose from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape, - absl::Span dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfTranspose( + const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping) const; // Given that "this" is the target index of a bitcast from `operand_shape` // to `shape`, returns the source index. From 171d55a7ec79d47fcbe469f29f70cc8429837f74 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 01:29:21 -0700 Subject: [PATCH 218/492] Use MLIR based TensorFlow compiler in XLA on demand compiler if mlir bridge is enabled Added a separate python test for now but in a follow-up CL the same set of tests will run on both bridges. PiperOrigin-RevId: 301762541 Change-Id: Iee8bb81ecc9695cb4eb389977eaa7c72322820af --- tensorflow/compiler/jit/BUILD | 1 - .../compiler/jit/xla_compilation_cache.cc | 27 +------ .../compiler/jit/xla_compilation_cache.h | 2 +- .../compiler/jit/xla_compile_on_demand_op.cc | 9 +-- tensorflow/compiler/mlir/tensorflow/BUILD | 5 -- .../tensorflow/utils/compile_mlir_util.cc | 56 +++---------- .../mlir/tensorflow/utils/compile_mlir_util.h | 10 --- .../utils/compile_mlir_util_test.cc | 40 ---------- tensorflow/compiler/tests/BUILD | 20 ----- .../compiler/tests/unary_mlir_ops_test.py | 80 ------------------- tensorflow/tools/lib_package/BUILD | 2 - 11 files changed, 19 insertions(+), 233 deletions(-) delete mode 100644 tensorflow/compiler/tests/unary_mlir_ops_test.py diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 8a868c3283a..f71331af0df 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -338,7 +338,6 @@ cc_library( deps = [ ":xla_activity_listener", ":xla_activity_proto_cc", - "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 06df2da37b8..5540fee7276 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -41,7 +40,6 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" @@ -255,7 +253,7 @@ static xla::StatusOr> CreateGraph( Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, absl::Span args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, bool use_mlir_bridge, + const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { const NodeDef& def = ctx->op_kernel().def(); @@ -275,27 +273,8 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - - bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) { - return arg.kind == XlaCompiler::Argument::kParameter; - }); - // Use MLIR bridge if all the arguments are parameters. - // TODO(hinsu): Support other argument types. - if (!use_mlir_bridge || !are_params) { - return compiler->CompileGraph(compile_options, node_def.name(), - std::move(graph), args, result); - } - - absl::InlinedVector arg_shapes; - arg_shapes.reserve(args.size()); - for (const XlaCompiler::Argument& arg : args) { - arg_shapes.push_back(absl::get(arg.shape)); - } - GraphDebugInfo debug_info; - return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()}, - compile_options.use_tuple_arg, - *options.flib_def, debug_info, - options.shape_representation_fn, result); + return compiler->CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 08b2ca4b778..83a0bda97d5 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -82,7 +82,7 @@ class XlaCompilationCache : public ResourceBase { Status CompileSingleOp( const XlaCompiler::Options& options, absl::Span args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, bool use_mlir_bridge, + const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 79c44ea30b7..45ce68ba9c0 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -188,8 +188,7 @@ Status XlaCompileOnDemandOp::Compile( XlaCompiler::Options options; options.device_type = metadata.jit_device_type(); options.client = metadata.client(); - FunctionLibraryRuntime* flr = ctx->function_library(); - options.flib_def = flr->GetFunctionLibraryDefinition(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.shape_representation_fn = metadata.shape_representation_fn(); XlaCompiler::CompileOptions compile_options; @@ -205,10 +204,8 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_arguments, variable_args, ctx, &args)); - const ConfigProto* config = flr->config_proto(); - bool use_mlir = config ? config->experimental().enable_mlir_bridge() : false; - return cache->CompileSingleOp(options, args, ctx, compile_options, use_mlir, - result, executable); + return cache->CompileSingleOp(options, args, ctx, compile_options, result, + executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 5087e98b038..3bed4e753e0 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1046,11 +1046,9 @@ cc_library( hdrs = ["utils/compile_mlir_util.h"], deps = [ ":bridge_logger", - ":convert_graphdef", ":convert_type", ":dump_mlir_util", ":error_util", - ":mlir_roundtrip_flags", ":tensorflow_dialect_registration", ":tensorflow_passes", ":translate_utils", @@ -1061,7 +1059,6 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", @@ -1086,10 +1083,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", "//tensorflow/stream_executor/lib", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 713afa5c214..10aad0a03ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -28,8 +28,6 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" @@ -262,11 +260,18 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, return Status::OK(); } -static Status CompileMlirToXlaHlo( - mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, +Status CompileSerializedMlirToXlaHlo( + llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { + mlir::MLIRContext mlir_context; + mlir::OwningModuleRef mlir_module; + + TF_RETURN_IF_ERROR( + ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); + auto module_op = mlir_module.get(); + if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -287,14 +292,9 @@ static Status CompileMlirToXlaHlo( GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping); auto shape_representation_fn_no_fast_memory = - [shape_representation_fn](const TensorShape& shape, - DataType dtype) -> StatusOr { - if (shape_representation_fn) - return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); - return xla_shape; - }; + [shape_representation_fn](const TensorShape& shape, DataType dtype) { + return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); + }; // Compute all input shapes. TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, @@ -316,36 +316,4 @@ static Status CompileMlirToXlaHlo( return Status::OK(); } -Status CompileSerializedMlirToXlaHlo( - llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, - bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result) { - mlir::MLIRContext mlir_context; - mlir::OwningModuleRef mlir_module; - - TF_RETURN_IF_ERROR( - ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args, - shape_representation_fn, compilation_result); -} - -Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef arg_shapes, - bool use_tuple_args, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result) { - mlir::MLIRContext context; - GraphImportConfig config; - config.graph_as_function = true; - auto module_or = - ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); - if (!module_or.ok()) return module_or.status(); - - return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, - use_tuple_args, shape_representation_fn, - compilation_result); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index bf95c7c0d61..41fa8b90e4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -21,7 +21,6 @@ limitations under the License. #include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -54,15 +53,6 @@ Status CompileSerializedMlirToXlaHlo( bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); - -// Same as the above but takes input as TensorFlow Graph. -Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef arg_shapes, - bool use_tuple_args, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index d4e5d71e525..b258dd68ae1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -20,9 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -251,42 +248,5 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { ::testing::HasSubstr(expected_signature)); } -// Verify that conversion from Graph to MLIR and empty shape representation -// function is successful. -TEST(CompileGraphToXlaHlo, Basic) { - setenv("TF_DUMP_GRAPH_PREFIX", "-", 1); - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - Graph graph(OpRegistry::Global()); - - Tensor dummy_tensor(DT_FLOAT, TensorShape({1})); - test::FillValues(&dummy_tensor, {-1.0}); - - Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT); - test::graph::Retval(&graph, 0, arg); - - XlaCompiler::CompilationResult result; - TF_ASSERT_OK(CompileGraphToXlaHlo( - graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def, - GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); - - const xla::HloModuleConfig module_config( - result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); - - string expected_hlo_module_string = R"(HloModule main.3 - -ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { - %Arg_0.1 = f32[] parameter(0) - ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1) -} - -)"; - - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index d586b8178c5..77cd3dc074c 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1354,26 +1354,6 @@ tf_xla_py_test( ], ) -# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it. -tf_xla_py_test( - name = "unary_mlir_ops_test", - size = "medium", - srcs = ["unary_mlir_ops_test.py"], - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - ], - deps = [ - ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", - "//tensorflow/python:platform_test", - ], -) - tf_xla_py_test( name = "fused_batchnorm_test", size = "medium", diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py deleted file mode 100644 index 2b3dec3d5a7..00000000000 --- a/tensorflow/compiler/tests/unary_mlir_ops_test.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for XLA JIT compiler.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.compiler.tests import xla_test -from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import googletest - - -class UnaryOpsTest(xla_test.XLATestCase): - """Test cases for unary operators.""" - - def __init__(self, method_name='runTest'): - super(UnaryOpsTest, self).__init__(method_name) - context.context().enable_mlir_bridge = True - - def _assertOpOutputMatchesExpected(self, - op, - inp, - expected, - equality_test=None, - rtol=1e-3, - atol=1e-5): - """Verifies that 'op' produces 'expected' when fed input 'inp' . - - Args: - op: operator to test - inp: numpy input array to use as input to 'op'. - expected: numpy array representing the expected output of 'op'. - equality_test: either None, or a function that tests two numpy arrays for - equality. If None, self.assertAllClose is used. - rtol: relative tolerance for equality test. - atol: absolute tolerance for equality test. - """ - with self.session() as session: - with self.test_scope(): - pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name='a') - output = op(pinp) - result = session.run(output, {pinp: inp}) - if equality_test is None: - self.assertEqual(output.dtype, expected.dtype) - self.assertAllCloseAccordingToType( - expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) - else: - equality_test(result, expected, rtol=rtol, atol=atol) - - def testNumericOps(self): - # TODO(hinsu): Enable complex types after fixing the failure in export to - # HLOModule. - for dtype in self.numeric_types - {np.int8, np.uint8} - self.complex_types: - self._assertOpOutputMatchesExpected( - math_ops.abs, - np.array([[2, -1]], dtype=dtype), - expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 89ec2a0c7c3..30ab95e370d 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -154,7 +154,6 @@ genrule( "@icu//:icu4c/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", - "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE", @@ -235,7 +234,6 @@ genrule( "@icu//:icu4j/main/shared/licenses/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", - "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE", From fff9858dca6136c4dbbc6db7bc5a6ab9aaec5e63 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Thu, 19 Mar 2020 01:56:52 -0700 Subject: [PATCH 219/492] Fixed the platform_profiler dep issue. PiperOrigin-RevId: 301765638 Change-Id: Ie9bb835a01fe50820357ea00d659bb3cf8897258 --- tensorflow/lite/BUILD | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index e6164c395e3..70b1566600d 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -250,7 +250,12 @@ cc_library( "//tensorflow/lite/experimental/resource", "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/schema:schema_fbs", - ], + ] + select({ + ":enable_default_profiler": [ + "//tensorflow/lite/profiling:platform_profiler", + ], + "//conditions:default": [], + }), alwayslink = 1, ) @@ -285,12 +290,7 @@ cc_library( "//tensorflow/lite/experimental/resource", "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/schema:schema_fbs", - ] + select({ - ":enable_default_profiler": [ - "//tensorflow/lite/profiling:platform_profiler", - ], - "//conditions:default": [], - }) + tflite_experimental_runtime_linkopts(), + ] + tflite_experimental_runtime_linkopts(), alwayslink = 1, ) From 2decf5694a7115cbf36c10ca124b38889f5622f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 02:02:40 -0700 Subject: [PATCH 220/492] compat: Update forward compatibility horizon to 2020-03-19 PiperOrigin-RevId: 301766393 Change-Id: I1634bc96a04abcd3d92f893ad1b8b7288bd04265 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 17b23b616d1..6121c71a404 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 18) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 19) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From b9b739285e46ca0bdc54641bf7351463cf8aa928 Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Thu, 19 Mar 2020 05:03:41 -0700 Subject: [PATCH 221/492] Fix Makefile build Since XNNPACK isn't buildable with Makefile, it should be disabled for Makefile build. PiperOrigin-RevId: 301790434 Change-Id: Ieff4304f26abb0ec9c7dc90e0be9803aa58e8e98 --- tensorflow/lite/tools/evaluation/utils.cc | 6 +++++- tensorflow/lite/tools/evaluation/utils.h | 6 +++--- tensorflow/lite/tools/make/Makefile | 9 ++++++++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc index f86a7316ecf..69f962d7e9a 100644 --- a/tensorflow/lite/tools/evaluation/utils.cc +++ b/tensorflow/lite/tools/evaluation/utils.cc @@ -163,7 +163,11 @@ TfLiteDelegatePtr CreateHexagonDelegate( } // TODO(b/149248802): include XNNPACK delegate when the issue is resolved. -#if !defined(__Fuchsia__) +#if defined(__Fuchsia__) || defined(TFLITE_WITHOUT_XNNPACK) +TfLiteDelegatePtr CreateXNNPACKDelegate(int num_threads) { + return CreateNullDelegate(); +} +#else TfLiteDelegatePtr CreateXNNPACKDelegate() { TfLiteXNNPackDelegateOptions xnnpack_options = TfLiteXNNPackDelegateOptionsDefault(); diff --git a/tensorflow/lite/tools/evaluation/utils.h b/tensorflow/lite/tools/evaluation/utils.h index d1717f92e5f..e7c8246b340 100644 --- a/tensorflow/lite/tools/evaluation/utils.h +++ b/tensorflow/lite/tools/evaluation/utils.h @@ -28,7 +28,7 @@ limitations under the License. #endif // TODO(b/149248802): include XNNPACK delegate when the issue is resolved. -#if !defined(__Fuchsia__) +#if !defined(__Fuchsia__) || defined(TFLITE_WITHOUT_XNNPACK) #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #endif @@ -73,12 +73,12 @@ TfLiteDelegatePtr CreateHexagonDelegate( const std::string& library_directory_path, bool profiling); // TODO(b/149248802): include XNNPACK delegate when the issue is resolved. -#if !defined(__Fuchsia__) +#if !defined(__Fuchsia__) || defined(TFLITE_WITHOUT_XNNPACK) TfLiteDelegatePtr CreateXNNPACKDelegate(); TfLiteDelegatePtr CreateXNNPACKDelegate( const TfLiteXNNPackDelegateOptions* options); -TfLiteDelegatePtr CreateXNNPACKDelegate(int num_threads); #endif +TfLiteDelegatePtr CreateXNNPACKDelegate(int num_threads); } // namespace evaluation } // namespace tflite diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index 2684d1fd05b..e0aa625744b 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -172,6 +172,9 @@ ifeq ($(BUILD_WITH_RUY),true) CXXFLAGS += -DTFLITE_WITH_RUY endif +# Not to include XNNPACK. +CXXFLAGS += -DTFLITE_WITHOUT_XNNPACK + BUILD_WITH_NNAPI ?= false ifeq ($(BUILD_WITH_NNAPI),true) CORE_CC_ALL_SRCS += tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -212,7 +215,11 @@ BENCHMARK_LIB_SRCS := $(filter-out \ $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc) \ $(BENCHMARK_MAIN_SRC) \ $(BENCHMARK_PERF_OPTIONS_SRC) \ - $(BENCHMARK_SRCS_DIR)/benchmark_plus_flex_main.cc, \ + $(BENCHMARK_SRCS_DIR)/benchmark_plus_flex_main.cc \ + $(BENCHMARK_SRCS_DIR)/gpu_delegate_provider.cc \ + $(BENCHMARK_SRCS_DIR)/hexagon_delegate_provider.cc \ + $(BENCHMARK_SRCS_DIR)/nnapi_delegate_provider.cc \ + $(BENCHMARK_SRCS_DIR)/xnnpack_delegate_provider.cc, \ $(BENCHMARK_ALL_SRCS)) # These target-specific makefiles should modify or replace options like From e2992a4d7baceca99897a93dbb820f8d88833a93 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 05:48:49 -0700 Subject: [PATCH 222/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301795889 Change-Id: Ia7878dfadc1a5585d9865c87866274b1a80e5f78 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 7be0c66548c..3d05bb08fa3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11815,7 +11815,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12072,7 +12072,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12083,7 +12083,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12301,7 +12301,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12312,7 +12312,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19153,7 +19153,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20224,7 +20224,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21396,7 +21396,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22104,7 +22104,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22300,7 +22300,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22484,7 +22484,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22543,7 +22543,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22717,7 +22717,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23098,7 +23098,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25441,7 +25441,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25504,7 +25504,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25747,7 +25747,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26370,7 +26370,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45499,7 +45499,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46287,7 +46287,7 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46350,7 +46350,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From f5712e28c7d4b1c0114745c0359ce87fcc77f82d Mon Sep 17 00:00:00 2001 From: Tiezhen WANG Date: Thu, 19 Mar 2020 05:57:11 -0700 Subject: [PATCH 223/492] TFLM: Use scratch buffer in FC. This will reduce the inference latency on FullyConnected layers as we no longer need to re-calculate OpData for each inference. In the same time, the arena requirement will only slightly increase. This CL also contains some changes on the testing infra so that the single op test can also take account the kernel memory allocations. This is based a dumb implementation so easy to understand and debug. PiperOrigin-RevId: 301796803 Change-Id: I5330c634bb322be7bcd175c4f0fd03c737714af6 --- .../lite/micro/kernels/fully_connected.cc | 25 ++++- .../micro/kernels/fully_connected_test.cc | 1 - tensorflow/lite/micro/test_helpers.h | 3 - tensorflow/lite/micro/testing/BUILD | 1 + tensorflow/lite/micro/testing/test_utils.cc | 103 ++++++++++++++++-- tensorflow/lite/micro/testing/test_utils.h | 63 ++++++++++- 6 files changed, 175 insertions(+), 21 deletions(-) diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 64bf788f538..91df80b328c 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -71,18 +71,35 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + OpData* data = nullptr; + TfLiteStatus status = context->AllocatePersistentBuffer( + context, sizeof(OpData), reinterpret_cast(&data)); + if (status != kTfLiteOk || data == nullptr) { + return nullptr; + } + return data; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + auto* params = + reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); + + TfLiteType data_type = input->type; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); + return kTfLiteOk; } @@ -178,11 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteType data_type = input->type; - OpData local_data_object; - OpData* data = &local_data_object; - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, - filter, bias, output, data)); + OpData* data = reinterpret_cast(node->user_data); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 0859e4af591..4687ae89108 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -49,7 +49,6 @@ void TestFullyConnectedFloat( TfLiteContext context; PopulateContext(tensors, tensors_size, micro_test::reporter, &context); - ::tflite::ops::micro::AllOpsResolver resolver; const TfLiteRegistration* registration = resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 76919526d81..498ce9c53da 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -57,9 +57,6 @@ CreateFlatbufferBuffers(); // Performs a simple string comparison without requiring standard C library. int TestStrcmp(const char* a, const char* b); -// Wrapper to forward kernel errors to the interpreter's error reporter. -void ReportOpError(struct TfLiteContext* context, const char* format, ...); - void PopulateContext(TfLiteTensor* tensors, int tensors_size, TfLiteContext* context); diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 01bdffc6892..42f25f0e8b0 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -17,6 +17,7 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", ], diff --git a/tensorflow/lite/micro/testing/test_utils.cc b/tensorflow/lite/micro/testing/test_utils.cc index 9f7803fcf62..5fd0161d621 100644 --- a/tensorflow/lite/micro/testing/test_utils.cc +++ b/tensorflow/lite/micro/testing/test_utils.cc @@ -15,24 +15,107 @@ limitations under the License. #include "tensorflow/lite/micro/testing/test_utils.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" + namespace tflite { namespace testing { +TfLiteStatus FakeAllocator::AllocatePersistentBuffer(size_t bytes, void** ptr) { + uint8_t* addr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); + *ptr = addr; + return kTfLiteOk; +} + +TfLiteStatus FakeAllocator::RequestScratchBufferInArena(int node_idx, + size_t bytes, + int* buffer_idx) { + if (scratch_buffers_count_ >= max_scratch_buffers_count_) { + return kTfLiteError; + } + uint8_t* ptr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); + scratch_buffers_[scratch_buffers_count_] = ptr; + *buffer_idx = scratch_buffers_count_; + scratch_buffers_count_++; + return kTfLiteOk; +} + +void FakeAllocator::Reset() { + // Get A fresh memory allocator. + memory_allocator_ = CreateInPlaceSimpleMemoryAllocator(arena_, arena_size_); + TFLITE_DCHECK_NE(memory_allocator_, nullptr); + + // Allocate enough space holding pointers to the scrtach buffers. + scratch_buffers_ = + reinterpret_cast(memory_allocator_->AllocateFromTail( + sizeof(uint8_t*) * max_scratch_buffers_count_, alignof(uint8_t*))); + TFLITE_DCHECK_NE(scratch_buffers_, nullptr); + + scratch_buffers_count_ = 0; +} + +void* FakeAllocator::GetScratchBuffer(int buffer_idx) { + if (buffer_idx < 0 || buffer_idx >= scratch_buffers_count_) { + return nullptr; + } + return scratch_buffers_[buffer_idx]; +} + +TfLiteStatus FakeContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx, + size_t bytes, + void** ptr) { + return reinterpret_cast(ctx->impl_) + ->allocator_->AllocatePersistentBuffer(bytes, ptr); +} + +TfLiteStatus FakeContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx, + size_t bytes, + int* buffer_idx) { + FakeContextHelper* helper = reinterpret_cast(ctx->impl_); + // FakeAllocator doesn't do memory reusing so it doesn't need node_idx to + // calculate the lifetime of the scratch buffer. + int node_idx = -1; + return helper->allocator_->RequestScratchBufferInArena(node_idx, bytes, + buffer_idx); +} + +void* FakeContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) { + return reinterpret_cast(ctx->impl_) + ->allocator_->GetScratchBuffer(buffer_idx); +} + +void FakeContextHelper::ReportOpError(struct TfLiteContext* context, + const char* format, ...) { + FakeContextHelper* helper = static_cast(context->impl_); + va_list args; + va_start(args, format); + TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args); + va_end(args); +} + +namespace { +constexpr size_t kArenaSize = 10000; +constexpr int kMaxScratchBufferCount = 32; +uint8_t arena[kArenaSize]; +} // namespace + // TODO(b/141330728): Move this method elsewhere as part clean up. void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context) { + // This should be a large enough arena for each test cases. + static FakeAllocator allocator(arena, kArenaSize, kMaxScratchBufferCount); + static FakeContextHelper helper(error_reporter, &allocator); + // Reset the allocator so that it's ready for another test. + allocator.Reset(); + + *context = {}; + context->recommended_num_threads = 1; context->tensors_size = tensors_size; context->tensors = tensors; - context->impl_ = static_cast(error_reporter); - context->GetExecutionPlan = nullptr; - context->ResizeTensor = nullptr; - context->ReportError = ReportOpError; - context->AddTensors = nullptr; - context->GetNodeAndRegistration = nullptr; - context->ReplaceNodeSubsetsWithDelegateKernels = nullptr; - context->recommended_num_threads = 1; - context->GetExternalContext = nullptr; - context->SetExternalContext = nullptr; + context->impl_ = static_cast(&helper); + context->AllocatePersistentBuffer = helper.AllocatePersistentBuffer; + context->RequestScratchBufferInArena = helper.RequestScratchBufferInArena; + context->GetScratchBuffer = helper.GetScratchBuffer; + context->ReportError = helper.ReportOpError; for (int i = 0; i < tensors_size; ++i) { if (context->tensors[i].is_variable) { diff --git a/tensorflow/lite/micro/testing/test_utils.h b/tensorflow/lite/micro/testing/test_utils.h index 7aa1e9d488f..f7f5dff6bb1 100644 --- a/tensorflow/lite/micro/testing/test_utils.h +++ b/tensorflow/lite/micro/testing/test_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/tensor_utils.h" #include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -95,7 +96,67 @@ inline int32_t F2Q32(const float value, const float scale) { return static_cast(quantized); } -// TODO(b/141330728): Move this method elsewhere as part clean up. +// A fake version of MemoryAllocator that allocates everything from the tail +// without static memory planning or reusing. +// TODO(b/150260678): Consider splitting this into its own file and inherit from +// the same public interface as MicroAllocator. +class FakeAllocator { + public: + FakeAllocator(uint8_t* arena, size_t arena_size, + size_t max_scratch_buffers_count) + : arena_(arena), + arena_size_(arena_size), + max_scratch_buffers_count_(max_scratch_buffers_count) { + Reset(); + } + + TfLiteStatus AllocatePersistentBuffer(size_t bytes, void** ptr); + TfLiteStatus RequestScratchBufferInArena(int node_idx, size_t bytes, + int* buffer_idx); + void* GetScratchBuffer(int buffer_idx); + + // Reset the allocator to the intial state. + void Reset(); + + private: + uint8_t* arena_; + size_t arena_size_; + size_t max_scratch_buffers_count_; + + SimpleMemoryAllocator* memory_allocator_; + // An array of buffer pointers. + uint8_t** scratch_buffers_; + size_t scratch_buffers_count_ = 0; + static constexpr size_t kBufferAlignment = 16; +}; + +// A fake implementation of ContextHelper. Instead of forwarding requests to +// MicroAllocator, it calls into FakeAllocator. +// PopulateContext will point context->impl_ to an instance of this class. +// TODO(b/150260678): Consider moving this into the same file as FakeAllocator. +class FakeContextHelper { + public: + explicit FakeContextHelper(ErrorReporter* error_reporter, + FakeAllocator* allocator) + : allocator_(allocator), error_reporter_(error_reporter) {} + + static TfLiteStatus AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes, + void** ptr); + + static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx, + size_t bytes, + int* buffer_idx); + + static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx); + + static void ReportOpError(struct TfLiteContext* context, const char* format, + ...); + + private: + FakeAllocator* allocator_; + ErrorReporter* error_reporter_; +}; + void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context); From 3dc37733523f2b0df52965ec2e68023a491a39d7 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 19 Mar 2020 06:42:04 -0700 Subject: [PATCH 224/492] Adapt to upstream MLIR SPIR-V target environment changes Fixups for adding DCE to the inliner. Bump tensorflow open source LLVM revision to b72e13c242d9bbe1a4c7e471da98718bde85fa78 PiperOrigin-RevId: 301803044 Change-Id: I5bff3a55f2e6f4609d62d3c684c31ad54c2a3efa --- tensorflow/workspace.bzl | 4 ++-- third_party/mlir/BUILD | 18 ------------------ 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 1bdcc0abd4f..9de46e711ba 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -597,8 +597,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "c5b81466c2bcc194e5563f39f5be3638760b4849" - LLVM_SHA256 = "f623a7e9585e76831abc967547dfbcd5a6ecd148ed5c4e088bdae94dc7d8bda7" + LLVM_COMMIT = "b72e13c242d9bbe1a4c7e471da98718bde85fa78" + LLVM_SHA256 = "d7823d08ac835f5ca587aee8e252ffdedfac5e72b62defa1fffd214b9c649841" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index e9fef46c4df..dbe0b53002e 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -1144,23 +1144,6 @@ gentbl( ], ) -gentbl( - name = "StandardToSPIRVGen", - strip_include_prefix = "lib/Conversion/StandardToSPIRV", - tbl_outs = [ - ( - "-gen-rewriters", - "lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp.inc", - ), - ], - tblgen = ":mlir-tblgen", - td_file = "lib/Conversion/StandardToSPIRV/StandardToSPIRV.td", - td_srcs = [ - ":SPIRVOpsTdFiles", - ":StdOpsTdFiles", - ], -) - gentbl( name = "SPIRVAvailabilityIncGen", strip_include_prefix = "include", @@ -1327,7 +1310,6 @@ cc_library( ":SPIRVDialect", ":SPIRVLowering", ":StandardOps", - ":StandardToSPIRVGen", ":Support", ":Transforms", "@llvm-project//llvm:support", From 6bc6835154b42cb81e6b8e325368402fa8c3590f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 07:46:06 -0700 Subject: [PATCH 225/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301812448 Change-Id: Ia9dd27d33285b246f85f5f69f37e9587587af740 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 3d05bb08fa3..7be0c66548c 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11815,7 +11815,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12072,7 +12072,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12083,7 +12083,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12301,7 +12301,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12312,7 +12312,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19153,7 +19153,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20224,7 +20224,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21396,7 +21396,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22104,7 +22104,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22300,7 +22300,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22484,7 +22484,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22543,7 +22543,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22717,7 +22717,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23098,7 +23098,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25441,7 +25441,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25504,7 +25504,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25747,7 +25747,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26370,7 +26370,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45499,7 +45499,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46287,7 +46287,7 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46350,7 +46350,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 0cfab2b1fa250d5460136580000eaf15f55927cb Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Thu, 19 Mar 2020 08:44:54 -0700 Subject: [PATCH 226/492] Private implementation of Adam using XLA for fusion. Because we can rely on tf.function I also refactored and cleaned up the updates a bit; removing stray control dependencies and using methods to update variables and avoiding chaining assignment operations. PiperOrigin-RevId: 301822384 Change-Id: If4cb54e3d7b27c916912d39e5a01c1ff7905b4ba --- tensorflow/python/keras/optimizer_v2/adam.py | 224 +++++++++ .../python/keras/optimizer_v2/adam_test.py | 430 ++++++++++++++++++ 2 files changed, 654 insertions(+) diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py index 6783d9324f6..94eab7db6eb 100644 --- a/tensorflow/python/keras/optimizer_v2/adam.py +++ b/tensorflow/python/keras/optimizer_v2/adam.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import def_function from tensorflow.python.framework import ops from tensorflow.python.keras import backend_config from tensorflow.python.keras.optimizer_v2 import optimizer_v2 @@ -278,3 +279,226 @@ class Adam(optimizer_v2.OptimizerV2): 'amsgrad': self.amsgrad, }) return config + + +class NonFusedAdam(optimizer_v2.OptimizerV2): + r"""Optimizer that implements the Adam algorithm without fused kernels. + + Adam optimization is a stochastic gradient descent method that is based on + adaptive estimation of first-order and second-order moments. + According to the paper + [Adam: A Method for Stochastic Optimization. Kingma et al., + 2014](http://arxiv.org/abs/1412.6980), the method is "*computationally + efficient, has little memory requirement, invariant to diagonal rescaling of + gradients, and is well suited for problems that are large in terms of + data/parameters*". + + For AMSGrad see [On The Convergence Of Adam And Beyond. + Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ). + + **If amsgrad = False**: + + initialize $m_0$ as 1st moment vector + initialize $v_0$ as 2nd moment vector + + The update rule for $\theta$ with gradient $g$ uses an optimization + described at the end of section 2 of the paper: + + $$lr_t = \mathrm{learning\_rate} * + \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ + $$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ + $$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$ + $$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + + **If amsgrad = True**: + + initialize $m_0$ as 1st moment vector + initialize $v_0$ as 2nd moment vector + initialize $\hat{v}_0$ as 2nd moment vector + + The update rule for $\theta$ with gradient $g$ uses an optimization + described at the end of section 2 of the paper: + + $$lr_t = \mathrm{learning\_rate} * + \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ + + $$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ + $$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$ + $$\hat{v}_t = \max(\hat{v}_{t-1}, v_t)$$ + $$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$ + + The default value of 1e-7 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Usage: + + >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1) + >>> var1 = tf.Variable(10.0) + >>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) == var1 + >>> step_count = opt.minimize(loss, [var1]).numpy() + >>> # The first step is `-learning_rate*sign(grad)` + >>> var1.numpy() + 9.9 + """ + + _HAS_ALL_REDUCE_SUM_GRAD = True + + def __init__(self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + amsgrad=False, + name='Adam', + **kwargs): + """Construct a new Adam optimizer. + + Args: + learning_rate: A `Tensor`, floating point value, or a schedule that is a + `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that + takes no arguments and returns the actual value to use, The learning + rate. Defaults to 0.001. + beta_1: A float value or a constant float tensor, or a callable that takes + no arguments and returns the actual value to use. The exponential decay + rate for the 1st moment estimates. Defaults to 0.9. + beta_2: A float value or a constant float tensor, or a callable that takes + no arguments and returns the actual value to use, The exponential decay + rate for the 2nd moment estimates. Defaults to 0.999. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to + 1e-7. + amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from + the paper "On the Convergence of Adam and beyond". Defaults to `False`. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, + `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip + gradients by value, `decay` is included for backward compatibility to + allow time inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + """ + + super(NonFusedAdam, self).__init__(name, **kwargs) + self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) + self._set_hyper('decay', self._initial_decay) + self._set_hyper('beta_1', beta_1) + self._set_hyper('beta_2', beta_2) + self.epsilon = epsilon or backend_config.epsilon() + self.amsgrad = amsgrad + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + # Separate for-loops to respect the ordering of slot variables from v1. + for var in var_list: + self.add_slot(var, 'm') + for var in var_list: + self.add_slot(var, 'v') + if self.amsgrad: + for var in var_list: + self.add_slot(var, 'vhat') + + def _prepare_local(self, var_device, var_dtype, apply_state): + super(NonFusedAdam, self)._prepare_local(var_device, var_dtype, apply_state) + + local_step = math_ops.cast(self.iterations + 1, var_dtype) + beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) + beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) + beta_1_power = math_ops.pow(beta_1_t, local_step) + beta_2_power = math_ops.pow(beta_2_t, local_step) + lr = ( + apply_state[(var_device, var_dtype)]['lr_t'] * + (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))) + apply_state[(var_device, var_dtype)].update( + dict( + lr=lr, + epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), + beta_1_t=beta_1_t, + beta_1_power=beta_1_power, + one_minus_beta_1_t=1 - beta_1_t, + beta_2_t=beta_2_t, + beta_2_power=beta_2_power, + one_minus_beta_2_t=1 - beta_2_t)) + + def set_weights(self, weights): + params = self.weights + # If the weights are generated by Keras V1 optimizer, it includes vhats + # even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2 + # optimizer has 2x + 1 variables. Filter vhats out for compatibility. + num_vars = int((len(params) - 1) / 2) + if len(weights) == 3 * num_vars + 1: + weights = weights[:len(params)] + super(NonFusedAdam, self).set_weights(weights) + + @def_function.function(experimental_compile=True) + def _resource_apply_dense(self, grad, var, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) or + self._fallback_apply_state(var_device, var_dtype)) + + m = self.get_slot(var, 'm') + v = self.get_slot(var, 'v') + + alpha = ( + coefficients['lr_t'] * math_ops.sqrt(1 - coefficients['beta_2_power']) / + (1 - coefficients['beta_1_power'])) + m.assign_add((grad - m) * (1 - coefficients['beta_1_t'])) + v.assign_add((math_ops.square(grad) - v) * (1 - coefficients['beta_2_t'])) + if self.amsgrad: + vhat = self.get_slot(var, 'vhat') + vhat.assign(math_ops.maximum(vhat, v)) + v = vhat + var.assign_sub( + (m * alpha) / (math_ops.sqrt(v) - coefficients['epsilon'])) + + @def_function.function(experimental_compile=True) + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) or + self._fallback_apply_state(var_device, var_dtype)) + + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, 'm') + m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] + m.assign(m * coefficients['beta_1_t']) + m.scatter_add(ops.IndexedSlices(m_scaled_g_values, indices)) + + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, 'v') + v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] + v.assign(v * coefficients['beta_2_t']) + v.scatter_add(ops.IndexedSlices(v_scaled_g_values, indices)) + + if not self.amsgrad: + var.assign_sub(coefficients['lr'] * m / + (math_ops.sqrt(v) + coefficients['epsilon'])) + else: + v_hat = self.get_slot(var, 'vhat') + v_hat.assign(math_ops.maximum(v_hat, v)) + var.assign_sub(coefficients['lr'] * m / + (math_ops.sqrt(v_hat) + coefficients['epsilon'])) + + def get_config(self): + config = super(NonFusedAdam, self).get_config() + config.update({ + 'learning_rate': self._serialize_hyperparameter('learning_rate'), + 'decay': self._serialize_hyperparameter('decay'), + 'beta_1': self._serialize_hyperparameter('beta_1'), + 'beta_2': self._serialize_hyperparameter('beta_2'), + 'epsilon': self.epsilon, + 'amsgrad': self.amsgrad, + }) + return config diff --git a/tensorflow/python/keras/optimizer_v2/adam_test.py b/tensorflow/python/keras/optimizer_v2/adam_test.py index 83ffc87d792..0b1ae51b08c 100644 --- a/tensorflow/python/keras/optimizer_v2/adam_test.py +++ b/tensorflow/python/keras/optimizer_v2/adam_test.py @@ -569,5 +569,435 @@ class AdamOptimizerTest(test.TestCase, parameterized.TestCase): self.assertAllClose(self.evaluate(opt_3.lr), (0.1)) +class NonFusedAdamOptimizerTest(test.TestCase, parameterized.TestCase): + + def testSparse(self): + # TODO(tanzheny, omalleyt): Fix test in eager mode. + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with ops.Graph().as_default(), self.cached_session(use_gpu=True): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0_np_indices = np.array([0, 2], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np[grads0_np_indices]), + constant_op.constant(grads0_np_indices), constant_op.constant([3])) + grads1_np_indices = np.array([0, 2], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np[grads1_np_indices]), + constant_op.constant(grads1_np_indices), constant_op.constant([3])) + opt = adam.NonFusedAdam() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 3.0, 4.0], self.evaluate(var1)) + + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + # Run 3 steps of NonFusedAdam + for t in range(3): + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSparseDevicePlacement(self): + # TODO(tanzheny, omalleyt): Fix test in eager mode. + for index_dtype in [dtypes.int32, dtypes.int64]: + with ops.Graph().as_default(), self.cached_session( + force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) + g_sum = lambda: math_ops.reduce_sum(array_ops.gather(var, indices)) # pylint: disable=cell-var-from-loop + optimizer = adam.NonFusedAdam(3.0) + minimize_op = optimizer.minimize(g_sum, var_list=[var]) + variables.global_variables_initializer().run() + minimize_op.run() + + def testSparseRepeatedIndices(self): + # TODO(tanzheny, omalleyt): Fix test in eager mode. + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with ops.Graph().as_default(), self.cached_session(): + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update = adam.NonFusedAdam().apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update = adam.NonFusedAdam().apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + + def doTestBasic(self, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.cached_session(use_gpu=True): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = adam.NonFusedAdam(learning_rate=learning_rate) + if not context.executing_eagerly(): + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.evaluate(variables.global_variables_initializer()) + # Run 3 steps of NonFusedAdam + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) + if not context.executing_eagerly(): + self.evaluate(update) + else: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, self.evaluate(var0), rtol=1e-4, atol=1e-4) + self.assertAllCloseAccordingToType( + var1_np, self.evaluate(var1), rtol=1e-4, atol=1e-4) + + @combinations.generate(combinations.combine(mode=["graph", "eager"])) + def testResourceBasic(self): + self.doTestBasic() + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_callable_params=True) + + @combinations.generate(combinations.combine(mode=["graph", "eager"])) + def testBasicWithAmsgrad(self): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.cached_session(use_gpu=True): + # Initialize variables for numpy implementation. + m0, v0, v0hat, m1, v1, v1hat = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = adam.NonFusedAdam(amsgrad=True) + if not context.executing_eagerly(): + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.evaluate(variables.global_variables_initializer()) + # Run 3 steps of NonFusedAdam + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) + if not context.executing_eagerly(): + self.evaluate(update) + else: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0, v0hat = adam_update_numpy_amsgrad( + var0_np, grads0_np, t, m0, v0, v0hat) + var1_np, m1, v1, v1hat = adam_update_numpy_amsgrad( + var1_np, grads1_np, t, m1, v1, v1hat) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, self.evaluate(var0), rtol=1e-4, atol=1e-4) + self.assertAllCloseAccordingToType( + var1_np, self.evaluate(var1), rtol=1e-4, atol=1e-4) + + @combinations.generate(combinations.combine(mode=["graph", "eager"])) + def testSparseWithAmsgrad(self): + # dtypes.half does not work on gpu + eager. + for dtype in [dtypes.float32, dtypes.float64]: + with self.cached_session(): + m0 = np.array([[0.0], [0.0]]) + v0 = np.array([[0.0], [0.0]]) + v0hat = np.array([[0.0], [0.0]]) + indices_np = np.array([1]) + indices = constant_op.constant(indices_np, dtype=dtypes.int32) + var0_np = np.array([[1.0], [2.0]], dtype=dtype.as_numpy_dtype) + repeated_index_update_var = variables.Variable(var0_np, dtype=dtype) + aggregated_update_var = variables.Variable(var0_np, dtype=dtype) + grads0_np = np.array([[0.2]], dtype=dtype.as_numpy_dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices(grads0_np, indices, + constant_op.constant([2, 1])) + opt_repeated = adam.NonFusedAdam(amsgrad=True) + opt_aggregated = adam.NonFusedAdam(amsgrad=True) + if not context.executing_eagerly(): + repeated_update = opt_repeated.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update = opt_aggregated.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + for t in range(3): + if not context.executing_eagerly(): + self.evaluate(repeated_update) + self.evaluate(aggregated_update) + else: + opt_repeated.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + opt_aggregated.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + + var0_np, m0, v0, v0hat = adam_sparse_update_numpy_amsgrad( + var0_np, indices_np, grads0_np, t, m0, v0, v0hat) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, self.evaluate(aggregated_update_var)) + self.assertAllCloseAccordingToType( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + + def testBasicWithLearningRateDecay(self): + # TODO(tanzheny, omalleyt): Fix test in eager mode. + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with ops.Graph().as_default(), self.cached_session(use_gpu=True): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = 0.001 + beta_1 = 0.9 + beta_2 = 0.999 + epsilon = 1e-7 + decay = 0.5 + + opt = adam.NonFusedAdam( + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + decay=decay) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.evaluate(variables.global_variables_initializer()) + # Run 3 steps of NonFusedAdam + for t in range(3): + self.evaluate(update) + lr_np = learning_rate / (1 + decay * t) + + var0_np, m0, v0 = adam_update_numpy( + var0_np, grads0_np, t, m0, v0, lr=lr_np) + var1_np, m1, v1 = adam_update_numpy( + var1_np, grads1_np, t, m1, v1, lr=lr_np) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testBasicWithLearningRateInverseTimeDecay(self): + # TODO(tanzheny, omalleyt): Fix test in eager mode. + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with ops.Graph().as_default(), self.cached_session(use_gpu=True): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = 0.001 + decay = 0.5 + lr_schedule = learning_rate_schedule.InverseTimeDecay( + learning_rate, decay_steps=1.0, decay_rate=decay) + beta_1 = 0.9 + beta_2 = 0.999 + epsilon = 1e-7 + + opt = adam.NonFusedAdam( + learning_rate=lr_schedule, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.evaluate(variables.global_variables_initializer()) + # Run 3 steps of NonFusedAdam + for t in range(3): + self.evaluate(update) + + lr_np = learning_rate / (1 + decay * t) + + var0_np, m0, v0 = adam_update_numpy( + var0_np, grads0_np, t, m0, v0, lr=lr_np) + var1_np, m1, v1 = adam_update_numpy( + var1_np, grads1_np, t, m1, v1, lr=lr_np) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testTensorLearningRate(self): + # TODO(tanzheny, omalleyt): Fix test in eager mode. + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with ops.Graph().as_default(), self.cached_session(use_gpu=True): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam.NonFusedAdam(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + # Run 3 steps of NonFusedAdam + for t in range(3): + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSharing(self): + # TODO(tanzheny, omalleyt): Fix test in eager mode. + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with ops.Graph().as_default(), self.cached_session(use_gpu=True): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam.NonFusedAdam() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of intertwined NonFusedAdam1 and NonFusedAdam2. + for t in range(3): + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + if __name__ == "__main__": test.main() From 2739b46ebdbd09e22deffcf9b05e6b7447c9d8ae Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Thu, 19 Mar 2020 08:52:01 -0700 Subject: [PATCH 227/492] Add 'xla_hlo.sharding' attributes to function arguments and results. By having these as function arguments and results, when converting from MLIR HLO to HLO proto and creating arg/root tuples, these can be looked up and then added to the tuple as tuple shardings. PiperOrigin-RevId: 301823654 Change-Id: I6aac771705703830a1dcaa3f9d50f44e46ea5bbd --- .../tests/tpu_sharding_identification.mlir | 22 ++++ .../tpu_sharding_identification_pass.cc | 119 +++++++++--------- 2 files changed, 85 insertions(+), 56 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index 87eb02eda94..17180490270 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -10,6 +10,7 @@ func @check_sharding_attrs_exists_for_empty_launch_func() { return } +// CHECK-LABEL: func @empty_func() { func @empty_func() { return } @@ -28,6 +29,9 @@ func @check_default_sharding_for_block_arg_inputs_outputs(%arg0: tensor<*xi32>) return } +// CHECK-LABEL: func @func_without_sharding +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) +// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { return %arg0 : tensor<*xi32> } @@ -46,6 +50,9 @@ func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) { return } +// CHECK-LABEL: func @func_without_sharding +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) +// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { %0 = "tf.A"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> @@ -64,6 +71,9 @@ func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) { return } +// CHECK-LABEL: func @inputs_with_sharding_func +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}) +// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> { %0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> %1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>) @@ -83,6 +93,9 @@ func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: te return } +// CHECK-LABEL: func @func_with_sharding +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) { %0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> %1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1> @@ -105,6 +118,9 @@ func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { return } +// CHECK-LABEL: func @func_with_sharding_after_identity +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) { %0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> %1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> @@ -128,6 +144,9 @@ func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi return } +// CHECK-LABEL: func @func_with_sharding_after_read_variable +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*x!tf.resource>> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*x!tf.resource>> {xla_hlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) -> (tensor<*xi32>, tensor<*xi1>) { %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> %1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32> @@ -153,6 +172,9 @@ func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { return } +// CHECK-LABEL: func @func_with_sharding_after_cast +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) { %0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index c9838ff9651..a88675f1557 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -13,28 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/UseDefLists.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace mlir { namespace TFTPU { namespace { +constexpr char kShardingAttr[] = "xla_hlo.sharding"; + struct TPUShardingIdentificationPass : public ModulePass { void runOnModule() override; @@ -71,9 +75,9 @@ void GetAdjacentToXlaShardingOp( // // TODO(hongjunchoi): Add logic to parse XlaSharding op inside a // Call op or if/while op. -llvm::Optional ParseInputSharding(const FuncOp func, - const int arg_index, - const Value& arg) { +llvm::Optional ParseInputSharding(const FuncOp func, + const int arg_index, + const Value& arg) { llvm::Optional parsed_sharding_op; for (auto user : arg.getUsers()) { if (parsed_sharding_op) continue; @@ -86,7 +90,7 @@ llvm::Optional ParseInputSharding(const FuncOp func, GetAdjacentToXlaShardingOp(read_variable_user, &parsed_sharding_op); } - if (!parsed_sharding_op) return llvm::Optional(); + if (!parsed_sharding_op) return llvm::Optional(); return tensorflow::ParseShardingAttribute(parsed_sharding_op->getOperation()); } @@ -103,20 +107,10 @@ llvm::Optional ParseReturnValueSharding(FuncOp func, return llvm::Optional(); } -// Add parsed sharding configuration to tf_device.LaunchFunc op attribute. -void SetShardingConfigurationAsAttribute( - tf_device::LaunchFuncOp launch_func, const std::string& attr_name, - const llvm::SmallVector& sharding_config) { - auto input_sharding_array_ref = llvm::SmallVector( - sharding_config.begin(), sharding_config.end()); - launch_func.setAttr(attr_name, - mlir::Builder(launch_func.getContext()) - .getStrArrayAttr(input_sharding_array_ref)); -} - // If XlaSharding op is connected to input/output of the tf_device.LaunchFuncOp, // then add attributes to the op specifying the sharding configurations. -void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) { +void IdentifyXlaShardingForTPUComputation(Builder* builder, + tf_device::LaunchFuncOp launch_func) { // Look up function definition from module. FuncOp func = launch_func.getParentOfType().lookupSymbol( launch_func.func()); @@ -124,55 +118,68 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) { // By default inputs have maximal sharding and inputs are assigned to // logical core 0 if no sharding is defined. - llvm::SmallVector sharding_for_args( - func_entry_block.getNumArguments(), - xla::sharding_builder::AssignDevice(0).SerializeAsString()); + const std::string logical_core_0_sharding = + xla::sharding_builder::AssignDevice(0).SerializeAsString(); + auto logical_core_0_sharding_attr = + builder->getStringAttr(logical_core_0_sharding); + + llvm::SmallVector sharding_for_args( + func_entry_block.getNumArguments(), logical_core_0_sharding); // Iterate through input arguments to the entry block of tf_device.LaunchFunc. - // For input ops, look for following XlaSharding ops. XlaSharding ops can - // 1) Directly follow the input argument if input argument has non-resource - // types. - // 2) Follow ReadVariableOp if the input type is of resource type. - // 3) Follow IdentityOp or CastOp after above cases (1), (2). - for (auto& arg_index_and_value : - llvm::enumerate(func_entry_block.getArguments())) { - const int arg_index = arg_index_and_value.index(); - auto& arg = arg_index_and_value.value(); - auto input_arg_sharding = ParseInputSharding(func, arg_index, arg); + // For input ops, look for following XlaSharding ops. XlaSharding ops can: + // 1) Directly follow the input argument if input argument has non-resource + // types. + // 2) Follow ReadVariableOp if the input type is of resource type. + // 3) Follow IdentityOp or CastOp after above cases (1), (2). + // + // Sharding configurations are added to the tf_device.LaunchFunc as an + // attribute and the function as an argument attribute. + for (auto& arg : func_entry_block.getArguments()) { + const int index = arg.getArgNumber(); + auto arg_sharding = ParseInputSharding(func, index, arg); - if (!input_arg_sharding.hasValue()) continue; - sharding_for_args[arg_index] = input_arg_sharding->str(); + if (arg_sharding) { + sharding_for_args[index] = arg_sharding.getValue(); + func.setArgAttr(index, kShardingAttr, + builder->getStringAttr(arg_sharding.getValue())); + } else { + func.setArgAttr(index, kShardingAttr, logical_core_0_sharding_attr); + } } - SetShardingConfigurationAsAttribute( - launch_func, tensorflow::kInputShardingAttr, sharding_for_args); + launch_func.setAttr(tensorflow::kInputShardingAttr, + builder->getStrArrayAttr(sharding_for_args)); // By default return values from logical core 0 is used if no sharding // configuration is defined. - llvm::SmallVector sharding_for_return_values( - func_entry_block.getTerminator()->getNumOperands(), - xla::sharding_builder::AssignDevice(0).SerializeAsString()); + Operation* terminator = func_entry_block.getTerminator(); + llvm::SmallVector sharding_for_rets( + terminator->getNumOperands(), logical_core_0_sharding); - // Iterate through operands of the terminator, if the preceding op is - // XlaShardingOp, then add provided sharding configuration to launch func + // Iterate through operands of the terminator. If the preceding op is + // XlaShardingOp, then the provided sharding configuration is added to the + // tf_device.LaunchFunc as an attribute and the function as a result // attribute. - for (auto& return_value_and_index : - llvm::enumerate(func_entry_block.getTerminator()->getOpOperands())) { - int return_value_index = return_value_and_index.index(); - const auto& return_value = return_value_and_index.value(); - auto return_val_sharding = - ParseReturnValueSharding(func, return_value_index, return_value); + for (auto& ret : terminator->getOpOperands()) { + const int index = ret.getOperandNumber(); + auto ret_sharding = ParseReturnValueSharding(func, index, ret); - if (return_val_sharding) - sharding_for_return_values[return_value_index] = - return_val_sharding->str(); + if (ret_sharding) { + sharding_for_rets[index] = ret_sharding.getValue(); + func.setResultAttr(index, kShardingAttr, + builder->getStringAttr(ret_sharding.getValue())); + } else { + func.setResultAttr(index, kShardingAttr, logical_core_0_sharding_attr); + } } - SetShardingConfigurationAsAttribute( - launch_func, tensorflow::kOutputShardingAttr, sharding_for_return_values); + launch_func.setAttr(tensorflow::kOutputShardingAttr, + builder->getStrArrayAttr(sharding_for_rets)); } void TPUShardingIdentificationPass::runOnModule() { + Builder builder(getModule().getContext()); getModule().walk([&](tf_device::LaunchFuncOp launch_func) { - IdentifyXlaShardingForTPUComputation(launch_func); + IdentifyXlaShardingForTPUComputation(&builder, launch_func); }); } From 5c90d76d41605c895370e0bc475429dc62d51bd2 Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Thu, 19 Mar 2020 16:56:51 +0100 Subject: [PATCH 228/492] Add -lrt to linkflags Fixes compilation on e.g. older CentOS, see for details #15129 --- tensorflow/tensorflow.bzl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 390acacefe8..d10650479d6 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -537,7 +537,7 @@ def tf_cc_shared_object( srcs = [], deps = [], data = [], - linkopts = [], + linkopts = if_not_windows(["-lrt"]), framework_so = tf_binary_additional_srcs(), soversion = None, kernels = [], @@ -641,7 +641,7 @@ def tf_cc_binary( srcs = [], deps = [], data = [], - linkopts = [], + linkopts = if_not_windows(["-lrt"]), copts = tf_copts(), kernels = [], per_os_targets = False, # Generate targets with SHARED_LIBRARY_NAME_PATTERNS @@ -737,7 +737,7 @@ def tf_gen_op_wrapper_cc( tf_cc_binary( name = tool, copts = tf_copts(), - linkopts = if_not_windows(["-lm", "-Wl,-ldl"]), + linkopts = if_not_windows(["-lm", "-Wl,-ldl", "-lrt"]), linkstatic = 1, # Faster to link this one-time-use binary dynamically deps = [op_gen] + deps, ) @@ -924,7 +924,7 @@ def tf_gen_op_wrapper_py( tf_cc_binary( name = tool_name, copts = tf_copts(), - linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + cc_linkopts, + linkopts = if_not_windows(["-lm", "-Wl,-ldl", "-lrt"]) + cc_linkopts, linkstatic = 1, # Faster to link this one-time-use binary dynamically visibility = [clean_dep("//tensorflow:internal")], deps = ([ @@ -1221,7 +1221,7 @@ def tf_cc_tests( tags = [], size = "medium", args = None, - linkopts = [], + linkopts = if_not_windows(["-lrt"]), kernels = [], create_named_test_suite = False, visibility = None): From 0a0600a8b2f4528a2a4f361db6d92f2d21fed168 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 19 Mar 2020 08:58:39 -0700 Subject: [PATCH 229/492] Bump tensorflow open source LLVM revision to 4a7f2032a350bc7eefd26709563f65216df3e2ce PiperOrigin-RevId: 301825006 Change-Id: I7156efcf76d12743bf78b7f3b06c1173d3e8959a --- tensorflow/workspace.bzl | 4 ++-- third_party/mlir/BUILD | 23 +++++++++++++++++++++++ third_party/mlir/test.BUILD | 1 + 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 9de46e711ba..10f5eb3dd63 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -597,8 +597,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "b72e13c242d9bbe1a4c7e471da98718bde85fa78" - LLVM_SHA256 = "d7823d08ac835f5ca587aee8e252ffdedfac5e72b62defa1fffd214b9c649841" + LLVM_COMMIT = "4a7f2032a350bc7eefd26709563f65216df3e2ce" + LLVM_SHA256 = "e43e9067427a331542733d5863b2e94369ed95b59af9999dcabdd5315ff46373" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index dbe0b53002e..fbbdb73cecc 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -1316,6 +1316,27 @@ cc_library( ], ) +cc_library( + name = "StandardToStandard", + srcs = glob([ + "lib/Conversion/StandardToStandard/*.cpp", + "lib/Conversion/StandardToStandard/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/StandardToStandard/*.h", + ]), + includes = [ + "include", + "lib/Conversion/StandardToStandard", + ], + deps = [ + ":IR", + ":Pass", + ":StandardOps", + ":Transforms", + ], +) + cc_library( name = "SPIRVSerialization", srcs = glob( @@ -1943,6 +1964,7 @@ cc_library( ":Pass", ":QuantizerTransforms", ":StandardToSPIRVConversions", + ":StandardToStandard", ":Support", ":Transforms", ":VectorToLLVM", @@ -2035,6 +2057,7 @@ cc_library( ":Shape", ":StandardOps", ":StandardToSPIRVConversions", + ":StandardToStandard", ":Transforms", ":VectorOps", ], diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD index a27c7c9f9f3..3ba2bbf0fb1 100644 --- a/third_party/mlir/test.BUILD +++ b/third_party/mlir/test.BUILD @@ -111,6 +111,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:StandardToStandard", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], From 7b649ad65c50e692ffc125495d6cfae6470757b4 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Thu, 19 Mar 2020 09:18:38 -0700 Subject: [PATCH 230/492] Disable tsan test for multi_worker_continuous_run_test. PiperOrigin-RevId: 301829016 Change-Id: I243e024c136db19e5c653e9909493fb256ff62a3 --- tensorflow/python/distribute/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 459cfb6b1bf..0667913bc76 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -414,6 +414,9 @@ cuda_py_test( srcs = [ "multi_worker_continuous_run_test.py", ], + tags = [ + "notsan", # TODO(b/151841995) + ], deps = [ ":collective_all_reduce_strategy", ":multi_process_runner", From ce3925cf580435b09ee7cc4156e998dba3a8d2d9 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 19 Mar 2020 09:18:57 -0700 Subject: [PATCH 231/492] Remove test to un-red nightly builds as suggested in b/151378056 PiperOrigin-RevId: 301829065 Change-Id: I88a9d159faa9169de2854e69cc116af29943c74e --- .../compatibility/testdata/test_file_v1_12.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_12.py b/tensorflow/tools/compatibility/testdata/test_file_v1_12.py index 36176e7d568..c74e076e4c3 100644 --- a/tensorflow/tools/compatibility/testdata/test_file_v1_12.py +++ b/tensorflow/tools/compatibility/testdata/test_file_v1_12.py @@ -80,31 +80,6 @@ class TestUpgrade(test_util.TensorFlowTestCase): logits=[0.1, 0.8], labels=[0, 1]) self.assertAllClose(out, 0.40318608) - def testLinearClassifier(self): - if _TEST_VERSION == 2 and self._tf_api_version == 1: - # Skip if we converted this file to v2 but running with tf v1. - # In this case, conversion script adds reference to - # tf.keras.losses.Reduction which is not available in v1. - self.skipTest( - 'After converting to 2.0, this test does not work with ' - 'TensorFlow 1.x.') - return - feature_column = tf.feature_column.numeric_column( - 'feature', shape=(1,)) - - classifier = tf.estimator.LinearClassifier( - n_classes=2, feature_columns=[feature_column]) - - data = {'feature': [1, 20, 3]} - target = [0, 1, 0] - classifier.train( - input_fn=lambda: (data, target), - steps=100) - scores = classifier.evaluate( - input_fn=lambda: (data, target), - steps=100) - self.assertGreater(scores['accuracy'], 0.99) - def testUniformUnitScalingInitializer(self): init = tf.initializers.uniform_unit_scaling(0.5, seed=1) self.assertArrayNear( From 905ed55a28bb6ff539bb421ae5e979ccd32618f1 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 19 Mar 2020 09:22:24 -0700 Subject: [PATCH 232/492] Update package visibility for all keras sub packages. PiperOrigin-RevId: 301829736 Change-Id: If1cd07bd48ff4c2f2ec850cc71460f70718a5dc0 --- tensorflow/python/keras/api/BUILD | 4 +++- tensorflow/python/keras/applications/BUILD | 6 +++++- tensorflow/python/keras/datasets/BUILD | 4 +++- tensorflow/python/keras/distribute/BUILD | 7 ++++++- tensorflow/python/keras/engine/BUILD | 7 ++++++- tensorflow/python/keras/layers/BUILD | 10 +++++++++- tensorflow/python/keras/layers/preprocessing/BUILD | 5 ++++- .../python/keras/mixed_precision/experimental/BUILD | 8 +++++++- tensorflow/python/keras/optimizer_v2/BUILD | 8 +++++++- tensorflow/python/keras/premade/BUILD | 7 ++++--- tensorflow/python/keras/preprocessing/BUILD | 7 ++++++- tensorflow/python/keras/saving/BUILD | 6 +++++- tensorflow/python/keras/tests/BUILD | 4 +++- tensorflow/python/keras/type/BUILD | 2 +- tensorflow/python/keras/utils/BUILD | 6 +++++- tensorflow/python/keras/wrappers/BUILD | 2 +- 16 files changed, 75 insertions(+), 18 deletions(-) diff --git a/tensorflow/python/keras/api/BUILD b/tensorflow/python/keras/api/BUILD index 19ad03a09bf..32c5e87a8f9 100644 --- a/tensorflow/python/keras/api/BUILD +++ b/tensorflow/python/keras/api/BUILD @@ -6,7 +6,9 @@ load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "KERAS_API_IN load("//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "KERAS_API_INIT_FILES_V1") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + "//tensorflow:tensorflow_py", + ], licenses = ["notice"], # Apache 2.0 License ) diff --git a/tensorflow/python/keras/applications/BUILD b/tensorflow/python/keras/applications/BUILD index 39c562c66c3..992010a809e 100644 --- a/tensorflow/python/keras/applications/BUILD +++ b/tensorflow/python/keras/applications/BUILD @@ -4,7 +4,11 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + # Remove this deps to integration test. + "//tensorflow/lite/experimental/tf_runtime:__pkg__", + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/datasets/BUILD b/tensorflow/python/keras/datasets/BUILD index 307ba24fa18..63d9826f5ec 100644 --- a/tensorflow/python/keras/datasets/BUILD +++ b/tensorflow/python/keras/datasets/BUILD @@ -2,7 +2,9 @@ # Contains the Keras datasets package (internal TensorFlow version). package( - default_visibility = ["//visibility:public"], + default_visibility = [ + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 515d201d1bd..7f2339f2ff9 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -7,7 +7,12 @@ load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( - default_visibility = ["//visibility:public"], + # TODO(scottzhu): Remove this deps when distribute test are converted to integration test. + default_visibility = [ + "//tensorflow/python/distribute:__pkg__", + "//tensorflow/python/keras:__subpackages__", + "//tensorflow/tools/pip_package:__pkg__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 54cce9e6486..51666a0f34f 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -5,7 +5,12 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( - default_visibility = ["//visibility:public"], + # TODO(scottzhu): Remove non-keras deps from TF. + default_visibility = [ + "//tensorflow/python:__pkg__", + "//tensorflow/python/feature_column:__pkg__", + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 1482e747a42..ad0cdb20f44 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -5,7 +5,15 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( - default_visibility = ["//visibility:public"], + # TODO(scottzhu): Remove non-keras deps from TF. + default_visibility = [ + "//tensorflow/python/distribute:__pkg__", + "//tensorflow/python/feature_column:__pkg__", + "//tensorflow/python/keras:__subpackages__", + "//tensorflow/python/kernel_tests:__pkg__", + "//tensorflow/python/training/tracking:__pkg__", + "//tensorflow/tools/pip_package:__pkg__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD index e0dd9114755..288dfa1cf84 100644 --- a/tensorflow/python/keras/layers/preprocessing/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/BUILD @@ -5,7 +5,10 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + "//tensorflow/python/keras:__subpackages__", + "//tensorflow/tools/pip_package:__pkg__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD index c672bd51cc7..09f6970db31 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD @@ -20,7 +20,13 @@ load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + # TODO(scottzhu): Remove these two deps and convert the test to integration test. + "//tensorflow/python:__pkg__", # For loss_scale_optimizer_test + "//tensorflow/python/distribute:__pkg__", # For collective_all_reduce_strategy_test + "//tensorflow/python/keras:__subpackages__", + "//tensorflow/tools/pip_package:__pkg__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index 03d9a0070f4..afdb8bc04b3 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -4,7 +4,13 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( - default_visibility = ["//visibility:public"], + # TODO(scottzhu): Remove non-keras deps from TF. + default_visibility = [ + "//tensorflow/python:__pkg__", + "//tensorflow/python/distribute:__pkg__", + "//tensorflow/python/keras:__subpackages__", + "//tensorflow/python/training/tracking:__pkg__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/premade/BUILD b/tensorflow/python/keras/premade/BUILD index 2892dfbb0fb..8c30cdbe2a6 100644 --- a/tensorflow/python/keras/premade/BUILD +++ b/tensorflow/python/keras/premade/BUILD @@ -1,15 +1,16 @@ # Description: # Contains the Keras Premade Models (internal TensorFlow version). +load("//tensorflow:tensorflow.bzl", "py_test") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") - py_library( name = "premade", srcs = [ diff --git a/tensorflow/python/keras/preprocessing/BUILD b/tensorflow/python/keras/preprocessing/BUILD index 7c75e45fc58..1bfbef38ac9 100644 --- a/tensorflow/python/keras/preprocessing/BUILD +++ b/tensorflow/python/keras/preprocessing/BUILD @@ -4,7 +4,12 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + # TODO(scottzhu): Remove non-keras deps from TF. + "//tensorflow/lite/experimental/tf_runtime:__pkg__", + "//tensorflow/tools/docs:__pkg__", + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index eda4df9b742..8220a951b41 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -4,7 +4,11 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") package( - default_visibility = ["//visibility:public"], + # TODO(scottzhu): Remove non-keras deps from TF. + default_visibility = [ + "//tensorflow/python/distribute:__pkg__", + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD index 94f5624bd4e..bcbb7a375d0 100644 --- a/tensorflow/python/keras/tests/BUILD +++ b/tensorflow/python/keras/tests/BUILD @@ -4,7 +4,9 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + "//tensorflow/tools/pip_package:__pkg__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/type/BUILD b/tensorflow/python/keras/type/BUILD index cc82b1b4b7f..bb612301dd1 100644 --- a/tensorflow/python/keras/type/BUILD +++ b/tensorflow/python/keras/type/BUILD @@ -1,7 +1,7 @@ load("//tensorflow:tensorflow.bzl", "py_strict_library") package( - default_visibility = ["//tensorflow:__subpackages__"], + default_visibility = ["//tensorflow/python/keras:__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index 681dec5932e..8e84a789c66 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -5,7 +5,11 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( - default_visibility = ["//visibility:public"], + # TODO(scottzhu): Remove non-keras deps from TF. + default_visibility = [ + "//tensorflow/python/feature_column:__pkg__", + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/python/keras/wrappers/BUILD b/tensorflow/python/keras/wrappers/BUILD index 5f8d6bd8780..446dac2697f 100644 --- a/tensorflow/python/keras/wrappers/BUILD +++ b/tensorflow/python/keras/wrappers/BUILD @@ -4,7 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") package( - default_visibility = ["//visibility:public"], + default_visibility = ["//tensorflow/python/keras:__subpackages__"], licenses = ["notice"], # Apache 2.0 ) From 7b1ed0adaef5985932283d45450efe134f5f4226 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Thu, 19 Mar 2020 09:24:43 -0700 Subject: [PATCH 233/492] Moves Interpreter, InterpreterBuilder, and Subgraph to `tflite::impl` namespace. PiperOrigin-RevId: 301830159 Change-Id: If6f07b0dba54eab4ea1abc4410a5ca60c1d6bc35 --- tensorflow/lite/core/subgraph.cc | 4 ++++ tensorflow/lite/core/subgraph.h | 6 ++++++ tensorflow/lite/interpreter.cc | 4 ++++ tensorflow/lite/interpreter.h | 10 +++++++++- tensorflow/lite/model.cc | 4 ++++ tensorflow/lite/model.h | 6 ++++++ tensorflow/lite/python/interpreter_wrapper/BUILD | 1 + .../python/interpreter_wrapper/interpreter_wrapper.h | 2 +- tensorflow/lite/python/optimize/BUILD | 1 + tensorflow/lite/python/optimize/calibration_wrapper.h | 3 ++- 10 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 09e70390f0f..d057b2adc6e 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -29,6 +29,8 @@ limitations under the License. namespace tflite { +namespace impl { + namespace { struct TfLiteQuantizationDeleter { @@ -1349,4 +1351,6 @@ TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { return status; } +} // namespace impl + } // namespace tflite diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 4380feda283..d2d5eaf2cbf 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -31,6 +31,8 @@ limitations under the License. namespace tflite { +namespace impl { + // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -675,5 +677,9 @@ class Subgraph { resource::ResourceMap* resources_ = nullptr; }; +} // namespace impl + +using Subgraph = impl::Subgraph; + } // namespace tflite #endif // TENSORFLOW_LITE_CORE_SUBGRAPH_H_ diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index d333fa736e3..55abd6c148c 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -51,6 +51,8 @@ static_assert(sizeof(TfLiteFloat16) == sizeof(uint16_t), namespace tflite { +namespace impl { + namespace { // Gets the current TfLiteQuantization from the legacy TfLiteQuantizationParams. @@ -371,4 +373,6 @@ Profiler* Interpreter::GetProfiler() { return primary_subgraph().GetProfiler(); } +} // namespace impl + } // namespace tflite diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index 093390afbb7..dd183b2a98f 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -36,6 +36,10 @@ limitations under the License. namespace tflite { +class InterpreterTest; + +namespace impl { + /// An interpreter for a graph of nodes that input and output from tensors. /// Each node of the graph processes a set of input tensors and produces a /// set of output Tensors. All inputs/output tensors are referenced by index. @@ -494,7 +498,7 @@ class Interpreter { private: friend class InterpreterBuilder; - friend class InterpreterTest; + friend class tflite::InterpreterTest; /// Set the value of an external context. static void SetExternalContext(struct TfLiteContext* context, @@ -542,5 +546,9 @@ class Interpreter { resource::ResourceMap resources_; }; +} // namespace impl + +using Interpreter = impl::Interpreter; + } // namespace tflite #endif // TENSORFLOW_LITE_INTERPRETER_H_ diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc index bb08976c73e..25f196d272b 100644 --- a/tensorflow/lite/model.cc +++ b/tensorflow/lite/model.cc @@ -270,6 +270,8 @@ FlatBufferModel::FlatBufferModel(std::unique_ptr allocation, FlatBufferModel::~FlatBufferModel() {} +namespace impl { + InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, const OpResolver& op_resolver) : model_(model.GetModel()), @@ -783,4 +785,6 @@ TfLiteStatus InterpreterBuilder::operator()( return kTfLiteOk; } +} // namespace impl + } // namespace tflite diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h index 159f8002ddb..fd196c049e9 100644 --- a/tensorflow/lite/model.h +++ b/tensorflow/lite/model.h @@ -174,6 +174,8 @@ class FlatBufferModel { std::unique_ptr allocation_; }; +namespace impl { + /// Build an interpreter capable of interpreting `model`. /// /// model: A model whose lifetime must be at least as long as any @@ -238,6 +240,10 @@ class InterpreterBuilder { bool has_flex_op_ = false; }; +} // namespace impl + +using InterpreterBuilder = impl::InterpreterBuilder; + } // namespace tflite #endif // TENSORFLOW_LITE_MODEL_H_ diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index 14dbc553257..c1778e7b12d 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -71,6 +71,7 @@ pybind_extension( module_name = "_pywrap_tensorflow_interpreter_wrapper", deps = [ ":interpreter_wrapper_lib", + "//tensorflow/lite:framework_lib", "//tensorflow/lite/experimental/tflite_api_dispatcher", "//tensorflow/python:pybind11_lib", "//third_party/python_runtime:headers", diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index 8a5ff215f3a..b509c1ca199 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -27,6 +27,7 @@ limitations under the License. #include #include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h" +#include "tensorflow/lite/interpreter.h" struct TfLiteDelegate; @@ -39,7 +40,6 @@ class BuiltinOpResolver; } // namespace ops class FlatBufferModel; -class Interpreter; namespace interpreter_wrapper { diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD index 93af9fc1e9a..ba75dca9362 100644 --- a/tensorflow/lite/python/optimize/BUILD +++ b/tensorflow/lite/python/optimize/BUILD @@ -51,6 +51,7 @@ pybind_extension( module_name = "_pywrap_tensorflow_lite_calibration_wrapper", deps = [ ":calibration_wrapper_lib", + "//tensorflow/lite:framework_lib", "//tensorflow/python:pybind11_lib", "//third_party/python_runtime:headers", "@pybind11", diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h index 4fcbfea6fea..7b5ae50e657 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.h +++ b/tensorflow/lite/python/optimize/calibration_wrapper.h @@ -26,6 +26,8 @@ limitations under the License. // automatically move before . #include +#include "tensorflow/lite/interpreter.h" + // We forward declare TFLite classes here to avoid exposing them to SWIG. namespace tflite { namespace ops { @@ -35,7 +37,6 @@ class BuiltinOpResolver; } // namespace ops class FlatBufferModel; -class Interpreter; namespace interpreter_wrapper { class PythonErrorReporter; From 2dc1efeb1ba4b911c053768fa25bdb56932656d2 Mon Sep 17 00:00:00 2001 From: Tiezhen WANG Date: Thu, 19 Mar 2020 09:29:12 -0700 Subject: [PATCH 234/492] Upgrade to Flatbuffer 1.12 PiperOrigin-RevId: 301831001 Change-Id: I5883d2e65e189dbdf6d2ed0788ad4842661b0758 --- tensorflow/lite/micro/tools/make/third_party_downloads.inc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/micro/tools/make/third_party_downloads.inc b/tensorflow/lite/micro/tools/make/third_party_downloads.inc index c5ac9f9ec1a..ca544d1371e 100644 --- a/tensorflow/lite/micro/tools/make/third_party_downloads.inc +++ b/tensorflow/lite/micro/tools/make/third_party_downloads.inc @@ -3,8 +3,8 @@ GEMMLOWP_URL := "https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37f7f98adcc7fc9f425.zip" GEMMLOWP_MD5 := "7e8191b24853d75de2af87622ad293ba" -FLATBUFFERS_URL := "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.11.0.tar.gz" -FLATBUFFERS_MD5 := "02c64880acb89dbd57eebacfd67200d8" +FLATBUFFERS_URL := "https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz" +FLATBUFFERS_MD5 := "c62ffefb3d4548b127cca14ce047f16c" ifeq ($(HOST_OS),osx) GCC_EMBEDDED_URL := "https://developer.arm.com/-/media/Files/downloads/gnu-rm/7-2018q2/gcc-arm-none-eabi-7-2018-q2-update-mac.tar.bz2" From 6535f8645c094a95b367f1b19d851f7121103a87 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 09:37:20 -0700 Subject: [PATCH 235/492] Fix a typo in runlit.site.cfg.py PiperOrigin-RevId: 301832660 Change-Id: Ief4ae5e7892d02b6e237864d21feb064ed46c38a --- tensorflow/compiler/mlir/runlit.site.cfg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 6c369a5a24c..b623ca8e849 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -22,7 +22,7 @@ import platform import lit.llvm # Handle the test srcdir for platforms. On windows, things are weird with bazel. -if platform.system == 'Windows': +if platform.system() == 'Windows': srcdir = os.environ['TEST_SRCDIR'] real_test_srcdir = srcdir[:srcdir.find('tensorflow/compiler/mlir')] external_srcdir = os.path.join(real_test_srcdir, 'external') @@ -56,7 +56,7 @@ test_dir = test_dir.strip('/').rsplit(':', 1)[0] config.mlir_test_dir = os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], test_dir) -if platform.system == 'Windows': +if platform.system() == 'Windows': # Configure this to work with msys2, TF's preferred windows bash. config.lit_tools_dir = '/usr/bin' From 3820a4ac5d017b5c8ea179b2c5b99d7b89004f02 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 09:46:20 -0700 Subject: [PATCH 236/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301834543 Change-Id: I7f757d8744456363c1e068bbc94721c8b40e270d --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 7be0c66548c..3d05bb08fa3 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11815,7 +11815,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -12072,7 +12072,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12083,7 +12083,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12301,7 +12301,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12312,7 +12312,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19153,7 +19153,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20224,7 +20224,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21396,7 +21396,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22104,7 +22104,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22300,7 +22300,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22484,7 +22484,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22543,7 +22543,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22717,7 +22717,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -23098,7 +23098,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25441,7 +25441,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25504,7 +25504,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25747,7 +25747,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26370,7 +26370,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45499,7 +45499,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46287,7 +46287,7 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46350,7 +46350,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 703b996f3369746b900af19c207894a9284e0266 Mon Sep 17 00:00:00 2001 From: Michael Banfield Date: Thu, 19 Mar 2020 09:52:54 -0700 Subject: [PATCH 237/492] Support restartType in cloud tpu client. PiperOrigin-RevId: 301835905 Change-Id: Iaa677be84aedfd1cbb68a467703ac10d458073c2 --- tensorflow/python/tpu/client/client.py | 14 +++++++++--- tensorflow/python/tpu/client/client_test.py | 25 +++++++++++++++------ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/tpu/client/client.py b/tensorflow/python/tpu/client/client.py index cd747b87004..fdfda90f7d0 100644 --- a/tensorflow/python/tpu/client/client.py +++ b/tensorflow/python/tpu/client/client.py @@ -280,8 +280,15 @@ class Client(object): logging.warning('TPU "%s" is healthy.', self.name()) - def configure_tpu_version(self, version): - """Configure TPU software version.""" + def configure_tpu_version(self, version, restart_type='always'): + """Configure TPU software version. + + Args: + version (string): Version of software to configure the TPU with. + restart_type (string): Restart behaviour when switching versions, + defaults to always restart. Options are 'always', 'ifNeeded'. + + """ def configure_worker(worker): """Configure individual TPU worker. @@ -291,7 +298,8 @@ class Client(object): be sent. """ ip_address = worker['ipAddress'] - url = 'http://{}:8475/requestversion/{}'.format(ip_address, version) + url = 'http://{}:8475/requestversion/{}?restartType={}'.format( + ip_address, version, restart_type) req = request.Request(url, data=b'') try: request.urlopen(req) diff --git a/tensorflow/python/tpu/client/client_test.py b/tensorflow/python/tpu/client/client_test.py index 25d4da345e7..09dcdcd86f8 100644 --- a/tensorflow/python/tpu/client/client_test.py +++ b/tensorflow/python/tpu/client/client_test.py @@ -396,8 +396,7 @@ class CloudTpuClientTest(test.TestCase): 'Timed out waiting for TPU .* to become healthy'): c.wait_for_healthy(timeout_s=80, interval=5) - @mock.patch.object(request, 'urlopen') - def testConfigureTpuVersion(self, urlopen): + def baseConfigureTpuVersion(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 'state': @@ -412,18 +411,30 @@ class CloudTpuClientTest(test.TestCase): ] } } - c = client.Client( + return client.Client( tpu='tpu_name', project='test-project', zone='us-central1-c', service=self.mock_service_client(tpu_map=tpu_map)) + + @mock.patch.object(request, 'urlopen') + def testConfigureTpuVersion(self, urlopen): + c = self.baseConfigureTpuVersion() c.configure_tpu_version('1.15') - paths = [call[0][0].full_url for call in urlopen.call_args_list] - self.assertEqual([ - 'http://1.2.3.4:8475/requestversion/1.15', - 'http://5.6.7.8:8475/requestversion/1.15' + 'http://1.2.3.4:8475/requestversion/1.15?restartType=always', + 'http://5.6.7.8:8475/requestversion/1.15?restartType=always' + ], sorted(paths)) + + @mock.patch.object(request, 'urlopen') + def testConfigureTpuVersionRestartIfneeded(self, urlopen): + c = self.baseConfigureTpuVersion() + c.configure_tpu_version('1.15', restart_type='ifNeeded') + paths = [call[0][0].full_url for call in urlopen.call_args_list] + self.assertEqual([ + 'http://1.2.3.4:8475/requestversion/1.15?restartType=ifNeeded', + 'http://5.6.7.8:8475/requestversion/1.15?restartType=ifNeeded' ], sorted(paths)) From 005ae406cbc8ed5e6fb9c869c6d6a815507aad84 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 10:01:07 -0700 Subject: [PATCH 238/492] Fix `get_replicated_var_handle` when variables-to-replicate are created manually with tf.device. PiperOrigin-RevId: 301837676 Change-Id: I7a916fe2dc0607942578294aad0254c51499f6b8 --- tensorflow/python/tpu/tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index fe8fac794db..768ef072052 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -314,7 +314,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # Note that the order of devices for replicas for the variable and the # device assignment might not match. job_name = pydev.DeviceSpec.from_string(vars_[0].device).job - devices_to_vars = {v.device: v for v in vars_} + devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_} replicated_vars = [] for replica_id in range(device_assignment.num_replicas): for logical_core in range(device_assignment.num_cores_per_replica): From f9291316f2d1574a41155f99b688233cb3b959fc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 10:08:01 -0700 Subject: [PATCH 239/492] trace "equation" for einsum op. PiperOrigin-RevId: 301839531 Change-Id: I25e48e101dc0183a04c4d0d61f198671af51fcd2 --- tensorflow/core/kernels/einsum_op_impl.h | 25 +++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/einsum_op_impl.h b/tensorflow/core/kernels/einsum_op_impl.h index a7afaad97bf..679a3de9b82 100644 --- a/tensorflow/core/kernels/einsum_op_impl.h +++ b/tensorflow/core/kernels/einsum_op_impl.h @@ -610,13 +610,12 @@ template class EinsumOp : public OpKernel { public: explicit EinsumOp(OpKernelConstruction* c) : OpKernel(c) { - string equation; - OP_REQUIRES_OK(c, c->GetAttr("equation", &equation)); - OP_REQUIRES_OK(c, - EinsumHelper::ParseEquation( - equation, &input_labels_, &output_labels_, &label_types_, - &input_label_counts_, &output_label_counts_, - &input_has_ellipsis_, &output_has_ellipsis_)); + OP_REQUIRES_OK(c, c->GetAttr("equation", &equation_)); + OP_REQUIRES_OK( + c, EinsumHelper::ParseEquation( + equation_, &input_labels_, &output_labels_, &label_types_, + &input_label_counts_, &output_label_counts_, + &input_has_ellipsis_, &output_has_ellipsis_)); } void Compute(OpKernelContext* ctx) override { @@ -735,7 +734,19 @@ class EinsumOp : public OpKernel { ctx->set_output(0, output); } + string TraceString(OpKernelContext* ctx, bool verbose) override { + if (!verbose) { + return strings::StrCat(name_view(), ":", type_string_view(), + "#equation=(", equation_, ")#"); + } else { + string trace_args = GetTraceArgument(ctx); + return strings::StrCat(name_view(), ":", type_string_view(), + "#equation=(", equation_, "),", trace_args, "#"); + } + } + private: + string equation_; OperandLabels input_labels_; Labels output_labels_; std::vector label_types_; From 4a53b4af9a3c9df29ae3759404d5c9b2340515c4 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Thu, 19 Mar 2020 10:16:28 -0700 Subject: [PATCH 240/492] Disable cost_analyzer_test on windows. It fails with heap corruption. PiperOrigin-RevId: 301841560 Change-Id: I23aacc0e73e7259b0649263ef949dd94506b5823 --- tensorflow/python/BUILD | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index d932899ab0d..d84ba11cf8b 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -7723,8 +7723,9 @@ tf_py_test( tags = [ "grappler", "no_cuda_on_cpu_tap", + "no_mac", "no_pip", - "nomac", + "no_windows", # TODO(b/151942037) ], deps = [ ":array_ops", From df950173f48b8fa39687c3aca8d314b60c79ae2d Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Thu, 19 Mar 2020 10:40:52 -0700 Subject: [PATCH 241/492] Fix OSS TPUClusterResolver `get_tpu_system_metadata` method. PiperOrigin-RevId: 301846924 Change-Id: Ia82b8328a8e48a9ffd91a08be583c5b92162c3c3 --- .../cluster_resolver/tpu_cluster_resolver.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py index a1e95fc380d..79ec0bc13d1 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py @@ -221,7 +221,20 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver): return self.task_type def get_tpu_system_metadata(self): - """Retrieves TPU system metadata given a TPUClusterResolver.""" + """Returns the metadata of the TPU system. + + Users can call this method to get some facts of the TPU system, like + total number of cores, number of TPU workers and the devices. E.g. + ```python + + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + tpu_system_medata = resolver.get_tpu_system_metadata() + num_hosts = tpu_system_medata.num_hosts + ``` + + Returns: + A `tf.tpu.experimental.TPUSystemMetadata` object. + """ cluster_spec = self.cluster_spec() cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None tpu_system_metadata = ( @@ -230,6 +243,8 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver): cluster_def=cluster_def, query_topology=False)) + return tpu_system_metadata + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. From e777af90e878cb3618e583bc1b8ec796903ccbdb Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 18 Mar 2020 17:47:45 -0700 Subject: [PATCH 242/492] Separate filesystem registration from DSO loading. As we need to execute the same registration code when filesystems are registered statically, we move the DSO loading code to the `RegisterFilesystemPlugin` function and keep `RegisterFilesystemPluginImpl` available to be called regardless of the registration type (plugin or static dependency). --- .../filesystem/modular_filesystem.cc | 20 ++++++++- .../modular_filesystem_registration.cc | 41 +++++-------------- .../modular_filesystem_registration.h | 3 +- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.cc b/tensorflow/c/experimental/filesystem/modular_filesystem.cc index 8645d3186c8..58541ea2b36 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.cc @@ -440,7 +440,25 @@ Status ModularWritableFile::Tell(int64* position) { } Status RegisterFilesystemPlugin(const std::string& dso_path) { - return filesystem_registration::RegisterFilesystemPluginImpl(dso_path); + // Step 1: Load plugin + Env* env = Env::Default(); + void* dso_handle; + TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle)); + + // Step 2: Load symbol for `TF_InitPlugin` + void* dso_symbol; + TF_RETURN_IF_ERROR( + env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol)); + + // Step 3: Call `TF_InitPlugin` + TF_FilesystemPluginInfo info; + memset(&info, 0, sizeof(info)); + auto TF_InitPlugin = + reinterpret_cast(dso_symbol); + TF_InitPlugin(&info); + + // Step 4: Do the actual registration + return filesystem_registration::RegisterFilesystemPluginImpl(&info); } } // namespace tensorflow diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc index 5f6c2048e56..174665f8927 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h" -#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/modular_filesystem.h" #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/platform/env.h" @@ -304,40 +303,22 @@ static Status ValidatePluginMemoryRoutines( namespace filesystem_registration { -Status RegisterFilesystemPluginImpl(const std::string& dso_path) { - // Step 1: Load plugin - Env* env = Env::Default(); - void* dso_handle; - TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle)); +Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info) { + TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(info)); - // Step 2: Load symbol for `TF_InitPlugin` - void* dso_symbol; - TF_RETURN_IF_ERROR( - env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol)); - - // Step 3: Call `TF_InitPlugin` - TF_FilesystemPluginInfo info; - memset(&info, 0, sizeof(info)); - auto TF_InitPlugin = - reinterpret_cast(dso_symbol); - TF_InitPlugin(&info); - - // Step 4: Ensure plugin provides the memory management functions. - TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info)); - - // Step 5: Validate and register all filesystems + // Validate and register all filesystems // Try to register as many filesystems as possible. // Free memory once we no longer need it Status status; - for (int i = 0; i < info.num_schemes; i++) { - status.Update(ValidateAndRegisterFilesystems(&info, i)); - info.plugin_memory_free(info.ops[i].scheme); - info.plugin_memory_free(info.ops[i].filesystem_ops); - info.plugin_memory_free(info.ops[i].random_access_file_ops); - info.plugin_memory_free(info.ops[i].writable_file_ops); - info.plugin_memory_free(info.ops[i].read_only_memory_region_ops); + for (int i = 0; i < info->num_schemes; i++) { + status.Update(ValidateAndRegisterFilesystems(info, i)); + info->plugin_memory_free(info->ops[i].scheme); + info->plugin_memory_free(info->ops[i].filesystem_ops); + info->plugin_memory_free(info->ops[i].random_access_file_ops); + info->plugin_memory_free(info->ops[i].writable_file_ops); + info->plugin_memory_free(info->ops[i].read_only_memory_region_ops); } - info.plugin_memory_free(info.ops); + info->plugin_memory_free(info->ops); return status; } diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h index 4df063d560c..5b1a7d40556 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h +++ b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h @@ -15,12 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_ #define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_ +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { namespace filesystem_registration { -Status RegisterFilesystemPluginImpl(const std::string& dso_path); +Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info); } // namespace filesystem_registration } // namespace tensorflow From ce298391eaa951a637d4155212e4927ea576e003 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 18 Mar 2020 17:55:57 -0700 Subject: [PATCH 243/492] Add a header file for the POSIX filesystem. This header only exposes `TF_InitPlugin` as we need to call this function when registering the local filesystems statically. --- .../plugins/posix/posix_filesystem.h | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h new file mode 100644 index 00000000000..b2f169cdeab --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_ + +// Initialize the POSIX filesystem. +// +// In general, the `TF_InitPlugin` symbol doesn't need to be exposed in a header +// file, since the plugin registration will look for the symbol in the DSO file +// that provides the filesystem functionality. However, the POSIX filesystem +// needs to be statically registered in some tests and utilities for building +// the API files at the time of creating the pip package. Hence, we need to +// expose this function so that this filesystem can be statically registered +// when needed. +void TF_InitPlugin(TF_FilesystemPluginInfo* info); + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_ From fe1d911498fac05155d93e72c11195591fb7d77b Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 18 Mar 2020 17:57:13 -0700 Subject: [PATCH 244/492] Add the `posix_filesystem.h` header to the `BUILD` target --- tensorflow/c/experimental/filesystem/plugins/posix/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD index 3707dafe518..395cbb75914 100644 --- a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD @@ -19,6 +19,7 @@ tf_cc_shared_object( cc_library( name = "posix_filesystem_impl", srcs = ["posix_filesystem.cc"], + hdrs = ["posix_filesystem.h"], deps = [ ":posix_filesystem_helper", "//tensorflow/c:tf_status", From d0259641873acc2dbe1e2d65c91d4a761bd1f42f Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 19 Mar 2020 10:33:20 -0700 Subject: [PATCH 245/492] Make visibility of modular_filesystem public We will need to depend on this to do the static registration of the filesystems. This is because a part of the registration is common between the plugin registration and the static registration. --- tensorflow/c/experimental/filesystem/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/c/experimental/filesystem/BUILD b/tensorflow/c/experimental/filesystem/BUILD index 602494aa087..46d094cecc4 100644 --- a/tensorflow/c/experimental/filesystem/BUILD +++ b/tensorflow/c/experimental/filesystem/BUILD @@ -27,6 +27,9 @@ cc_library( "modular_filesystem_registration.h", ], hdrs = ["modular_filesystem.h"], + # TODO(mihaimaruseac): Visibility should be more restrictive once we + # convert to modular filesystems everywhere + visibility = ["//visibility:public"], deps = [ ":filesystem_interface", "//tensorflow/c:tf_status_helper", From 9e1c010111a4ed05e2d6c34c7ed8bfe1482f25d4 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 19 Mar 2020 10:35:47 -0700 Subject: [PATCH 246/492] Add `posix_filesystem_static` target and source. This implements the static registration of a filesystem, without needing to load the plugin DSO. Simply link this target in the binary which needs the filesystem to have the static filesystem imported. --- .../filesystem/plugins/posix/BUILD | 13 ++++++++ .../plugins/posix/posix_filesystem_static.cc | 33 +++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_static.cc diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD index 395cbb75914..49a412dfb6a 100644 --- a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD @@ -27,6 +27,19 @@ cc_library( ], ) +# Since building pip package and API tests require a filesystem, we provide a +# static registration target that they should link against. +cc_library( + name = "posix_filesystem_static", + srcs = ["posix_filesystem_static.cc"], + visibility = ["//visibility:public"], + deps = [ + ":posix_filesystem_impl", + "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/c/experimental/filesystem:modular_filesystem", + ], +) + # Library implementing helper functionality, so that the above only contains # the API implementation for modular filesystems. cc_library( diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_static.cc b/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_static.cc new file mode 100644 index 00000000000..355cfce213b --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_static.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h" +#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h" + +namespace tensorflow { + +// Register the POSIX filesystems statically. +// Return value will be unused +Status StaticallyRegisterLocalFilesystems() { + TF_FilesystemPluginInfo info; + TF_InitPlugin(&info); + return filesystem_registration::RegisterFilesystemPluginImpl(&info); +} + +// Perform the actual registration +static Status unused = StaticallyRegisterLocalFilesystems(); + +} // namespace tensorflow From 6a15e2946748a20f7d23005e1a06fe594e072f8e Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 19 Mar 2020 11:03:15 -0700 Subject: [PATCH 247/492] Remove reference to forwarding header. PiperOrigin-RevId: 301852309 Change-Id: I1a04f85e28668c1fb454a2a2efe7edc494cbca32 --- .../kernels/data/experimental/matching_files_dataset_op.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index db2ab927993..9ba44aaf909 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include + +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" From afb11783a50af6e5a39cb3fbba10e7fdea34d1a3 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 19 Mar 2020 11:05:11 -0700 Subject: [PATCH 248/492] Standardize name scopes used during model construction in v2. PiperOrigin-RevId: 301852763 Change-Id: I3b4281f64ec4f3fe8e5f25901a1581ebc63de057 --- tensorflow/python/keras/engine/BUILD | 1 + tensorflow/python/keras/engine/base_layer.py | 14 ++++++++++- .../python/keras/engine/base_layer_test.py | 24 +++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 51666a0f34f..7555228c20f 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -111,6 +111,7 @@ py_library( ":base_layer_utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:constant_op", + "//tensorflow/python:tf2", "//tensorflow/python/data", "//tensorflow/python/distribute:distribute_coordinator", "//tensorflow/python/distribute:distribute_lib", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 8ae529fbfcb..5a37826d761 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -30,6 +30,7 @@ from six.moves import zip # pylint: disable=redefined-builtin from google.protobuf import json_format from tensorflow.core.framework import node_def_pb2 +from tensorflow.python import tf2 from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import distribution_strategy_context as ds_context @@ -2083,7 +2084,18 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._dtype_policy = policy.Policy(value) def _name_scope(self): - return self.name + if not tf2.enabled(): + return self.name + name_scope = self.name + current_name_scope = ops.get_name_scope() + if current_name_scope: + name_scope = current_name_scope + '/' + name_scope + if name_scope: + # Note that the trailing `/` prevents autogenerated + # numerical suffixes to get appended. It will also fully reset + # nested name scope (i.e. the outer name scope has no effect). + name_scope += '/' + return name_scope def _init_set_name(self, name, zero_based=True): if not name: diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 94766fe177a..1999f313d6b 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -936,6 +936,30 @@ class NameScopingTest(keras_parameterized.TestCase): self.assertEqual(layer.bias.name, 'MyName/bias:0') self.assertEqual(layer.kernel.name, 'MyName/kernel:0') + def test_name_scope_functional_api(self): + inputs = input_layer.Input((3,)) + layer = layers.Dense(10, name='MyName') + _ = layer(inputs) + self.assertEqual(layer.bias.name, 'MyName/bias:0') + self.assertEqual(layer.kernel.name, 'MyName/kernel:0') + + def test_name_scope_functional_api_nested(self): + + class NestedLayer(base_layer.Layer): + + def __init__(self, name='OuterName'): + super(NestedLayer, self).__init__(name=name) + self.dense = layers.Dense(10, name='InnerName') + + def call(self, inputs): + return self.dense(inputs) + + inputs = input_layer.Input((3,)) + layer = NestedLayer() + _ = layer(inputs) + self.assertEqual(layer.dense.bias.name, 'OuterName/InnerName/bias:0') + self.assertEqual(layer.dense.kernel.name, 'OuterName/InnerName/kernel:0') + def test_name_scope_sublayer(self): class NameScopeTracker(base_layer.Layer): From b148a36030486390771ea440774a91e1adaa85b7 Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Thu, 19 Mar 2020 11:17:52 -0700 Subject: [PATCH 249/492] Disable the thread sanitizer for some tests until they get fixed. PiperOrigin-RevId: 301855529 Change-Id: Ib42dd1820c36c1d6a365b6f5b2fb96fd515477d3 --- tensorflow/compiler/tests/BUILD | 10 ++++++++++ tensorflow/compiler/xrt/tests/BUILD | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 77cd3dc074c..3018fb5f857 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -346,6 +346,8 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + # TODO(b/151948649): Fails on 2020-03-19. + "notsan", ], deps = [ ":xla_test", @@ -912,6 +914,8 @@ tf_xla_py_test( shard_count = 10, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + # TODO(b/151948649): Fails on 2020-03-19. + "notsan", ], deps = [ ":xla_test", @@ -1548,6 +1552,8 @@ cuda_py_test( tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", + # TODO(b/151948649): Fails on 2020-03-19. + "notsan", ], xla_enable_strict_auto_jit = False, xla_enabled = True, @@ -1573,6 +1579,8 @@ cuda_py_test( tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", + # TODO(b/151948649): Fails on 2020-03-19. + "notsan", ], xla_enable_strict_auto_jit = False, xla_enabled = True, @@ -1764,6 +1772,8 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + # TODO(b/151948649): Fails on 2020-03-19. + "notsan", ], deps = [ ":xla_test", diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index 2f1faf1cdf1..918c802604a 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -58,6 +58,10 @@ tf_cc_test( "--xla_test_device=XLA_CPU", "--xla_platform=CPU", ], + tags = [ + # TODO(b/151948649): Fails on 2020-03-19. + "notsan", + ], deps = [ ":raw_api_test_lib", "//tensorflow/compiler/jit:xla_cpu_device", From 111879a5bc7a2317a92421b0c07c4be27b96b094 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 19 Mar 2020 12:43:10 -0700 Subject: [PATCH 250/492] Do not propagate parent name_scope in v2 control flow when inside v1 graph since that would break existing TF2 estimator models. PiperOrigin-RevId: 301874559 Change-Id: Ie42fdd40ce52ada1c4a1307a920625b489d5db67 --- tensorflow/python/framework/op_callbacks_test.py | 5 ++++- tensorflow/python/kernel_tests/cond_v2_test.py | 1 + .../kernel_tests/control_flow_ops_py_test.py | 15 +++++++++++---- tensorflow/python/kernel_tests/while_v2_test.py | 1 + .../python/ops/control_flow_v2_func_graphs.py | 16 ++++++++++------ 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/framework/op_callbacks_test.py b/tensorflow/python/framework/op_callbacks_test.py index f04d85bba21..14304536f65 100644 --- a/tensorflow/python/framework/op_callbacks_test.py +++ b/tensorflow/python/framework/op_callbacks_test.py @@ -632,7 +632,10 @@ class OpCallbacksTest(test_util.TensorFlowTestCase): greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP] self.assertEqual(len(greater_op_outputs), 1) self.assertAllClose(greater_op_outputs[0], False) - pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"] + # This was needed for backwards compatibility with TF2 Estimators which + # rely on variable names. + prefix = b"cond/" if context.executing_eagerly() else b"" + pow_op_outputs = instrument.graph_internal_ndarrays[b"%spow" % prefix] self.assertEqual(len(pow_op_outputs), 1) self.assertAllClose(pow_op_outputs[0], -64.0) diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index de8ea8d89d7..1682f2275c1 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -260,6 +260,7 @@ class CondV2Test(test.TestCase): self.assertRegexpMatches( cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*") + @test_util.run_v2_only def testInheritParentNameScope(self): @def_function.function diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 99fff136314..2533cf0a645 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -809,9 +809,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): return control_flow_ops.cond( pred, lambda: true_fn(inputs), lambda: false_fn(inputs)) + # This was needed for backwards compatibility with TF2 Estimators which + # rely on variable names. + prefix = "cond/" if context.executing_eagerly() else "" + with self.assertRaisesRegexp( ValueError, - "Tensor cond/true_branch:0 in true_fn is accessed from false_fn."): + "Tensor %strue_branch:0 in true_fn is accessed from false_fn." % + prefix): f() def testSwitchCaseAccessBranch1TensorInBranch4Raises(self): @@ -836,10 +841,12 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): [other_fn, lambda: br1_fn(inputs), other_fn, other_fn, lambda: br4_fn(inputs)]) + # This was needed for backwards compatibility with TF2 Estimators which + # rely on variable names. + prefix = "switch_case/indexed_case/" if context.executing_eagerly() else "" with self.assertRaisesRegexp( - ValueError, - "Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is " - "accessed from branch 4."): + ValueError, "Tensor %sbr1_identity:0 in branch 1 is " + "accessed from branch 4." % prefix): f() def testCondListOutput(self): diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 1fa6c179e7a..3f53f49fc30 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -1175,6 +1175,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): Fn() + @test_util.run_v2_only def testInheritParentNameScope(self): @def_function.function diff --git a/tensorflow/python/ops/control_flow_v2_func_graphs.py b/tensorflow/python/ops/control_flow_v2_func_graphs.py index 537ad2b4b8a..97e04f8d73d 100644 --- a/tensorflow/python/ops/control_flow_v2_func_graphs.py +++ b/tensorflow/python/ops/control_flow_v2_func_graphs.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import func_graph +from tensorflow.python.framework import ops class CondBranchFuncGraph(func_graph.FuncGraph): @@ -29,8 +30,9 @@ class CondBranchFuncGraph(func_graph.FuncGraph): def __init__(self, *args, **kwargs): super(CondBranchFuncGraph, self).__init__(*args, **kwargs) - func_graph.override_func_graph_name_scope(self, - self.outer_graph.get_name_scope()) + if ops.executing_eagerly_outside_functions(): + func_graph.override_func_graph_name_scope( + self, self.outer_graph.get_name_scope()) class WhileCondFuncGraph(func_graph.FuncGraph): @@ -41,8 +43,9 @@ class WhileCondFuncGraph(func_graph.FuncGraph): def __init__(self, *args, **kwargs): super(WhileCondFuncGraph, self).__init__(*args, **kwargs) - func_graph.override_func_graph_name_scope(self, - self.outer_graph.get_name_scope()) + if ops.executing_eagerly_outside_functions(): + func_graph.override_func_graph_name_scope( + self, self.outer_graph.get_name_scope()) class WhileBodyFuncGraph(func_graph.FuncGraph): @@ -53,5 +56,6 @@ class WhileBodyFuncGraph(func_graph.FuncGraph): def __init__(self, *args, **kwargs): super(WhileBodyFuncGraph, self).__init__(*args, **kwargs) - func_graph.override_func_graph_name_scope(self, - self.outer_graph.get_name_scope()) + if ops.executing_eagerly_outside_functions(): + func_graph.override_func_graph_name_scope( + self, self.outer_graph.get_name_scope()) From b65b2404745f64d5aa38a9565c2f71a3fe52c0ca Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Thu, 19 Mar 2020 13:07:34 -0700 Subject: [PATCH 251/492] Update sparse_tensor model. PiperOrigin-RevId: 301879775 Change-Id: I5a82a592f80f2af6624328a3596f27cf9a8c9959 --- tensorflow/lite/model_test.cc | 2 +- tensorflow/lite/testdata/sparse_tensor.bin | Bin 508 -> 592 bytes tensorflow/lite/testdata/sparse_tensor.json | 10 +++++++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc index b9efdf676a8..9a2c68e7c91 100644 --- a/tensorflow/lite/model_test.cc +++ b/tensorflow/lite/model_test.cc @@ -386,7 +386,7 @@ TEST(BasicFlatBufferModel, TestParseModelWithSparseTensor) { ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter), kTfLiteOk); ASSERT_NE(interpreter, nullptr); - ASSERT_EQ(interpreter->tensors_size(), 1); + ASSERT_EQ(interpreter->tensors_size(), 2); TfLiteTensor* t1 = interpreter->tensor(0); ASSERT_EQ(t1->allocation_type, kTfLiteMmapRo); diff --git a/tensorflow/lite/testdata/sparse_tensor.bin b/tensorflow/lite/testdata/sparse_tensor.bin index c035e02441d68e6cb6dca9bfbbf47c2924335d0a..ef02328088720ccfb1946e9a798e306233bfd5eb 100644 GIT binary patch delta 196 zcmeyve1RoZg@J(~#LdSTNFf0}1`!4p1`Y-upqKyyACP1Q;y;WG3>H9|1&En|SOd!L z`uG1oNDQPJsDy@4hD8G f3uGV=aLND)kOjOzoL^c}P+Bs1AEVgDBNmJRg+di@ delta 128 zcmcb>@`pK4f`Ne{#LdT;fq{=fgn@;DgMkMqA^_wuFfar07DfgJ10c-;#7sb}0A*kL u_y0dg45S*UgoO#Hn3aKnZOOz`f5rbm02E|^(Lg3h5g34IiOr&n){Fprk`vYd diff --git a/tensorflow/lite/testdata/sparse_tensor.json b/tensorflow/lite/testdata/sparse_tensor.json index d23c0d0a64b..3c6a742a4e8 100644 --- a/tensorflow/lite/testdata/sparse_tensor.json +++ b/tensorflow/lite/testdata/sparse_tensor.json @@ -40,11 +40,19 @@ } ] } + }, + { + "shape": [ + 4, + 4 + ], + "name": "output_tensor", + "type": "INT8" } ], "inputs": [0], "outputs": [0], - "operators": [{"inputs":[-1], "outputs":[-1]}] + "operators": [{"inputs":[-1], "outputs":[1]}] } ], "buffers": [ From f13d9b12f9a7aeef1cfa5cf460cd56b212f110e6 Mon Sep 17 00:00:00 2001 From: Karim Nosir Date: Thu, 19 Mar 2020 13:09:11 -0700 Subject: [PATCH 252/492] Add experimental directory under mlir/lite and add new interface for hardwares. This is experimental and can be deleted at any time. PiperOrigin-RevId: 301880113 Change-Id: I7e7df4e83d3219b36712626b6afbc0bf72a79d41 --- tensorflow/compiler/mlir/lite/BUILD | 3 + .../lite/experimental/estimators/estimator.h | 56 ++++++++++++++ .../experimental/tfl_hardware_interfaces.td | 76 +++++++++++++++++++ .../mlir/lite/ir/tfl_op_interfaces.td | 1 + tensorflow/compiler/mlir/lite/ir/tfl_ops.h | 1 + tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 5 +- 6 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h create mode 100644 tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index c917af71f92..03cf9265f3b 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -26,6 +26,7 @@ package_group( filegroup( name = "tensorflow_lite_ops_td_files", srcs = [ + "experimental/tfl_hardware_interfaces.td", "ir/tfl_op_interfaces.td", "ir/tfl_ops.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", @@ -204,6 +205,7 @@ cc_library( cc_library( name = "tensorflow_lite", srcs = [ + "experimental/estimators/estimator.h", "ir/tfl_ops.cc", "ir/tfl_ops.cc.inc", "ir/tfl_ops.h.inc", @@ -439,6 +441,7 @@ genrule( srcs = [ "ir/tfl_ops.td", "ir/tfl_op_interfaces.td", + "experimental/tfl_hardware_interfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h b/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h new file mode 100644 index 00000000000..26f6b0f3428 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_ + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" + +namespace hardware { +// Empty classes that represents hardware types. +class CPU {}; +class GPU {}; +} // namespace hardware + +template +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { + llvm::errs() << "No defined support for op: " + << op->getName().getStringRef().str(); + return false; + } +}; + +// All ops on CPU are supported. +// TODO(karimnosseir): Only allow TFL ops in the "TFL_OP" param. +template +class TFLiteCostEstimator { + public: + // TODO(karimnosseir): Update and use table based method and lookup + // cost from a loadable table ? + static double GetCost(mlir::Operation* op) { return 0.0; } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td b/tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td new file mode 100644 index 00000000000..5c3ec6c206c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// WARNING: This Interface is experimental, DO NOT USE. + +// This is the Target Hardware operation interfacea definition file +// for TensorFlow Lite. + +#ifndef TFL_TARGET_HARDWARE_OP_INTERFACES +#define TFL_TARGET_HARDWARE_OP_INTERFACES + +def TFL_CpuTargetOp : OpInterface<"CpuOpTargetInterface"> { + let description = [{ + Interface for ops to run on CPU. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the cost of running this op on CPU.}], + // TODO(karimnosseir): Change to return Cost object instead. + "double", "GetOpCost", (ins "mlir::Operation*":$op_to_check), [{ + // TODO(karimnosseir): Consider changing to another way that doesn't + // rely on template param name. + return TFL::TFLiteCostEstimator::GetCost(op_to_check); + }] + >, + InterfaceMethod< + [{Returns whether this op can be run on CPU.}], + "bool", "IsSupported", (ins "mlir::Operation*":$op_to_check), [{ + // TODO(karimnosseir): Consider changing to another way that doesn't + // rely on template param name. + return TFL::TFLiteCostEstimator::IsSupported(op_to_check); + }] + >, + ]; +} + +def TFL_GpuTargetOp : OpInterface<"GpuOpTargetInterface"> { + let description = [{ + Interface for ops to run on GPU. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the cost of running this op on GPU.}], + // TODO(karimnosseir): Change to return Cost object instead. + "double", "GetOpCost", (ins "Operation*":$op_to_check), [{ + // TODO(karimnosseir): Consider changing to another way that doesn't + // rely on template param name. + return TFL::TFLiteCostEstimator::GetCost(op_to_check); + }] + >, + InterfaceMethod< + [{Returns whether this op can be run on GPU.}], + "bool", "IsSupported", (ins "Operation*":$op_to_check), [{ + // TODO(karimnosseir): Consider changing to another way that doesn't + // rely on template param name. + return TFL::TFLiteCostEstimator::IsSupported(op_to_check); + }] + >, + ]; +} + +#endif // TFL_TARGET_HARDWARE_OP_INTERFACES diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index 8e100538659..db0bef39358 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -19,6 +19,7 @@ limitations under the License. #define TFL_OP_INTERFACES include "mlir/IR/OpBase.td" +include "tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td" //===----------------------------------------------------------------------===// // TFL op interface for stateful operands. diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index ffdafc1844f..a9b89c2bb64 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -49,6 +49,7 @@ class TensorFlowLiteDialect : public Dialect { Location loc) override; }; +#include "tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 96eb69f7c8f..53bec976186 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -285,7 +285,10 @@ def TFL_ComparisonBinaryBuilder : OpBuilder< class TFL_Op traits = []> : Op])> { + [DeclareOpInterfaceMethods, + // All TFL ops are supported on CPU. + DeclareOpInterfaceMethods + ])> { // FlatBuffer generation specific information. // ------------------------------------------- // When generating the FlatBuffer output some operations have From 71313bcb180f2220adee93d964bc4721ad6bf575 Mon Sep 17 00:00:00 2001 From: Anna R Date: Thu, 19 Mar 2020 13:12:52 -0700 Subject: [PATCH 253/492] Place keras API tree at the root of pip package to get autocomplete working for deeper imports (for e.g. from tensorflow.keras.losses import KLD). PiperOrigin-RevId: 301880779 Change-Id: I332e0570750dfdd61712688a4f327f7993019520 --- tensorflow/tools/pip_package/build_pip_package.sh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 090ce22b718..7a070938045 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -208,16 +208,18 @@ function prepare_src() { rm -f ${TMPDIR}/tensorflow/libtensorflow_framework.so rm -f ${TMPDIR}/tensorflow/libtensorflow_framework.so.[0-9].* - # Create a keras/__init__.pyi file so that autocomplete for imports - # such as `from tensorflow.keras import losses` works. # TODO(annarev): copy over API files from tensorflow/api/_vN to tensorflow/ # except tensorflow/api/_vN/lite/. - mkdir ${TMPDIR}/tensorflow/keras/ + + # Copy over keras API folder to the root directory + # so that autocomplete works as expected for all keras subimports. if [ -d "${TMPDIR}/tensorflow/_api/v1/" ] then - echo "from tensorflow.python.keras.api._v1.keras import *" > ${TMPDIR}/tensorflow/keras/__init__.pyi + cp -r ${TMPDIR}/tensorflow/python/keras/api/_v1/keras/ ${TMPDIR}/tensorflow/keras/ + sed -i'.original' -e 's/.python.keras.api._v1/tensorflow/g' ${TMPDIR}/tensorflow/__init__.py else - echo "from tensorflow.python.keras.api._v2.keras import *" > ${TMPDIR}/tensorflow/keras/__init__.pyi + cp -r ${TMPDIR}/tensorflow/python/keras/api/_v2/keras/ ${TMPDIR}/tensorflow/keras/ + sed -i'.original' -e 's/.python.keras.api._v2/tensorflow/g' ${TMPDIR}/tensorflow/__init__.py fi } From 0c53d830010ba1a75f77cfbe431f2df926e6a2c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 13:21:09 -0700 Subject: [PATCH 254/492] Supported tiled input sharding for model parallelism. PiperOrigin-RevId: 301882389 Change-Id: I7da33977da3f05881f7dc66742cd0fd9cb89d358 --- .../mlir/tensorflow/tests/tpu_rewrite.mlir | 301 ++++++++++++++++++ .../tensorflow/transforms/tpu_rewrite_pass.cc | 26 +- .../tensorflow/utils/xla_sharding_util.cc | 173 +++++++++- .../mlir/tensorflow/utils/xla_sharding_util.h | 8 +- 4 files changed, 480 insertions(+), 28 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 7ee20d23df3..f6eb08bb58c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1395,3 +1395,304 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc return %1, %3 : tensor<*xi32>, tensor<*xi1> } } + +// ----- + +// Tests inputs are correctly split and fed into TPU computation for +// tiled input sharding. + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_input + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func @parallel_execute_with_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[CONST_SPLIT_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[RI_0]]) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1) + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: device = "TPU_REPLICATED_CORE_1" + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 4 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" +// +// ----- + + +// Tests tile sharding of inputs with number of splits that does not evenly divide +// the input results in an error. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + +// The following topology is used in subsequent test cases: +// Proto debug string: +// mesh_shape: 2 +// mesh_shape: 1 +// mesh_shape: 2 +// num_tasks: 2 +// num_tpu_devices_per_task: 2 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_shape { +// element_type: F32 +// dimensions: 2 +// dimensions: 2 +// layout { +// minor_to_major: 1 +// minor_to_major: 0 +// format: DENSE +// } +// is_dynamic_dimension: false +// is_dynamic_dimension: false +// } +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// Tests inputs to TPUComputation that are tiled in multiple dimensions. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_multi_dimension_tiled_input + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func @parallel_execute_with_multi_dimension_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]]) + // CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0) + // CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#1) + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[COMPILE]]#3) + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + + +// ----- + +// Tests inputs device assignment order is well preserved for tiled input sharding. + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_shape { +// element_type: F32 +// dimensions: 2 +// dimensions: 2 +// layout { +// minor_to_major: 1 +// minor_to_major: 0 +// format: DENSE +// } +// is_dynamic_dimension: false +// is_dynamic_dimension: false +// } +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 3 +// tile_assignment_devices: 2 +// tile_assignment_devices: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00" +// +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" +// +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @tiled_input_sharding_with_device_assignment_order + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func @tiled_input_sharding_with_device_assignment_order(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]]) + // CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0) + // CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#1) + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[RI_1]], %[[COMPILE]]#2) + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[COMPILE]]#3) + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#4) + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 50b6555076d..e20e78a243c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -488,11 +488,11 @@ Operation* BuildExecuteOp( // Creates a tf_device.parallel_execute op that wraps TPUExecute op to // represent execution of TPU program in multiple logical cores. -tf_device::ParallelExecuteOp BuildParallelExecuteOp( +LogicalResult BuildParallelExecuteOp( llvm::ArrayRef> execution_devices, llvm::ArrayRef output_sharding_config, Operation* compile_op, tf_device::LaunchFuncOp launch_func, - OpBuilder* builder) { + OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) { const int num_cores_per_replica = execution_devices.front().size(); // parallel_execute op returns concatenated list of return values of // all its regions. @@ -510,20 +510,23 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp( for (Type t : output_types) concatenated_output_types.emplace_back(t); } - auto parallel_execute_op = builder->create( + *parallel_execute_op = builder->create( launch_func.getLoc(), num_cores_per_replica, concatenated_output_types); // Extract inputs for each region of the parallel_execute op. The i-th // element in the list represents the input lists to TPU computation for // i-th logical core. - auto input_list = tensorflow::ExtractInputsForLogicalDevices( - num_cores_per_replica, launch_func); + llvm::SmallVector, 4> input_list; + builder->setInsertionPoint(*parallel_execute_op); + auto result = tensorflow::ExtractInputsForLogicalDevices( + num_cores_per_replica, launch_func, builder, &input_list); + if (failed(result)) return failure(); const bool replicated = execution_devices.size() != 1; // For each logical core, create a region with TPUExecute op. assert(input_list.size() == num_cores_per_replica); for (int core = 0; core < num_cores_per_replica; ++core) { - auto& region = parallel_execute_op.GetRegionBlockWithIndex(core); + auto& region = parallel_execute_op->GetRegionBlockWithIndex(core); builder->setInsertionPointToEnd(®ion); // Create Execute op. @@ -551,7 +554,7 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp( region_launch_op.getResults()); } - return parallel_execute_op; + return success(); } tf_device::LaunchOp AssignDevicesToReplicatedExecute( @@ -703,9 +706,12 @@ LogicalResult Rewrite( if (num_cores_per_replica > 1) { // For model parallelism, tf_device.parallel_execute is used to express // concurrent device execution across multiple logical devices. - tf_device::ParallelExecuteOp execute_op = BuildParallelExecuteOp( - tpu_device_assignment.execution_devices, output_shardings, compile_op, - launch_func, builder); + + tf_device::ParallelExecuteOp execute_op; + result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices, + output_shardings, compile_op, launch_func, + builder, &execute_op); + if (failed(result)) return failure(); // As tf_device.parallel_execute wraps # logical cores number of TPUExecute // ops, the number of return values of parallel_execute op exceeds that of diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index bbe91054b3b..bcf6e1b3496 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -16,10 +16,19 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -36,14 +45,135 @@ llvm::Optional ParseShardingAttribute( return sharding_attr.getValue(); } -llvm::SmallVector, 4> -ExtractInputsForLogicalDevices(int num_logical_cores, - mlir::tf_device::LaunchFuncOp launch_func) { +namespace { + +constexpr char kNumSplitAttr[] = "num_split"; + +// Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways +// in 'split_dimension' dimension and returns the split values. +mlir::LogicalResult CreateSplitOp(const int num_split, + const int split_dimension, + const mlir::Location& location, + mlir::Value src_input, + mlir::OpBuilder* builder, + mlir::TF::SplitOp* split_op) { + // Creates a const op to hold split dimension value. + auto split_dim_type = + mlir::RankedTensorType::get({}, builder->getIntegerType(32)); + auto split_dimension_attr = + mlir::DenseElementsAttr::get(split_dim_type, split_dimension); + auto split_dimension_op = builder->create( + location, split_dim_type, split_dimension_attr); + + // Correctly set output shapes of split op output if input shape is statically + // known. + mlir::Type output_type; + auto input_type = src_input.getType().cast(); + + if (input_type.hasRank()) { + if (input_type.getShape()[split_dimension] == + mlir::ShapedType::kDynamicSize) { + output_type = input_type; + } else { + auto shape = llvm::to_vector<4>(input_type.getShape()); + if (shape[split_dimension] % num_split != 0) { + return mlir::emitError( + location, + llvm::formatv( + "incorrect input sharding configuration received. " + "{0}-th dimension of the input must be evenly divisible by {1}", + split_dimension, num_split)); + } + + shape[split_dimension] = shape[split_dimension] / num_split; + output_type = + mlir::RankedTensorType::get(shape, input_type.getElementType()); + } + } else { + output_type = input_type; + } + + // Creates a split op that splits |src_input| along |split_dimension|. + llvm::SmallVector output_types(num_split, output_type); + *split_op = builder->create( + location, output_types, split_dimension_op.output(), src_input); + split_op->setAttr(kNumSplitAttr, builder->getIntegerAttr( + builder->getIntegerType(32), num_split)); + return mlir::success(); +} + +// For tile sharded inputs to TPU computation, inject split op between the +// input values and TPU computation so that tiled input values are passed in +// as inputs to TPU computations. If more than one dimension is sharded, then +// a tree of connected split ops are added before tf_device.parallel_execute op. +mlir::LogicalResult HandleTileShardedInputs( + const mlir::Location& location, const xla::OpSharding& input_sharding, + const mlir::Value& original_source, mlir::OpBuilder* builder, + llvm::SmallVectorImpl* tiled_inputs) { + llvm::SmallVector split_ops_for_tiled_input; + split_ops_for_tiled_input.reserve( + input_sharding.tile_assignment_devices_size()); + + // Creates a tree of split nodes for sharding tiled inputs. Splits nodes + // are created such that input data is sharded in row major order. + // Split nodes at ith depth from the original input node represent nodes + // that split the input data at i-th dimension. + const auto& dimension_splits = input_sharding.tile_assignment_dimensions(); + for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) { + const int num_splits = num_splits_and_index.value(); + const int dimension_index = num_splits_and_index.index(); + if (num_splits == 1) continue; + + // Creates root split op. + if (split_ops_for_tiled_input.empty()) { + mlir::TF::SplitOp root_split_op; + auto result = CreateSplitOp(num_splits, dimension_index, location, + original_source, builder, &root_split_op); + if (mlir::failed(result)) return mlir::failure(); + + split_ops_for_tiled_input.emplace_back(root_split_op); + continue; + } + + llvm::SmallVector new_split_ops; + new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits); + + for (auto split_op : split_ops_for_tiled_input) { + for (auto parent_split_output_value : split_op.getResults()) { + mlir::TF::SplitOp child_split_op; + auto result = + CreateSplitOp(num_splits, dimension_index, location, + parent_split_output_value, builder, &child_split_op); + if (mlir::failed(result)) return mlir::failure(); + + new_split_ops.emplace_back(child_split_op); + } + } + + std::swap(new_split_ops, split_ops_for_tiled_input); + } + + // `split_ops_for_tiled_input` now includes final split nodes + // from which sharded data will be fed into TPUExcute ops -- sorted by + // row major order. + tiled_inputs->reserve(input_sharding.tile_assignment_devices_size()); + for (auto split_op : split_ops_for_tiled_input) + tiled_inputs->append(split_op.getResults().begin(), + split_op.getResults().end()); + + return mlir::success(); +} + +} // namespace + +mlir::LogicalResult ExtractInputsForLogicalDevices( + int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func, + mlir::OpBuilder* builder, + llvm::SmallVectorImpl>* input_list) { // Initialize the input list for each logical devices. - llvm::SmallVector, 4> input_list; - input_list.reserve(num_logical_cores); + input_list->reserve(num_logical_cores); for (int i = 0; i < num_logical_cores; ++i) - input_list.emplace_back(llvm::SmallVector()); + input_list->emplace_back(llvm::SmallVector()); llvm::SmallVector launch_func_inputs( launch_func.getOperands()); @@ -53,8 +183,8 @@ ExtractInputsForLogicalDevices(int num_logical_cores, // If sharding attribute does not exist, then all inputs are placed on 0th // logical core by default. if (!sharding_attrs) { - input_list[0] = launch_func_inputs; - return input_list; + (*input_list)[0] = launch_func_inputs; + return mlir::success(); } // Enumerate sharding configuration for each inputs. If input has replicate @@ -71,19 +201,32 @@ ExtractInputsForLogicalDevices(int num_logical_cores, sharding_attr.cast().getValue().str()); const auto input_sharing_type = sharding.type(); - if (input_sharing_type == xla::OpSharding::OTHER) - launch_func.emitError( - "tiled inputs are not yet supported for model parallelism"); + if (input_sharing_type == xla::OpSharding::OTHER) { + llvm::SmallVector tiled_inputs; + auto result = HandleTileShardedInputs( + launch_func.getLoc(), sharding, input_value, builder, &tiled_inputs); + if (mlir::failed(result)) return mlir::failure(); - if (input_sharing_type == xla::OpSharding::REPLICATED) { - for (auto inputs : input_list) inputs.emplace_back(input_value); + if (tiled_inputs.size() != num_logical_cores) + launch_func.emitError(llvm::formatv( + "incorrect {0}-th tiled input sharding received. " + "Product of tile sharding splits({1}) must be equal to " + "number of logical devices : {2}", + input_index, tiled_inputs.size(), num_logical_cores)); + + for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) { + const int assigned_logical_device = sharding.tile_assignment_devices(i); + (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]); + } + } else if (input_sharing_type == xla::OpSharding::REPLICATED) { + for (auto inputs : *input_list) inputs.emplace_back(input_value); } else { assert(input_sharing_type == xla::OpSharding::MAXIMAL); const int logical_device_id = sharding.tile_assignment_devices(0); - input_list[logical_device_id].emplace_back(input_value); + (*input_list)[logical_device_id].emplace_back(input_value); } } - return input_list; + return mlir::success(); } mlir::LogicalResult ParseAndValidateOutputSharding( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 4f548ca95aa..f7a9dbf2c81 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project @@ -38,9 +39,10 @@ llvm::Optional ParseShardingAttribute( // i-th element is a list of mlir::Value's which represent inputs for the // TPU computation correponding to i-th logical device. If the attribute // does not exist, the all inputs are placed on logical core 0. -llvm::SmallVector, 4> -ExtractInputsForLogicalDevices(int num_logical_cores, - mlir::tf_device::LaunchFuncOp launch_func); +mlir::LogicalResult ExtractInputsForLogicalDevices( + int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func, + mlir::OpBuilder* builder, + llvm::SmallVectorImpl>* input_list); // Extracts a list of OpSharding that represent output sharding configuration // of `tf_device.launch`. From f29c62f405fec7b6b8d2ef518e4918857cfaca5e Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Thu, 19 Mar 2020 13:26:23 -0700 Subject: [PATCH 255/492] Regenerated wrapper includes for all CUDA versions & libraries. PiperOrigin-RevId: 301883437 Change-Id: I60eb5e45b6eec404c0694a95e091b5e17dd02585 --- .../hexagon/hexagon_nn/hexagon_nn_init.cc | 52 + tensorflow/stream_executor/cuda/BUILD | 9 +- .../stream_executor/cuda/cublas_10_2.inc | 5220 +++++++++++ .../stream_executor/cuda/cublas_stub.cc | 11 +- tensorflow/stream_executor/cuda/cuda_10_0.inc | 301 + tensorflow/stream_executor/cuda/cuda_10_1.inc | 2166 +++++ tensorflow/stream_executor/cuda/cuda_10_2.inc | 2328 +++++ tensorflow/stream_executor/cuda/cuda_9_0.inc | 1718 ++++ tensorflow/stream_executor/cuda/cuda_blas.cc | 27 +- .../cuda/cuda_runtime_10_0.inc | 317 + .../cuda/cuda_runtime_10_2.inc | 1896 ++++ tensorflow/stream_executor/cuda/cuda_stub.cc | 15 +- .../stream_executor/cuda/cudart_stub.cc | 10 +- .../stream_executor/cuda/cufft_10_0.inc | 1 + tensorflow/stream_executor/cuda/cufft_9_0.inc | 307 + tensorflow/stream_executor/cuda/cufft_stub.cc | 5 + .../cuda/cusolver_dense_10_2.inc | 3677 ++++++++ .../cuda/cusolver_dense_9_0.inc | 2185 +++++ .../stream_executor/cuda/cusolver_stub.cc | 13 +- .../stream_executor/cuda/cusparse_10_1.inc | 32 - .../stream_executor/cuda/cusparse_10_2.inc | 8226 +++++++++++++++++ .../stream_executor/cuda/cusparse_9_0.inc | 9 +- .../stream_executor/cuda/cusparse_stub.cc | 11 +- 23 files changed, 28474 insertions(+), 62 deletions(-) create mode 100644 tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.cc create mode 100644 tensorflow/stream_executor/cuda/cublas_10_2.inc create mode 100644 tensorflow/stream_executor/cuda/cuda_10_1.inc create mode 100644 tensorflow/stream_executor/cuda/cuda_10_2.inc create mode 100644 tensorflow/stream_executor/cuda/cuda_9_0.inc create mode 100644 tensorflow/stream_executor/cuda/cuda_runtime_10_2.inc create mode 100644 tensorflow/stream_executor/cuda/cufft_9_0.inc create mode 100644 tensorflow/stream_executor/cuda/cusolver_dense_10_2.inc create mode 100644 tensorflow/stream_executor/cuda/cusolver_dense_9_0.inc create mode 100644 tensorflow/stream_executor/cuda/cusparse_10_2.inc diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.cc b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.cc new file mode 100644 index 00000000000..d1607c4a6d7 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.cc @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h" + +#include +#include +#include +#include + +#include "remote.h" // NOLINT +#include "rpcmem.h" // NOLINT +#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/soc_model.h" + +#ifdef __cplusplus +extern "C" { +#endif +// Version 1.14 +static const int kHexagonNNVersion = 136193; +#pragma weak remote_handle_control // Declare it as a weak symbol +void hexagon_nn_global_init() { + rpcmem_init(); + // Non-domains QoS invocation + struct remote_rpc_control_latency data; + data.enable = 1; + if (remote_handle_control) { // Check if API is available before invoking + remote_handle_control(DSPRPC_CONTROL_LATENCY, (void*)&data, sizeof(data)); + } +} + +void hexagon_nn_global_teardown() { rpcmem_deinit(); } + +bool hexagon_nn_is_device_supported() { + return tflite::delegates::getsoc_model().mode != UNSPECIFIED_MODE; +} + +int hexagon_nn_hexagon_interface_version() { return kHexagonNNVersion; } + +#ifdef __cplusplus +} +#endif diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index 1789abadde8..1457a36beaf 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -1,7 +1,7 @@ # Description: # CUDA-platform specific StreamExecutor support code. -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts", "tf_cuda_cc_test") load( "//tensorflow/stream_executor:build_defs.bzl", "stream_executor_friends", @@ -9,7 +9,6 @@ load( "tf_additional_cuda_platform_deps", "tf_additional_cudnn_plugin_deps", ) -load("//tensorflow:tensorflow.bzl", "tf_copts") load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -90,7 +89,7 @@ cc_library( cc_library( name = "cuda_stub", srcs = if_cuda_is_configured(["cuda_stub.cc"]), - textual_hdrs = ["cuda_10_0.inc"], + textual_hdrs = glob(["cuda_*.inc"]), deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tensorflow/stream_executor/lib", @@ -271,7 +270,7 @@ cc_library( cc_library( name = "cufft_stub", srcs = if_cuda_is_configured(["cufft_stub.cc"]), - textual_hdrs = ["cufft_10_0.inc"], + textual_hdrs = glob(["cufft_*.inc"]), deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tensorflow/stream_executor/lib", @@ -426,7 +425,7 @@ cc_library( cc_library( name = "cusolver_stub", srcs = if_cuda_is_configured(["cusolver_stub.cc"]), - textual_hdrs = ["cusolver_dense_10_0.inc"], + textual_hdrs = glob(["cusolver_dense_*.inc"]), deps = if_cuda_is_configured([ # LINT.IfChange "@local_config_cuda//cuda:cublas_headers", diff --git a/tensorflow/stream_executor/cuda/cublas_10_2.inc b/tensorflow/stream_executor/cuda/cublas_10_2.inc new file mode 100644 index 00000000000..42c4e5fef3b --- /dev/null +++ b/tensorflow/stream_executor/cuda/cublas_10_2.inc @@ -0,0 +1,5220 @@ +// Auto-generated, do not edit. + +extern "C" { + +cublasStatus_t CUBLASWINAPI cublasCreate_v2 (cublasHandle_t *handle) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t *); + static auto func_ptr = LoadSymbol("cublasCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cublasStatus_t CUBLASWINAPI cublasDestroy_v2 (cublasHandle_t handle) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t); + static auto func_ptr = LoadSymbol("cublasDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cublasStatus_t CUBLASWINAPI cublasGetVersion_v2(cublasHandle_t handle, int *version) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int *); + static auto func_ptr = LoadSymbol("cublasGetVersion_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, version); +} + +cublasStatus_t CUBLASWINAPI cublasGetProperty(libraryPropertyType type, int *value) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol("cublasGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +size_t CUBLASWINAPI cublasGetCudartVersion(void) { + using FuncPtr = size_t(CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol("cublasGetCudartVersion"); + if (!func_ptr) LogFatalSymbolNotFound("cublasGetCudartVersion"); + return func_ptr(); +} + +cublasStatus_t CUBLASWINAPI cublasSetStream_v2 (cublasHandle_t handle, cudaStream_t streamId) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cublasSetStream_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cublasStatus_t CUBLASWINAPI cublasGetStream_v2 (cublasHandle_t handle, cudaStream_t *streamId) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cudaStream_t *); + static auto func_ptr = LoadSymbol("cublasGetStream_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cublasStatus_t CUBLASWINAPI cublasGetPointerMode_v2 (cublasHandle_t handle, cublasPointerMode_t *mode) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t *); + static auto func_ptr = LoadSymbol("cublasGetPointerMode_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasSetPointerMode_v2 (cublasHandle_t handle, cublasPointerMode_t mode) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t); + static auto func_ptr = LoadSymbol("cublasSetPointerMode_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasGetAtomicsMode(cublasHandle_t handle, cublasAtomicsMode_t *mode) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t *); + static auto func_ptr = LoadSymbol("cublasGetAtomicsMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasSetAtomicsMode(cublasHandle_t handle, cublasAtomicsMode_t mode) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t); + static auto func_ptr = LoadSymbol("cublasSetAtomicsMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasGetMathMode(cublasHandle_t handle, cublasMath_t *mode) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasMath_t *); + static auto func_ptr = LoadSymbol("cublasGetMathMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasSetMathMode(cublasHandle_t handle, cublasMath_t mode) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasMath_t); + static auto func_ptr = LoadSymbol("cublasSetMathMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasLoggerConfigure(int logIsOn, int logToStdOut, + int logToStdErr, + const char *logFileName) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, int, const char *); + static auto func_ptr = LoadSymbol("cublasLoggerConfigure"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(logIsOn, logToStdOut, logToStdErr, logFileName); +} + +cublasStatus_t CUBLASWINAPI +cublasSetLoggerCallback(cublasLogCallback userCallback) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLogCallback); + static auto func_ptr = LoadSymbol("cublasSetLoggerCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(userCallback); +} + +cublasStatus_t CUBLASWINAPI +cublasGetLoggerCallback(cublasLogCallback *userCallback) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLogCallback *); + static auto func_ptr = LoadSymbol("cublasGetLoggerCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(userCallback); +} + +cublasStatus_t CUBLASWINAPI cublasSetVector (int n, int elemSize, const void *x, + int incx, void *devicePtr, int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int); + static auto func_ptr = LoadSymbol("cublasSetVector"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, x, incx, devicePtr, incy); +} + +cublasStatus_t CUBLASWINAPI cublasGetVector (int n, int elemSize, const void *x, + int incx, void *y, int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int); + static auto func_ptr = LoadSymbol("cublasGetVector"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSetMatrix (int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int); + static auto func_ptr = LoadSymbol("cublasSetMatrix"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rows, cols, elemSize, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasGetMatrix (int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int); + static auto func_ptr = LoadSymbol("cublasGetMatrix"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rows, cols, elemSize, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasSetVectorAsync (int n, int elemSize, + const void *hostPtr, int incx, + void *devicePtr, int incy, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int, cudaStream_t); + static auto func_ptr = LoadSymbol("cublasSetVectorAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, hostPtr, incx, devicePtr, incy, stream); +} + +cublasStatus_t CUBLASWINAPI cublasGetVectorAsync (int n, int elemSize, + const void *devicePtr, int incx, + void *hostPtr, int incy, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int, cudaStream_t); + static auto func_ptr = LoadSymbol("cublasGetVectorAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, devicePtr, incx, hostPtr, incy, stream); +} + +cublasStatus_t CUBLASWINAPI cublasSetMatrixAsync (int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb, cudaStream_t stream) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int, cudaStream_t); + static auto func_ptr = LoadSymbol("cublasSetMatrixAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rows, cols, elemSize, A, lda, B, ldb, stream); +} + +cublasStatus_t CUBLASWINAPI cublasGetMatrixAsync (int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb, cudaStream_t stream) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int, cudaStream_t); + static auto func_ptr = LoadSymbol("cublasGetMatrixAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rows, cols, elemSize, A, lda, B, ldb, stream); +} + +void CUBLASWINAPI cublasXerbla (const char *srName, int info) { + using FuncPtr = void (CUBLASWINAPI *)(const char *, int); + static auto func_ptr = LoadSymbol("cublasXerbla"); + if (!func_ptr) LogFatalSymbolNotFound("cublasXerbla"); + return func_ptr(srName, info); +} + +cublasStatus_t CUBLASWINAPI cublasNrm2Ex(cublasHandle_t handle, + int n, + const void *x, + cudaDataType xType, + int incx, + void *result, + cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol("cublasNrm2Ex"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, result, resultType, executionType); +} + +cublasStatus_t CUBLASWINAPI cublasSnrm2_v2(cublasHandle_t handle, + int n, + const float *x, + int incx, + float *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *); + static auto func_ptr = LoadSymbol("cublasSnrm2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDnrm2_v2(cublasHandle_t handle, + int n, + const double *x, + int incx, + double *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *); + static auto func_ptr = LoadSymbol("cublasDnrm2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasScnrm2_v2(cublasHandle_t handle, + int n, + const cuComplex *x, + int incx, + float *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, float *); + static auto func_ptr = LoadSymbol("cublasScnrm2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDznrm2_v2(cublasHandle_t handle, + int n, + const cuDoubleComplex *x, + int incx, + double *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, double *); + static auto func_ptr = LoadSymbol("cublasDznrm2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDotEx (cublasHandle_t handle, + int n, + const void *x, + cudaDataType xType, + int incx, + const void *y, + cudaDataType yType, + int incy, + void *result, + cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol("cublasDotEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType); +} + +cublasStatus_t CUBLASWINAPI cublasDotcEx (cublasHandle_t handle, + int n, + const void *x, + cudaDataType xType, + int incx, + const void *y, + cudaDataType yType, + int incy, + void *result, + cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol("cublasDotcEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType); +} + +cublasStatus_t CUBLASWINAPI cublasSdot_v2 (cublasHandle_t handle, + int n, + const float *x, + int incx, + const float *y, + int incy, + float *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, const float *, int, float *); + static auto func_ptr = LoadSymbol("cublasSdot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasDdot_v2 (cublasHandle_t handle, + int n, + const double *x, + int incx, + const double *y, + int incy, + double *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, const double *, int, double *); + static auto func_ptr = LoadSymbol("cublasDdot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasCdotu_v2 (cublasHandle_t handle, + int n, + const cuComplex *x, + int incx, + const cuComplex *y, + int incy, + cuComplex *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol("cublasCdotu_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasCdotc_v2 (cublasHandle_t handle, + int n, + const cuComplex *x, + int incx, + const cuComplex *y, + int incy, + cuComplex *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol("cublasCdotc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasZdotu_v2 (cublasHandle_t handle, + int n, + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *y, + int incy, + cuDoubleComplex *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZdotu_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasZdotc_v2 (cublasHandle_t handle, + int n, + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *y, + int incy, + cuDoubleComplex *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZdotc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasScalEx(cublasHandle_t handle, + int n, + const void *alpha, /* host or device pointer */ + cudaDataType alphaType, + void *x, + cudaDataType xType, + int incx, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, void *, cudaDataType, int, cudaDataType); + static auto func_ptr = LoadSymbol("cublasScalEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, alphaType, x, xType, incx, executionType); +} + +cublasStatus_t CUBLASWINAPI cublasSscal_v2(cublasHandle_t handle, + int n, + const float *alpha, /* host or device pointer */ + float *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDscal_v2(cublasHandle_t handle, + int n, + const double *alpha, /* host or device pointer */ + double *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCscal_v2(cublasHandle_t handle, + int n, + const cuComplex *alpha, /* host or device pointer */ + cuComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCsscal_v2(cublasHandle_t handle, + int n, + const float *alpha, /* host or device pointer */ + cuComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZscal_v2(cublasHandle_t handle, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + cuDoubleComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZdscal_v2(cublasHandle_t handle, + int n, + const double *alpha, /* host or device pointer */ + cuDoubleComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZdscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasAxpyEx (cublasHandle_t handle, + int n, + const void *alpha, /* host or device pointer */ + cudaDataType alphaType, + const void *x, + cudaDataType xType, + int incx, + void *y, + cudaDataType yType, + int incy, + cudaDataType executiontype) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, const void *, cudaDataType, int, void *, cudaDataType, int, cudaDataType); + static auto func_ptr = LoadSymbol("cublasAxpyEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, executiontype); +} + +cublasStatus_t CUBLASWINAPI cublasSaxpy_v2 (cublasHandle_t handle, + int n, + const float *alpha, /* host or device pointer */ + const float *x, + int incx, + float *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSaxpy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDaxpy_v2 (cublasHandle_t handle, + int n, + const double *alpha, /* host or device pointer */ + const double *x, + int incx, + double *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDaxpy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCaxpy_v2 (cublasHandle_t handle, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCaxpy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZaxpy_v2 (cublasHandle_t handle, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZaxpy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCopyEx(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, void *y, cudaDataType yType, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, void *, + cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasCopyEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy); +} + +cublasStatus_t CUBLASWINAPI cublasScopy_v2 (cublasHandle_t handle, + int n, + const float *x, + int incx, + float *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasScopy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDcopy_v2 (cublasHandle_t handle, + int n, + const double *x, + int incx, + double *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDcopy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCcopy_v2 (cublasHandle_t handle, + int n, + const cuComplex *x, + int incx, + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCcopy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZcopy_v2 (cublasHandle_t handle, + int n, + const cuDoubleComplex *x, + int incx, + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZcopy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSswap_v2 (cublasHandle_t handle, + int n, + float *x, + int incx, + float *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSswap_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDswap_v2 (cublasHandle_t handle, + int n, + double *x, + int incx, + double *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDswap_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCswap_v2 (cublasHandle_t handle, + int n, + cuComplex *x, + int incx, + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCswap_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZswap_v2 (cublasHandle_t handle, + int n, + cuDoubleComplex *x, + int incx, + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZswap_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSwapEx(cublasHandle_t handle, int n, void *x, + cudaDataType xType, int incx, void *y, + cudaDataType yType, int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, void *, cudaDataType, + int, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasSwapEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy); +} + +cublasStatus_t CUBLASWINAPI cublasIsamax_v2(cublasHandle_t handle, + int n, + const float *x, + int incx, + int *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, int *); + static auto func_ptr = LoadSymbol("cublasIsamax_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIdamax_v2(cublasHandle_t handle, + int n, + const double *x, + int incx, + int *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, int *); + static auto func_ptr = LoadSymbol("cublasIdamax_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIcamax_v2(cublasHandle_t handle, + int n, + const cuComplex *x, + int incx, + int *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cublasIcamax_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIzamax_v2(cublasHandle_t handle, + int n, + const cuDoubleComplex *x, + int incx, + int *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cublasIzamax_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIamaxEx( + cublasHandle_t handle, int n, const void *x, cudaDataType xType, int incx, + int *result /* host or device pointer */ +) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, int *); + static auto func_ptr = LoadSymbol("cublasIamaxEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIsamin_v2(cublasHandle_t handle, + int n, + const float *x, + int incx, + int *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, int *); + static auto func_ptr = LoadSymbol("cublasIsamin_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIdamin_v2(cublasHandle_t handle, + int n, + const double *x, + int incx, + int *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, int *); + static auto func_ptr = LoadSymbol("cublasIdamin_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIcamin_v2(cublasHandle_t handle, + int n, + const cuComplex *x, + int incx, + int *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cublasIcamin_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIzamin_v2(cublasHandle_t handle, + int n, + const cuDoubleComplex *x, + int incx, + int *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cublasIzamin_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIaminEx( + cublasHandle_t handle, int n, const void *x, cudaDataType xType, int incx, + int *result /* host or device pointer */ +) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, int *); + static auto func_ptr = LoadSymbol("cublasIaminEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasAsumEx( + cublasHandle_t handle, int n, const void *x, cudaDataType xType, int incx, + void *result, cudaDataType resultType, /* host or device pointer */ + cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, void *, + cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol("cublasAsumEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, result, resultType, executiontype); +} + +cublasStatus_t CUBLASWINAPI cublasSasum_v2(cublasHandle_t handle, + int n, + const float *x, + int incx, + float *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *); + static auto func_ptr = LoadSymbol("cublasSasum_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDasum_v2(cublasHandle_t handle, + int n, + const double *x, + int incx, + double *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *); + static auto func_ptr = LoadSymbol("cublasDasum_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasScasum_v2(cublasHandle_t handle, + int n, + const cuComplex *x, + int incx, + float *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, float *); + static auto func_ptr = LoadSymbol("cublasScasum_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDzasum_v2(cublasHandle_t handle, + int n, + const cuDoubleComplex *x, + int incx, + double *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, double *); + static auto func_ptr = LoadSymbol("cublasDzasum_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasSrot_v2 (cublasHandle_t handle, + int n, + float *x, + int incx, + float *y, + int incy, + const float *c, /* host or device pointer */ + const float *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int, const float *, const float *); + static auto func_ptr = LoadSymbol("cublasSrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasDrot_v2 (cublasHandle_t handle, + int n, + double *x, + int incx, + double *y, + int incy, + const double *c, /* host or device pointer */ + const double *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int, const double *, const double *); + static auto func_ptr = LoadSymbol("cublasDrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasCrot_v2 (cublasHandle_t handle, + int n, + cuComplex *x, + int incx, + cuComplex *y, + int incy, + const float *c, /* host or device pointer */ + const cuComplex *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, const cuComplex *); + static auto func_ptr = LoadSymbol("cublasCrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasCsrot_v2(cublasHandle_t handle, + int n, + cuComplex *x, + int incx, + cuComplex *y, + int incy, + const float *c, /* host or device pointer */ + const float *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, const float *); + static auto func_ptr = LoadSymbol("cublasCsrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasZrot_v2 (cublasHandle_t handle, + int n, + cuDoubleComplex *x, + int incx, + cuDoubleComplex *y, + int incy, + const double *c, /* host or device pointer */ + const cuDoubleComplex *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, const double *, const cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasZdrot_v2(cublasHandle_t handle, + int n, + cuDoubleComplex *x, + int incx, + cuDoubleComplex *y, + int incy, + const double *c, /* host or device pointer */ + const double *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, const double *, const double *); + static auto func_ptr = LoadSymbol("cublasZdrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI +cublasRotEx(cublasHandle_t handle, int n, void *x, cudaDataType xType, int incx, + void *y, cudaDataType yType, int incy, + const void *c, /* host or device pointer */ + const void *s, cudaDataType csType, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, void *, cudaDataType, int, void *, cudaDataType, int, + const void *, const void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol("cublasRotEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, c, s, csType, + executiontype); +} + +cublasStatus_t CUBLASWINAPI cublasSrotg_v2(cublasHandle_t handle, + float *a, /* host or device pointer */ + float *b, /* host or device pointer */ + float *c, /* host or device pointer */ + float *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, float *, float *, float *, float *); + static auto func_ptr = LoadSymbol("cublasSrotg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasDrotg_v2(cublasHandle_t handle, + double *a, /* host or device pointer */ + double *b, /* host or device pointer */ + double *c, /* host or device pointer */ + double *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, double *, double *, double *, double *); + static auto func_ptr = LoadSymbol("cublasDrotg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasCrotg_v2(cublasHandle_t handle, + cuComplex *a, /* host or device pointer */ + cuComplex *b, /* host or device pointer */ + float *c, /* host or device pointer */ + cuComplex *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cuComplex *, cuComplex *, float *, cuComplex *); + static auto func_ptr = LoadSymbol("cublasCrotg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasZrotg_v2(cublasHandle_t handle, + cuDoubleComplex *a, /* host or device pointer */ + cuDoubleComplex *b, /* host or device pointer */ + double *c, /* host or device pointer */ + cuDoubleComplex *s) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cuDoubleComplex *, cuDoubleComplex *, double *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZrotg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasRotgEx(cublasHandle_t handle, + void *a, /* host or device pointer */ + void *b, /* host or device pointer */ + cudaDataType abType, + void *c, /* host or device pointer */ + void *s, /* host or device pointer */ + cudaDataType csType, + cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, void *, void *, + cudaDataType, void *, void *, + cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol("cublasRotgEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, abType, c, s, csType, executiontype); +} + +cublasStatus_t CUBLASWINAPI cublasSrotm_v2(cublasHandle_t handle, + int n, + float *x, + int incx, + float *y, + int incy, + const float* param) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int, const float *); + static auto func_ptr = LoadSymbol("cublasSrotm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, param); +} + +cublasStatus_t CUBLASWINAPI cublasDrotm_v2(cublasHandle_t handle, + int n, + double *x, + int incx, + double *y, + int incy, + const double* param) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int, const double *); + static auto func_ptr = LoadSymbol("cublasDrotm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, param); +} + +cublasStatus_t CUBLASWINAPI +cublasRotmEx(cublasHandle_t handle, int n, void *x, cudaDataType xType, + int incx, void *y, cudaDataType yType, int incy, + const void *param, /* host or device pointer */ + cudaDataType paramType, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, void *, cudaDataType, int, void *, cudaDataType, int, + const void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol("cublasRotmEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, param, paramType, + executiontype); +} + +cublasStatus_t CUBLASWINAPI cublasSrotmg_v2(cublasHandle_t handle, + float *d1, /* host or device pointer */ + float *d2, /* host or device pointer */ + float *x1, /* host or device pointer */ + const float *y1, /* host or device pointer */ + float *param) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, float *, float *, float *, const float *, float *); + static auto func_ptr = LoadSymbol("cublasSrotmg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, d1, d2, x1, y1, param); +} + +cublasStatus_t CUBLASWINAPI cublasDrotmg_v2(cublasHandle_t handle, + double *d1, /* host or device pointer */ + double *d2, /* host or device pointer */ + double *x1, /* host or device pointer */ + const double *y1, /* host or device pointer */ + double *param) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, double *, double *, double *, const double *, double *); + static auto func_ptr = LoadSymbol("cublasDrotmg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, d1, d2, x1, y1, param); +} + +cublasStatus_t CUBLASWINAPI +cublasRotmgEx(cublasHandle_t handle, void *d1, /* host or device pointer */ + cudaDataType d1Type, void *d2, /* host or device pointer */ + cudaDataType d2Type, void *x1, /* host or device pointer */ + cudaDataType x1Type, const void *y1, /* host or device pointer */ + cudaDataType y1Type, void *param, /* host or device pointer */ + cudaDataType paramType, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, void *, cudaDataType, void *, cudaDataType, void *, + cudaDataType, const void *, cudaDataType, void *, cudaDataType, + cudaDataType); + static auto func_ptr = LoadSymbol("cublasRotmgEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, d1, d1Type, d2, d2Type, x1, x1Type, y1, y1Type, param, + paramType, executiontype); +} + +cublasStatus_t CUBLASWINAPI cublasSgemv_v2 (cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *x, + int incx, + const float *beta, /* host or device pointer */ + float *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSgemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDgemv_v2 (cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *x, + int incx, + const double *beta, /* host or device pointer */ + double *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDgemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCgemv_v2 (cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *x, + int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZgemv_v2 (cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSgbmv_v2 (cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int kl, + int ku, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *x, + int incx, + const float *beta, /* host or device pointer */ + float *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSgbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDgbmv_v2 (cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int kl, + int ku, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *x, + int incx, + const double *beta, /* host or device pointer */ + double *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDgbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCgbmv_v2 (cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int kl, + int ku, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *x, + int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZgbmv_v2 (cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int kl, + int ku, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasStrmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const float *A, + int lda, + float *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStrmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtrmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const double *A, + int lda, + double *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtrmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtrmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuComplex *A, + int lda, + cuComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtrmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtrmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuDoubleComplex *A, + int lda, + cuDoubleComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtrmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStbmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const float *A, + int lda, + float *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtbmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const double *A, + int lda, + double *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtbmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const cuComplex *A, + int lda, + cuComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtbmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const cuDoubleComplex *A, + int lda, + cuDoubleComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStpmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const float *AP, + float *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasStpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtpmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const double *AP, + double *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDtpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtpmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuComplex *AP, + cuComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtpmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuDoubleComplex *AP, + cuDoubleComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStrsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const float *A, + int lda, + float *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStrsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtrsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const double *A, + int lda, + double *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtrsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtrsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuComplex *A, + int lda, + cuComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtrsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtrsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuDoubleComplex *A, + int lda, + cuDoubleComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtrsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStpsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const float *AP, + float *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasStpsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtpsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const double *AP, + double *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDtpsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtpsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuComplex *AP, + cuComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtpsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtpsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuDoubleComplex *AP, + cuDoubleComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtpsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStbsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const float *A, + int lda, + float *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStbsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtbsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const double *A, + int lda, + double *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtbsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtbsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const cuComplex *A, + int lda, + cuComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtbsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtbsv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const cuDoubleComplex *A, + int lda, + cuDoubleComplex *x, + int incx) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtbsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasSsymv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *x, + int incx, + const float *beta, /* host or device pointer */ + float *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSsymv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDsymv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *x, + int incx, + const double *beta, /* host or device pointer */ + double *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDsymv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCsymv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *x, + int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsymv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZsymv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsymv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasChemv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *x, + int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasChemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZhemv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZhemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSsbmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + int k, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *x, + int incx, + const float *beta, /* host or device pointer */ + float *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSsbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDsbmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + int k, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *x, + int incx, + const double *beta, /* host or device pointer */ + double *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDsbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasChbmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *x, + int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasChbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZhbmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZhbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSspmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *alpha, /* host or device pointer */ + const float *AP, + const float *x, + int incx, + const float *beta, /* host or device pointer */ + float *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSspmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDspmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *alpha, /* host or device pointer */ + const double *AP, + const double *x, + int incx, + const double *beta, /* host or device pointer */ + double *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDspmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasChpmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *AP, + const cuComplex *x, + int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasChpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZhpmv_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *AP, + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, + int incy) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZhpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSger_v2 (cublasHandle_t handle, + int m, + int n, + const float *alpha, /* host or device pointer */ + const float *x, + int incx, + const float *y, + int incy, + float *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const float *, const float *, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSger_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasDger_v2 (cublasHandle_t handle, + int m, + int n, + const double *alpha, /* host or device pointer */ + const double *x, + int incx, + const double *y, + int incy, + double *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const double *, const double *, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDger_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasCgeru_v2 (cublasHandle_t handle, + int m, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + const cuComplex *y, + int incy, + cuComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgeru_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasCgerc_v2 (cublasHandle_t handle, + int m, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + const cuComplex *y, + int incy, + cuComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgerc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasZgeru_v2 (cublasHandle_t handle, + int m, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *y, + int incy, + cuDoubleComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgeru_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasZgerc_v2 (cublasHandle_t handle, + int m, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *y, + int incy, + cuDoubleComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgerc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasSsyr_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *alpha, /* host or device pointer */ + const float *x, + int incx, + float *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasDsyr_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *alpha, /* host or device pointer */ + const double *x, + int incx, + double *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasCsyr_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + cuComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsyr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasZsyr_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + cuDoubleComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsyr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasCher_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + cuComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCher_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasZher_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + cuDoubleComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZher_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasSspr_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *alpha, /* host or device pointer */ + const float *x, + int incx, + float *AP) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, float *); + static auto func_ptr = LoadSymbol("cublasSspr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, AP); +} + +cublasStatus_t CUBLASWINAPI cublasDspr_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *alpha, /* host or device pointer */ + const double *x, + int incx, + double *AP) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, double *); + static auto func_ptr = LoadSymbol("cublasDspr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, AP); +} + +cublasStatus_t CUBLASWINAPI cublasChpr_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + cuComplex *AP) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol("cublasChpr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, AP); +} + +cublasStatus_t CUBLASWINAPI cublasZhpr_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZhpr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, AP); +} + +cublasStatus_t CUBLASWINAPI cublasSsyr2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *alpha, /* host or device pointer */ + const float *x, + int incx, + const float *y, + int incy, + float *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasDsyr2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *alpha, /* host or device pointer */ + const double *x, + int incx, + const double *y, + int incy, + double *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasCsyr2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + const cuComplex *y, + int incy, + cuComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsyr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasZsyr2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *y, + int incy, + cuDoubleComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsyr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasCher2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + const cuComplex *y, + int incy, + cuComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCher2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasZher2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *y, + int incy, + cuDoubleComplex *A, + int lda) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZher2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasSspr2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *alpha, /* host or device pointer */ + const float *x, + int incx, + const float *y, + int incy, + float *AP) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, float *); + static auto func_ptr = LoadSymbol("cublasSspr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); +} + +cublasStatus_t CUBLASWINAPI cublasDspr2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *alpha, /* host or device pointer */ + const double *x, + int incx, + const double *y, + int incy, + double *AP) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, double *); + static auto func_ptr = LoadSymbol("cublasDspr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); +} + +cublasStatus_t CUBLASWINAPI cublasChpr2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, + int incx, + const cuComplex *y, + int incy, + cuComplex *AP) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol("cublasChpr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); +} + +cublasStatus_t CUBLASWINAPI cublasZhpr2_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, + int incx, + const cuDoubleComplex *y, + int incy, + cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZhpr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); +} + +cublasStatus_t CUBLASWINAPI cublasSgemm_v2 (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *B, + int ldb, + const float *beta, /* host or device pointer */ + float *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSgemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDgemm_v2 (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *B, + int ldb, + const double *beta, /* host or device pointer */ + double *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDgemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm_v2 (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm3m (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgemm3m"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm3mEx (cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuComplex *alpha, + const void *A, + cudaDataType Atype, + int lda, + const void *B, + cudaDataType Btype, + int ldb, + const cuComplex *beta, + void *C, + cudaDataType Ctype, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const void *, cudaDataType, int, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasCgemm3mEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZgemm_v2 (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZgemm3m (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgemm3m"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSgemmEx (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host or device pointer */ + const void *A, + cudaDataType Atype, + int lda, + const void *B, + cudaDataType Btype, + int ldb, + const float *beta, /* host or device pointer */ + void *C, + cudaDataType Ctype, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const void *, cudaDataType, int, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasSgemmEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasGemmEx (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, /* host or device pointer */ + const void *A, + cudaDataType Atype, + int lda, + const void *B, + cudaDataType Btype, + int ldb, + const void *beta, /* host or device pointer */ + void *C, + cudaDataType Ctype, + int ldc, + cudaDataType computeType, + cublasGemmAlgo_t algo) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const void *, const void *, cudaDataType, int, const void *, cudaDataType, int, const void *, void *, cudaDataType, int, cudaDataType, cublasGemmAlgo_t); + static auto func_ptr = LoadSymbol("cublasGemmEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo); +} + +cublasStatus_t CUBLASWINAPI cublasCgemmEx (cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuComplex *alpha, + const void *A, + cudaDataType Atype, + int lda, + const void *B, + cudaDataType Btype, + int ldb, + const cuComplex *beta, + void *C, + cudaDataType Ctype, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const void *, cudaDataType, int, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasCgemmEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasUint8gemmBias (cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, cublasOperation_t transc, + int m, int n, int k, + const unsigned char *A, int A_bias, int lda, + const unsigned char *B, int B_bias, int ldb, + unsigned char *C, int C_bias, int ldc, + int C_mult, int C_shift) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, cublasOperation_t, int, int, int, const unsigned char *, int, int, const unsigned char *, int, int, unsigned char *, int, int, int, int); + static auto func_ptr = LoadSymbol("cublasUint8gemmBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, transc, m, n, k, A, A_bias, lda, B, B_bias, ldb, C, C_bias, ldc, C_mult, C_shift); +} + +cublasStatus_t CUBLASWINAPI cublasSsyrk_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *beta, /* host or device pointer */ + float *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyrk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDsyrk_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *beta, /* host or device pointer */ + double *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyrk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyrk_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsyrk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZsyrk_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsyrk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyrkEx ( cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const void *A, + cudaDataType Atype, + int lda, + const cuComplex *beta, /* host or device pointer */ + void *C, + cudaDataType Ctype, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasCsyrkEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyrk3mEx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex *alpha, + const void *A, + cudaDataType Atype, + int lda, + const cuComplex *beta, + void *C, + cudaDataType Ctype, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasCsyrk3mEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCherk_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const float *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const cuComplex *, int, const float *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCherk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZherk_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZherk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCherkEx (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float *alpha, /* host or device pointer */ + const void *A, + cudaDataType Atype, + int lda, + const float *beta, /* host or device pointer */ + void *C, + cudaDataType Ctype, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasCherkEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCherk3mEx (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float *alpha, + const void *A, cudaDataType Atype, + int lda, + const float *beta, + void *C, + cudaDataType Ctype, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol("cublasCherk3mEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSsyr2k_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *B, + int ldb, + const float *beta, /* host or device pointer */ + float *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyr2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDsyr2k_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *B, + int ldb, + const double *beta, /* host or device pointer */ + double *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyr2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyr2k_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsyr2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZsyr2k_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsyr2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCher2k_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + const float *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const float *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCher2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZher2k_v2 (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZher2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSsyrkx (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *B, + int ldb, + const float *beta, /* host or device pointer */ + float *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyrkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDsyrkx (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *B, + int ldb, + const double *beta, /* host or device pointer */ + double *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyrkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyrkx (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsyrkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZsyrkx (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsyrkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCherkx (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + const float *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const float *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCherkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZherkx (cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZherkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSsymm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *B, + int ldb, + const float *beta, /* host or device pointer */ + float *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasSsymm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDsymm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *B, + int ldb, + const double *beta, /* host or device pointer */ + double *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDsymm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsymm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsymm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZsymm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsymm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasChemm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasChemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZhemm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZhemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasStrsm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + float *B, + int ldb) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStrsm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasDtrsm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + double *B, + int ldb) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtrsm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasCtrsm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + cuComplex *B, + int ldb) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtrsm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasZtrsm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + cuDoubleComplex *B, + int ldb) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtrsm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasStrmm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *B, + int ldb, + float *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, const float *, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStrmm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDtrmm_v2 (cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *B, + int ldb, + double *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, const double *, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtrmm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCtrmm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *B, + int ldb, + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtrmm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZtrmm_v2(cublasHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtrmm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *const Aarray[], int lda, const float *const Barray[], int ldb, + const float *beta, /* host or device pointer */ + float *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *const[], int, const float *const[], int, + const float *, float *const[], int, int); + static auto func_ptr = LoadSymbol("cublasSgemmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasDgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *const Aarray[], int lda, const double *const Barray[], + int ldb, const double *beta, /* host or device pointer */ + double *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *const[], int, const double *const[], int, + const double *, double *const[], int, int); + static auto func_ptr = LoadSymbol("cublasDgemmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *const Aarray[], int lda, const cuComplex *const Barray[], + int ldb, const cuComplex *beta, /* host or device pointer */ + cuComplex *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *const[], int, + const cuComplex *const[], int, const cuComplex *, cuComplex *const[], int, + int); + static auto func_ptr = LoadSymbol("cublasCgemmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm3mBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *const Aarray[], int lda, const cuComplex *const Barray[], + int ldb, const cuComplex *beta, /* host or device pointer */ + cuComplex *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *const[], int, + const cuComplex *const[], int, const cuComplex *, cuComplex *const[], int, + int); + static auto func_ptr = LoadSymbol("cublasCgemm3mBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI +cublasZgemmBatched(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *const Aarray[], int lda, + const cuDoubleComplex *const Barray[], int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *const[], int, + const cuDoubleComplex *const[], int, const cuDoubleComplex *, + cuDoubleComplex *const[], int, int); + static auto func_ptr = LoadSymbol("cublasZgemmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasGemmBatchedEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const void *alpha, /* host or device pointer */ + const void *const Aarray[], cudaDataType Atype, int lda, + const void *const Barray[], cudaDataType Btype, int ldb, + const void *beta, /* host or device pointer */ + void *const Carray[], cudaDataType Ctype, int ldc, int batchCount, + cudaDataType computeType, cublasGemmAlgo_t algo) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const void *, const void *const[], cudaDataType, int, const void *const[], + cudaDataType, int, const void *, void *const[], cudaDataType, int, int, + cudaDataType, cublasGemmAlgo_t); + static auto func_ptr = LoadSymbol("cublasGemmBatchedEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, Atype, lda, + Barray, Btype, ldb, beta, Carray, Ctype, ldc, batchCount, + computeType, algo); +} + +cublasStatus_t CUBLASWINAPI cublasGemmStridedBatchedEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const void *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, + long long int strideA, /* purposely signed */ + const void *B, cudaDataType Btype, int ldb, long long int strideB, + const void *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc, long long int strideC, int batchCount, + cudaDataType computeType, cublasGemmAlgo_t algo) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const void *, const void *, cudaDataType, int, long long, const void *, + cudaDataType, int, long long, const void *, void *, cudaDataType, int, + long long, int, cudaDataType, cublasGemmAlgo_t); + static auto func_ptr = LoadSymbol("cublasGemmStridedBatchedEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, + strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, + batchCount, computeType, algo); +} + +cublasStatus_t CUBLASWINAPI cublasSgemmStridedBatched (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + long long int strideA, /* purposely signed */ + const float *B, + int ldb, + long long int strideB, + const float *beta, /* host or device pointer */ + float *C, + int ldc, + long long int strideC, + int batchCount) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const float *, int, long long, const float *, int, long long, const float *, float *, int, long long, int); + static auto func_ptr = LoadSymbol("cublasSgemmStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasDgemmStridedBatched (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + long long int strideA, /* purposely signed */ + const double *B, + int ldb, + long long int strideB, + const double *beta, /* host or device pointer */ + double *C, + int ldc, + long long int strideC, + int batchCount) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const double *, const double *, int, long long, const double *, int, long long, const double *, double *, int, long long, int); + static auto func_ptr = LoadSymbol("cublasDgemmStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCgemmStridedBatched (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + long long int strideA, /* purposely signed */ + const cuComplex *B, + int ldb, + long long int strideB, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc, + long long int strideC, + int batchCount) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, long long, const cuComplex *, int, long long, const cuComplex *, cuComplex *, int, long long, int); + static auto func_ptr = LoadSymbol("cublasCgemmStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm3mStridedBatched (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + long long int strideA, /* purposely signed */ + const cuComplex *B, + int ldb, + long long int strideB, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, + int ldc, + long long int strideC, + int batchCount) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, long long, const cuComplex *, int, long long, const cuComplex *, cuComplex *, int, long long, int); + static auto func_ptr = LoadSymbol("cublasCgemm3mStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasZgemmStridedBatched (cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + long long int strideA, /* purposely signed */ + const cuDoubleComplex *B, + int ldb, + long long int strideB, + const cuDoubleComplex *beta, /* host or device poi */ + cuDoubleComplex *C, + int ldc, + long long int strideC, + int batchCount) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, long long, const cuDoubleComplex *, int, long long, const cuDoubleComplex *, cuDoubleComplex *, int, long long, int); + static auto func_ptr = LoadSymbol("cublasZgemmStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasSgeam(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + const float *alpha, /* host or device pointer */ + const float *A, + int lda, + const float *beta , /* host or device pointer */ + const float *B, + int ldb, + float *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDgeam(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + const double *alpha, /* host or device pointer */ + const double *A, + int lda, + const double *beta, /* host or device pointer */ + const double *B, + int ldb, + double *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCgeam(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, + int lda, + const cuComplex *beta, /* host or device pointer */ + const cuComplex *B, + int ldb, + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZgeam(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *beta, /* host or device pointer */ + const cuDoubleComplex *B, + int ldb, + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSgetrfBatched( + cublasHandle_t handle, int n, float *const A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, float *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol("cublasSgetrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasDgetrfBatched( + cublasHandle_t handle, int n, double *const A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol("cublasDgetrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasCgetrfBatched( + cublasHandle_t handle, int n, cuComplex *const A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol("cublasCgetrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasZgetrfBatched( + cublasHandle_t handle, int n, cuDoubleComplex *const A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol("cublasZgetrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasSgetriBatched( + cublasHandle_t handle, int n, const float *const A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + float *const C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *const[], int, const int *, + float *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasSgetriBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasDgetriBatched( + cublasHandle_t handle, int n, const double *const A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + double *const C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *const[], int, const int *, + double *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasDgetriBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasCgetriBatched( + cublasHandle_t handle, int n, const cuComplex *const A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + cuComplex *const C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *const[], int, const int *, + cuComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasCgetriBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasZgetriBatched(cublasHandle_t handle, int n, + const cuDoubleComplex *const A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + cuDoubleComplex *const C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *const[], int, const int *, + cuDoubleComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasZgetriBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasSgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const float *const Aarray[], int lda, const int *devIpiv, + float *const Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const float *const[], int, + const int *, float *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasSgetrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasDgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const double *const Aarray[], int lda, const int *devIpiv, + double *const Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const double *const[], int, + const int *, double *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasDgetrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasCgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuComplex *const Aarray[], int lda, const int *devIpiv, + cuComplex *const Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuComplex *const[], + int, const int *, cuComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasCgetrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasZgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuDoubleComplex *const Aarray[], int lda, const int *devIpiv, + cuDoubleComplex *const Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, + const cuDoubleComplex *const[], int, const int *, + cuDoubleComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasZgetrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasStrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /*Host or Device Pointer*/ + const float *const A[], int lda, float *const B[], int ldb, + int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *const[], int, + float *const[], int, int); + static auto func_ptr = LoadSymbol("cublasStrsmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasDtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /*Host or Device Pointer*/ + const double *const A[], int lda, double *const B[], int ldb, + int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *const[], int, + double *const[], int, int); + static auto func_ptr = LoadSymbol("cublasDtrsmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /*Host or Device Pointer*/ + const cuComplex *const A[], int lda, cuComplex *const B[], int ldb, + int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *const[], + int, cuComplex *const[], int, int); + static auto func_ptr = LoadSymbol("cublasCtrsmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasZtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /*Host or Device Pointer*/ + const cuDoubleComplex *const A[], int lda, cuDoubleComplex *const B[], + int ldb, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *const[], int, cuDoubleComplex *const[], int, int); + static auto func_ptr = LoadSymbol("cublasZtrsmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasSmatinvBatched( + cublasHandle_t handle, int n, const float *const A[], /*Device pointer*/ + int lda, float *const Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const float *const[], + int, float *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasSmatinvBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasDmatinvBatched( + cublasHandle_t handle, int n, const double *const A[], /*Device pointer*/ + int lda, double *const Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const double *const[], + int, double *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasDmatinvBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasCmatinvBatched( + cublasHandle_t handle, int n, const cuComplex *const A[], /*Device pointer*/ + int lda, cuComplex *const Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *const[], int, cuComplex *const[], + int, int *, int); + static auto func_ptr = LoadSymbol("cublasCmatinvBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasZmatinvBatched(cublasHandle_t handle, int n, + const cuDoubleComplex *const A[], /*Device pointer*/ + int lda, cuDoubleComplex *const Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *const[], int, + cuDoubleComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol("cublasZmatinvBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasSgeqrfBatched(cublasHandle_t handle, int m, int n, + float *const Aarray[], /*Device pointer*/ + int lda, float *const TauArray[], /*Device pointer*/ + int *info, int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, int, float *const[], + int, float *const[], int *, int); + static auto func_ptr = LoadSymbol("cublasSgeqrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasDgeqrfBatched(cublasHandle_t handle, int m, int n, + double *const Aarray[], /*Device pointer*/ + int lda, double *const TauArray[], /*Device pointer*/ + int *info, int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, int, double *const[], + int, double *const[], int *, int); + static auto func_ptr = LoadSymbol("cublasDgeqrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasCgeqrfBatched(cublasHandle_t handle, int m, int n, + cuComplex *const Aarray[], /*Device pointer*/ + int lda, cuComplex *const TauArray[], /*Device pointer*/ + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, cuComplex *const[], int, cuComplex *const[], + int *, int); + static auto func_ptr = LoadSymbol("cublasCgeqrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasZgeqrfBatched( + cublasHandle_t handle, int m, int n, + cuDoubleComplex *const Aarray[], /*Device pointer*/ + int lda, cuDoubleComplex *const TauArray[], /*Device pointer*/ + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, cuDoubleComplex *const[], int, + cuDoubleComplex *const[], int *, int); + static auto func_ptr = LoadSymbol("cublasZgeqrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasSgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, float *const Aarray[], /*Device pointer*/ + int lda, float *const Carray[], /*Device pointer*/ + int ldc, int *info, int *devInfoArray, /*Device pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, float *const[], int, + float *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol("cublasSgelsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasDgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, double *const Aarray[], /*Device pointer*/ + int lda, double *const Carray[], /*Device pointer*/ + int ldc, int *info, int *devInfoArray, /*Device pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, double *const[], int, + double *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol("cublasDgelsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasCgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, cuComplex *const Aarray[], /*Device pointer*/ + int lda, cuComplex *const Carray[], /*Device pointer*/ + int ldc, int *info, int *devInfoArray, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, cuComplex *const[], int, + cuComplex *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol("cublasCgelsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasZgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, cuDoubleComplex *const Aarray[], /*Device pointer*/ + int lda, cuDoubleComplex *const Carray[], /*Device pointer*/ + int ldc, int *info, int *devInfoArray, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, + cuDoubleComplex *const[], int, cuDoubleComplex *const[], int, int *, + int *, int); + static auto func_ptr = LoadSymbol("cublasZgelsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasSdgmm(cublasHandle_t handle, + cublasSideMode_t mode, + int m, + int n, + const float *A, + int lda, + const float *x, + int incx, + float *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const float *, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSdgmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDdgmm(cublasHandle_t handle, + cublasSideMode_t mode, + int m, + int n, + const double *A, + int lda, + const double *x, + int incx, + double *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const double *, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDdgmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCdgmm(cublasHandle_t handle, + cublasSideMode_t mode, + int m, + int n, + const cuComplex *A, + int lda, + const cuComplex *x, + int incx, + cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCdgmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZdgmm(cublasHandle_t handle, + cublasSideMode_t mode, + int m, + int n, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *x, + int incx, + cuDoubleComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZdgmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasStpttr ( cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *AP, + float *A, + int lda ) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasStpttr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, AP, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasDtpttr ( cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *AP, + double *A, + int lda ) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDtpttr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, AP, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasCtpttr ( cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex *AP, + cuComplex *A, + int lda ) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtpttr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, AP, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasZtpttr ( cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *AP, + cuDoubleComplex *A, + int lda ) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtpttr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, AP, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasStrttp ( cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float *A, + int lda, + float *AP ) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, int, float *); + static auto func_ptr = LoadSymbol("cublasStrttp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, AP); +} + +cublasStatus_t CUBLASWINAPI cublasDtrttp ( cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double *A, + int lda, + double *AP ) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, int, double *); + static auto func_ptr = LoadSymbol("cublasDtrttp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, AP); +} + +cublasStatus_t CUBLASWINAPI cublasCtrttp ( cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex *A, + int lda, + cuComplex *AP ) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol("cublasCtrttp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, AP); +} + +cublasStatus_t CUBLASWINAPI cublasZtrttp ( cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + cuDoubleComplex *AP ) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZtrttp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, AP); +} + +cublasStatus CUBLASWINAPI cublasInit (void) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol("cublasInit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +cublasStatus CUBLASWINAPI cublasShutdown (void) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol("cublasShutdown"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +cublasStatus CUBLASWINAPI cublasGetError (void) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol("cublasGetError"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +cublasStatus CUBLASWINAPI cublasGetVersion(int *version) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int *); + static auto func_ptr = LoadSymbol("cublasGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(version); +} + +cublasStatus CUBLASWINAPI cublasAlloc (int n, int elemSize, void **devicePtr) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, void **); + static auto func_ptr = LoadSymbol("cublasAlloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, devicePtr); +} + +cublasStatus CUBLASWINAPI cublasFree (void *devicePtr) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(void *); + static auto func_ptr = LoadSymbol("cublasFree"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devicePtr); +} + +cublasStatus CUBLASWINAPI cublasSetKernelStream (cudaStream_t stream) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol("cublasSetKernelStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +float CUBLASWINAPI cublasSnrm2 (int n, const float *x, int incx) { + using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int); + static auto func_ptr = LoadSymbol("cublasSnrm2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSnrm2"); + return func_ptr(n, x, incx); +} + +double CUBLASWINAPI cublasDnrm2 (int n, const double *x, int incx) { + using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int); + static auto func_ptr = LoadSymbol("cublasDnrm2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDnrm2"); + return func_ptr(n, x, incx); +} + +float CUBLASWINAPI cublasScnrm2 (int n, const cuComplex *x, int incx) { + using FuncPtr = float (CUBLASWINAPI *)(int, const cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasScnrm2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasScnrm2"); + return func_ptr(n, x, incx); +} + +double CUBLASWINAPI cublasDznrm2 (int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = double (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasDznrm2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDznrm2"); + return func_ptr(n, x, incx); +} + +float CUBLASWINAPI cublasSdot (int n, const float *x, int incx, const float *y, + int incy) { + using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int, const float *, int); + static auto func_ptr = LoadSymbol("cublasSdot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSdot"); + return func_ptr(n, x, incx, y, incy); +} + +double CUBLASWINAPI cublasDdot (int n, const double *x, int incx, const double *y, + int incy) { + using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int, const double *, int); + static auto func_ptr = LoadSymbol("cublasDdot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDdot"); + return func_ptr(n, x, incx, y, incy); +} + +cuComplex CUBLASWINAPI cublasCdotu (int n, const cuComplex *x, int incx, const cuComplex *y, + int incy) { + using FuncPtr = cuComplex (CUBLASWINAPI *)(int, const cuComplex *, int, const cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCdotu"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCdotu"); + return func_ptr(n, x, incx, y, incy); +} + +cuComplex CUBLASWINAPI cublasCdotc (int n, const cuComplex *x, int incx, const cuComplex *y, + int incy) { + using FuncPtr = cuComplex (CUBLASWINAPI *)(int, const cuComplex *, int, const cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCdotc"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCdotc"); + return func_ptr(n, x, incx, y, incy); +} + +cuDoubleComplex CUBLASWINAPI cublasZdotu (int n, const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy) { + using FuncPtr = cuDoubleComplex (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZdotu"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZdotu"); + return func_ptr(n, x, incx, y, incy); +} + +cuDoubleComplex CUBLASWINAPI cublasZdotc (int n, const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy) { + using FuncPtr = cuDoubleComplex (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZdotc"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZdotc"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasSscal (int n, float alpha, float *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasDscal (int n, double alpha, double *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasCscal (int n, cuComplex alpha, cuComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasZscal (int n, cuDoubleComplex alpha, cuDoubleComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasCsscal (int n, float alpha, cuComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(int, float, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasZdscal (int n, double alpha, cuDoubleComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(int, double, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZdscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZdscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasSaxpy (int n, float alpha, const float *x, int incx, + float *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, float, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSaxpy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSaxpy"); + return func_ptr(n, alpha, x, incx, y, incy); +} + +void CUBLASWINAPI cublasDaxpy (int n, double alpha, const double *x, + int incx, double *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, double, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDaxpy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDaxpy"); + return func_ptr(n, alpha, x, incx, y, incy); +} + +void CUBLASWINAPI cublasCaxpy (int n, cuComplex alpha, const cuComplex *x, + int incx, cuComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCaxpy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCaxpy"); + return func_ptr(n, alpha, x, incx, y, incy); +} + +void CUBLASWINAPI cublasZaxpy (int n, cuDoubleComplex alpha, const cuDoubleComplex *x, + int incx, cuDoubleComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZaxpy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZaxpy"); + return func_ptr(n, alpha, x, incx, y, incy); +} + +void CUBLASWINAPI cublasScopy (int n, const float *x, int incx, float *y, + int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasScopy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasScopy"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasDcopy (int n, const double *x, int incx, double *y, + int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDcopy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDcopy"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasCcopy (int n, const cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCcopy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCcopy"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasZcopy (int n, const cuDoubleComplex *x, int incx, cuDoubleComplex *y, + int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZcopy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZcopy"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasSswap (int n, float *x, int incx, float *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSswap"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSswap"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasDswap (int n, double *x, int incx, double *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDswap"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDswap"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasCswap (int n, cuComplex *x, int incx, cuComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCswap"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCswap"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasZswap (int n, cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZswap"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZswap"); + return func_ptr(n, x, incx, y, incy); +} + +int CUBLASWINAPI cublasIsamax (int n, const float *x, int incx) { + using FuncPtr = int (CUBLASWINAPI *)(int, const float *, int); + static auto func_ptr = LoadSymbol("cublasIsamax"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIsamax"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIdamax (int n, const double *x, int incx) { + using FuncPtr = int (CUBLASWINAPI *)(int, const double *, int); + static auto func_ptr = LoadSymbol("cublasIdamax"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIdamax"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIcamax (int n, const cuComplex *x, int incx) { + using FuncPtr = int (CUBLASWINAPI *)(int, const cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasIcamax"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIcamax"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIzamax (int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = int (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasIzamax"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIzamax"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIsamin (int n, const float *x, int incx) { + using FuncPtr = int (CUBLASWINAPI *)(int, const float *, int); + static auto func_ptr = LoadSymbol("cublasIsamin"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIsamin"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIdamin (int n, const double *x, int incx) { + using FuncPtr = int (CUBLASWINAPI *)(int, const double *, int); + static auto func_ptr = LoadSymbol("cublasIdamin"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIdamin"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIcamin (int n, const cuComplex *x, int incx) { + using FuncPtr = int (CUBLASWINAPI *)(int, const cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasIcamin"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIcamin"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIzamin (int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = int (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasIzamin"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIzamin"); + return func_ptr(n, x, incx); +} + +float CUBLASWINAPI cublasSasum (int n, const float *x, int incx) { + using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int); + static auto func_ptr = LoadSymbol("cublasSasum"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSasum"); + return func_ptr(n, x, incx); +} + +double CUBLASWINAPI cublasDasum (int n, const double *x, int incx) { + using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int); + static auto func_ptr = LoadSymbol("cublasDasum"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDasum"); + return func_ptr(n, x, incx); +} + +float CUBLASWINAPI cublasScasum (int n, const cuComplex *x, int incx) { + using FuncPtr = float (CUBLASWINAPI *)(int, const cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasScasum"); + if (!func_ptr) LogFatalSymbolNotFound("cublasScasum"); + return func_ptr(n, x, incx); +} + +double CUBLASWINAPI cublasDzasum (int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = double (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasDzasum"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDzasum"); + return func_ptr(n, x, incx); +} + +void CUBLASWINAPI cublasSrot (int n, float *x, int incx, float *y, int incy, + float sc, float ss) { + using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int, float, float); + static auto func_ptr = LoadSymbol("cublasSrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSrot"); + return func_ptr(n, x, incx, y, incy, sc, ss); +} + +void CUBLASWINAPI cublasDrot (int n, double *x, int incx, double *y, int incy, + double sc, double ss) { + using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int, double, double); + static auto func_ptr = LoadSymbol("cublasDrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDrot"); + return func_ptr(n, x, incx, y, incy, sc, ss); +} + +void CUBLASWINAPI cublasCrot (int n, cuComplex *x, int incx, cuComplex *y, + int incy, float c, cuComplex s) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, float, cuComplex); + static auto func_ptr = LoadSymbol("cublasCrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCrot"); + return func_ptr(n, x, incx, y, incy, c, s); +} + +void CUBLASWINAPI cublasZrot (int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, double sc, + cuDoubleComplex cs) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, double, cuDoubleComplex); + static auto func_ptr = LoadSymbol("cublasZrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZrot"); + return func_ptr(n, x, incx, y, incy, sc, cs); +} + +void CUBLASWINAPI cublasCsrot (int n, cuComplex *x, int incx, cuComplex *y, + int incy, float c, float s) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, float, float); + static auto func_ptr = LoadSymbol("cublasCsrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsrot"); + return func_ptr(n, x, incx, y, incy, c, s); +} + +void CUBLASWINAPI cublasZdrot (int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, double c, double s) { + using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, double, double); + static auto func_ptr = LoadSymbol("cublasZdrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZdrot"); + return func_ptr(n, x, incx, y, incy, c, s); +} + +void CUBLASWINAPI cublasSrotg (float *sa, float *sb, float *sc, float *ss) { + using FuncPtr = void (CUBLASWINAPI *)(float *, float *, float *, float *); + static auto func_ptr = LoadSymbol("cublasSrotg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSrotg"); + return func_ptr(sa, sb, sc, ss); +} + +void CUBLASWINAPI cublasDrotg (double *sa, double *sb, double *sc, double *ss) { + using FuncPtr = void (CUBLASWINAPI *)(double *, double *, double *, double *); + static auto func_ptr = LoadSymbol("cublasDrotg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDrotg"); + return func_ptr(sa, sb, sc, ss); +} + +void CUBLASWINAPI cublasCrotg (cuComplex *ca, cuComplex cb, float *sc, + cuComplex *cs) { + using FuncPtr = void (CUBLASWINAPI *)(cuComplex *, cuComplex, float *, cuComplex *); + static auto func_ptr = LoadSymbol("cublasCrotg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCrotg"); + return func_ptr(ca, cb, sc, cs); +} + +void CUBLASWINAPI cublasZrotg (cuDoubleComplex *ca, cuDoubleComplex cb, double *sc, + cuDoubleComplex *cs) { + using FuncPtr = void (CUBLASWINAPI *)(cuDoubleComplex *, cuDoubleComplex, double *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZrotg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZrotg"); + return func_ptr(ca, cb, sc, cs); +} + +void CUBLASWINAPI cublasSrotm(int n, float *x, int incx, float *y, int incy, + const float* sparam) { + using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int, const float *); + static auto func_ptr = LoadSymbol("cublasSrotm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSrotm"); + return func_ptr(n, x, incx, y, incy, sparam); +} + +void CUBLASWINAPI cublasDrotm(int n, double *x, int incx, double *y, int incy, + const double* sparam) { + using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int, const double *); + static auto func_ptr = LoadSymbol("cublasDrotm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDrotm"); + return func_ptr(n, x, incx, y, incy, sparam); +} + +void CUBLASWINAPI cublasSrotmg (float *sd1, float *sd2, float *sx1, + const float *sy1, float* sparam) { + using FuncPtr = void (CUBLASWINAPI *)(float *, float *, float *, const float *, float *); + static auto func_ptr = LoadSymbol("cublasSrotmg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSrotmg"); + return func_ptr(sd1, sd2, sx1, sy1, sparam); +} + +void CUBLASWINAPI cublasDrotmg (double *sd1, double *sd2, double *sx1, + const double *sy1, double* sparam) { + using FuncPtr = void (CUBLASWINAPI *)(double *, double *, double *, const double *, double *); + static auto func_ptr = LoadSymbol("cublasDrotmg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDrotmg"); + return func_ptr(sd1, sd2, sx1, sy1, sparam); +} + +void CUBLASWINAPI cublasSgemv (char trans, int m, int n, float alpha, + const float *A, int lda, const float *x, int incx, + float beta, float *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, float, const float *, int, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSgemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSgemv"); + return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDgemv (char trans, int m, int n, double alpha, + const double *A, int lda, const double *x, int incx, + double beta, double *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, double, const double *, int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDgemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDgemv"); + return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasCgemv (char trans, int m, int n, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, int incx, + cuComplex beta, cuComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgemv"); + return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZgemv (char trans, int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgemv"); + return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasSgbmv (char trans, int m, int n, int kl, int ku, + float alpha, const float *A, int lda, + const float *x, int incx, float beta, float *y, + int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, float, const float *, int, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSgbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSgbmv"); + return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDgbmv (char trans, int m, int n, int kl, int ku, + double alpha, const double *A, int lda, + const double *x, int incx, double beta, double *y, + int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, double, const double *, int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDgbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDgbmv"); + return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasCgbmv (char trans, int m, int n, int kl, int ku, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *x, int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgbmv"); + return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZgbmv (char trans, int m, int n, int kl, int ku, + cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgbmv"); + return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasStrmv (char uplo, char trans, char diag, int n, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStrmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStrmv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasDtrmv (char uplo, char trans, char diag, int n, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtrmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtrmv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasCtrmv (char uplo, char trans, char diag, int n, + const cuComplex *A, int lda, cuComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtrmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtrmv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasZtrmv (char uplo, char trans, char diag, int n, + const cuDoubleComplex *A, int lda, cuDoubleComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtrmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtrmv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasStbmv (char uplo, char trans, char diag, int n, int k, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStbmv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasDtbmv (char uplo, char trans, char diag, int n, int k, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtbmv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasCtbmv (char uplo, char trans, char diag, int n, int k, + const cuComplex *A, int lda, cuComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtbmv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasZtbmv (char uplo, char trans, char diag, int n, int k, + const cuDoubleComplex *A, int lda, cuDoubleComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtbmv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, const float *AP, float *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasStpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStpmv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, const double *AP, double *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDtpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtpmv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtpmv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasZtpmv(char uplo, char trans, char diag, int n, const cuDoubleComplex *AP, cuDoubleComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtpmv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, const float *A, int lda, float *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStrsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStrsv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, const double *A, int lda, double *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtrsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtrsv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasCtrsv(char uplo, char trans, char diag, int n, const cuComplex *A, int lda, cuComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtrsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtrsv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasZtrsv(char uplo, char trans, char diag, int n, const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtrsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtrsv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, const float *AP, + float *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cublasStpsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStpsv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, const double *AP, double *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cublasDtpsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtpsv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtpsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtpsv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasZtpsv(char uplo, char trans, char diag, int n, const cuDoubleComplex *AP, + cuDoubleComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtpsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtpsv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasStbsv(char uplo, char trans, + char diag, int n, int k, const float *A, + int lda, float *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStbsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStbsv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasDtbsv(char uplo, char trans, + char diag, int n, int k, const double *A, + int lda, double *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtbsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtbsv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasCtbsv(char uplo, char trans, + char diag, int n, int k, const cuComplex *A, + int lda, cuComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtbsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtbsv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasZtbsv(char uplo, char trans, + char diag, int n, int k, const cuDoubleComplex *A, + int lda, cuDoubleComplex *x, int incx) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtbsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtbsv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasSsymv (char uplo, int n, float alpha, const float *A, + int lda, const float *x, int incx, float beta, + float *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSsymv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsymv"); + return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDsymv (char uplo, int n, double alpha, const double *A, + int lda, const double *x, int incx, double beta, + double *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDsymv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsymv"); + return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasChemv (char uplo, int n, cuComplex alpha, const cuComplex *A, + int lda, const cuComplex *x, int incx, cuComplex beta, + cuComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasChemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChemv"); + return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZhemv (char uplo, int n, cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *x, int incx, cuDoubleComplex beta, + cuDoubleComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZhemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhemv"); + return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasSsbmv (char uplo, int n, int k, float alpha, + const float *A, int lda, const float *x, int incx, + float beta, float *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, float, const float *, int, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSsbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsbmv"); + return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDsbmv (char uplo, int n, int k, double alpha, + const double *A, int lda, const double *x, int incx, + double beta, double *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, double, const double *, int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDsbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsbmv"); + return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasChbmv (char uplo, int n, int k, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, int incx, + cuComplex beta, cuComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasChbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChbmv"); + return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZhbmv (char uplo, int n, int k, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZhbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhbmv"); + return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasSspmv(char uplo, int n, float alpha, + const float *AP, const float *x, + int incx, float beta, float *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSspmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSspmv"); + return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDspmv(char uplo, int n, double alpha, + const double *AP, const double *x, + int incx, double beta, double *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDspmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDspmv"); + return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasChpmv(char uplo, int n, cuComplex alpha, + const cuComplex *AP, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasChpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChpmv"); + return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZhpmv(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *AP, const cuDoubleComplex *x, + int incx, cuDoubleComplex beta, cuDoubleComplex *y, int incy) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZhpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhpmv"); + return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasSger (int m, int n, float alpha, const float *x, int incx, + const float *y, int incy, float *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(int, int, float, const float *, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSger"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSger"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasDger (int m, int n, double alpha, const double *x, int incx, + const double *y, int incy, double *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(int, int, double, const double *, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDger"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDger"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasCgeru (int m, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgeru"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgeru"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasCgerc (int m, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgerc"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgerc"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasZgeru (int m, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgeru"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgeru"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasZgerc (int m, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgerc"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgerc"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasSsyr (char uplo, int n, float alpha, const float *x, + int incx, float *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr"); + return func_ptr(uplo, n, alpha, x, incx, A, lda); +} + +void CUBLASWINAPI cublasDsyr (char uplo, int n, double alpha, const double *x, + int incx, double *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr"); + return func_ptr(uplo, n, alpha, x, incx, A, lda); +} + +void CUBLASWINAPI cublasCher (char uplo, int n, float alpha, + const cuComplex *x, int incx, cuComplex *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCher"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCher"); + return func_ptr(uplo, n, alpha, x, incx, A, lda); +} + +void CUBLASWINAPI cublasZher (char uplo, int n, double alpha, + const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZher"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZher"); + return func_ptr(uplo, n, alpha, x, incx, A, lda); +} + +void CUBLASWINAPI cublasSspr (char uplo, int n, float alpha, const float *x, + int incx, float *AP) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, float *); + static auto func_ptr = LoadSymbol("cublasSspr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSspr"); + return func_ptr(uplo, n, alpha, x, incx, AP); +} + +void CUBLASWINAPI cublasDspr (char uplo, int n, double alpha, const double *x, + int incx, double *AP) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, double *); + static auto func_ptr = LoadSymbol("cublasDspr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDspr"); + return func_ptr(uplo, n, alpha, x, incx, AP); +} + +void CUBLASWINAPI cublasChpr (char uplo, int n, float alpha, const cuComplex *x, + int incx, cuComplex *AP) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol("cublasChpr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChpr"); + return func_ptr(uplo, n, alpha, x, incx, AP); +} + +void CUBLASWINAPI cublasZhpr (char uplo, int n, double alpha, const cuDoubleComplex *x, + int incx, cuDoubleComplex *AP) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZhpr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhpr"); + return func_ptr(uplo, n, alpha, x, incx, AP); +} + +void CUBLASWINAPI cublasSsyr2 (char uplo, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *A, + int lda) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasDsyr2 (char uplo, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *A, + int lda) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasCher2 (char uplo, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, cuComplex *A, + int lda) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCher2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCher2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasZher2 (char uplo, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, int incy, cuDoubleComplex *A, + int lda) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZher2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZher2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasSspr2 (char uplo, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *AP) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float *); + static auto func_ptr = LoadSymbol("cublasSspr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSspr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); +} + +void CUBLASWINAPI cublasDspr2 (char uplo, int n, double alpha, + const double *x, int incx, const double *y, + int incy, double *AP) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double *); + static auto func_ptr = LoadSymbol("cublasDspr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDspr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); +} + +void CUBLASWINAPI cublasChpr2 (char uplo, int n, cuComplex alpha, + const cuComplex *x, int incx, const cuComplex *y, + int incy, cuComplex *AP) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol("cublasChpr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChpr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); +} + +void CUBLASWINAPI cublasZhpr2 (char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *AP) { + using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cublasZhpr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhpr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); +} + +void CUBLASWINAPI cublasSgemm (char transa, char transb, int m, int n, int k, + float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, + int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, float, const float *, int, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSgemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSgemm"); + return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasDgemm (char transa, char transb, int m, int n, int k, + double alpha, const double *A, int lda, + const double *B, int ldb, double beta, double *C, + int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, double, const double *, int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDgemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDgemm"); + return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasCgemm (char transa, char transb, int m, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCgemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgemm"); + return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZgemm (char transa, char transb, int m, int n, + int k, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZgemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgemm"); + return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasSsyrk (char uplo, char trans, int n, int k, float alpha, + const float *A, int lda, float beta, float *C, + int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyrk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsyrk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasDsyrk (char uplo, char trans, int n, int k, + double alpha, const double *A, int lda, + double beta, double *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyrk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsyrk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasCsyrk (char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + cuComplex beta, cuComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsyrk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsyrk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasZsyrk (char uplo, char trans, int n, int k, + cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + cuDoubleComplex beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsyrk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZsyrk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasCherk (char uplo, char trans, int n, int k, + float alpha, const cuComplex *A, int lda, + float beta, cuComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const cuComplex *, int, float, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCherk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCherk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasZherk (char uplo, char trans, int n, int k, + double alpha, + const cuDoubleComplex *A, int lda, + double beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const cuDoubleComplex *, int, double, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZherk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZherk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasSsyr2k (char uplo, char trans, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSsyr2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasDsyr2k (char uplo, char trans, int n, int k, + double alpha, const double *A, int lda, + const double *B, int ldb, double beta, + double *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDsyr2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasCsyr2k (char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsyr2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsyr2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZsyr2k (char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, cuDoubleComplex beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsyr2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZsyr2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasCher2k (char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, float beta, + cuComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, float, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCher2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCher2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZher2k (char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, double beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, double, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZher2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZher2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasSsymm (char side, char uplo, int m, int n, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol("cublasSsymm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsymm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasDsymm (char side, char uplo, int m, int n, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol("cublasDsymm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsymm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasCsymm (char side, char uplo, int m, int n, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *B, int ldb, + cuComplex beta, cuComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCsymm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsymm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZsymm (char side, char uplo, int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZsymm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZsymm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasChemm (char side, char uplo, int m, int n, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasChemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChemm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZhemm (char side, char uplo, int m, int n, + cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, cuDoubleComplex beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZhemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhemm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasStrsm (char side, char uplo, char transa, char diag, + int m, int n, float alpha, const float *A, int lda, + float *B, int ldb) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, float, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStrsm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStrsm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasDtrsm (char side, char uplo, char transa, + char diag, int m, int n, double alpha, + const double *A, int lda, double *B, + int ldb) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, double, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtrsm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtrsm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasCtrsm (char side, char uplo, char transa, char diag, + int m, int n, cuComplex alpha, const cuComplex *A, + int lda, cuComplex *B, int ldb) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtrsm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtrsm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasZtrsm (char side, char uplo, char transa, + char diag, int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtrsm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtrsm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasStrmm (char side, char uplo, char transa, char diag, + int m, int n, float alpha, const float *A, int lda, + float *B, int ldb) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, float, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cublasStrmm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStrmm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasDtrmm (char side, char uplo, char transa, + char diag, int m, int n, double alpha, + const double *A, int lda, double *B, + int ldb) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, double, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cublasDtrmm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtrmm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasCtrmm (char side, char uplo, char transa, char diag, + int m, int n, cuComplex alpha, const cuComplex *A, + int lda, cuComplex *B, int ldb) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cublasCtrmm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtrmm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasZtrmm (char side, char uplo, char transa, + char diag, int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, cuDoubleComplex *B, + int ldb) { + using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cublasZtrmm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtrmm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cublas_stub.cc b/tensorflow/stream_executor/cuda/cublas_stub.cc index b7f8be717f5..5c1b666bcef 100644 --- a/tensorflow/stream_executor/cuda/cublas_stub.cc +++ b/tensorflow/stream_executor/cuda/cublas_stub.cc @@ -57,11 +57,16 @@ cublasStatus_t GetSymbolNotFoundError() { return CUBLAS_STATUS_INTERNAL_ERROR; } typedef enum {} cublasMath_t; #endif -// Parameter constness changed in cuBLAS 9.2 #if CUDA_VERSION < 9020 #include "tensorflow/stream_executor/cuda/cublas_9_0.inc" -#elif CUDA_VERSION < 10010 +#elif CUDA_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cublas_10_0.inc" -#else +#elif CUDA_VERSION == 10010 #include "tensorflow/stream_executor/cuda/cublas_10_1.inc" +#elif CUDA_VERSION == 10020 +#include "tensorflow/stream_executor/cuda/cublas_10_2.inc" +#elif CUDA_VERSION == 11000 +#include "tensorflow/stream_executor/cuda/cublas_11_0.inc" +#else +#error "We have no wrapper for this version." #endif diff --git a/tensorflow/stream_executor/cuda/cuda_10_0.inc b/tensorflow/stream_executor/cuda/cuda_10_0.inc index 26c272d683c..6f26cfb92d1 100644 --- a/tensorflow/stream_executor/cuda/cuda_10_0.inc +++ b/tensorflow/stream_executor/cuda/cuda_10_0.inc @@ -1,6 +1,7 @@ // Auto-generated, do not edit. extern "C" { + CUresult CUDAAPI cuGetErrorString(CUresult error, const char **pStr) { using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); static auto func_ptr = LoadSymbol("cuGetErrorString"); @@ -1024,6 +1025,28 @@ CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, return func_ptr(hStream, callback, userData, flags); } +CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamBeginCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUgraph *); + static auto func_ptr = LoadSymbol("cuStreamEndCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, phGraph); +} + +CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, + CUstreamCaptureStatus *captureStatus) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUstreamCaptureStatus *); + static auto func_ptr = LoadSymbol("cuStreamIsCapturing"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, captureStatus); +} + CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, size_t length, unsigned int flags) { using FuncPtr = @@ -1385,6 +1408,284 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuParamSetTexRef(CUfunction hfunc, return func_ptr(hfunc, texunit, hTexRef); } +CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph *, unsigned int); + static auto func_ptr = LoadSymbol("cuGraphCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraph, flags); +} + +CUresult CUDAAPI cuGraphAddKernelNode( + CUgraphNode *phGraphNode, CUgraph hGraph, CUgraphNode *dependencies, + size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, CUgraphNode *, + size_t, const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphAddKernelNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); +} + +CUresult CUDAAPI cuGraphKernelNodeGetParams( + CUgraphNode hNode, CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphKernelNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphKernelNodeSetParams( + CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddMemcpyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_MEMCPY3D *copyParams, + CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, CUgraphNode *, + size_t, const CUDA_MEMCPY3D *, CUcontext); + static auto func_ptr = LoadSymbol("cuGraphAddMemcpyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + copyParams, ctx); +} + +CUresult CUDAAPI cuGraphMemcpyNodeGetParams(CUgraphNode hNode, + CUDA_MEMCPY3D *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuGraphMemcpyNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphMemcpyNodeSetParams(CUgraphNode hNode, + const CUDA_MEMCPY3D *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, const CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuGraphMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddMemsetNode( + CUgraphNode *phGraphNode, CUgraph hGraph, CUgraphNode *dependencies, + size_t numDependencies, const CUDA_MEMSET_NODE_PARAMS *memsetParams, + CUcontext ctx) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, CUgraphNode *, size_t, + const CUDA_MEMSET_NODE_PARAMS *, CUcontext); + static auto func_ptr = LoadSymbol("cuGraphAddMemsetNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + memsetParams, ctx); +} + +CUresult CUDAAPI cuGraphMemsetNodeGetParams( + CUgraphNode hNode, CUDA_MEMSET_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_MEMSET_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphMemsetNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphMemsetNodeSetParams( + CUgraphNode hNode, const CUDA_MEMSET_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_MEMSET_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddHostNode(CUgraphNode *phGraphNode, CUgraph hGraph, + CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, CUgraphNode *, + size_t, const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphAddHostNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); +} + +CUresult CUDAAPI cuGraphHostNodeGetParams(CUgraphNode hNode, + CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphHostNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphHostNodeSetParams( + CUgraphNode hNode, const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, + CUgraph hGraph, + CUgraphNode *dependencies, + size_t numDependencies, + CUgraph childGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, CUgraphNode *, + size_t, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphAddChildGraphNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + childGraph); +} + +CUresult CUDAAPI cuGraphChildGraphNodeGetGraph(CUgraphNode hNode, + CUgraph *phGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraph *); + static auto func_ptr = LoadSymbol("cuGraphChildGraphNodeGetGraph"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, phGraph); +} + +CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + CUgraphNode *dependencies, + size_t numDependencies) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphAddEmptyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies); +} + +CUresult CUDAAPI cuGraphClone(CUgraph *phGraphClone, CUgraph originalGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph *, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphClone, originalGraph); +} + +CUresult CUDAAPI cuGraphNodeFindInClone(CUgraphNode *phNode, + CUgraphNode hOriginalNode, + CUgraph hClonedGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraphNode, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphNodeFindInClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phNode, hOriginalNode, hClonedGraph); +} + +CUresult CUDAAPI cuGraphNodeGetType(CUgraphNode hNode, CUgraphNodeType *type) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNodeType *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, type); +} + +CUresult CUDAAPI cuGraphGetNodes(CUgraph hGraph, CUgraphNode *nodes, + size_t *numNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, nodes, numNodes); +} + +CUresult CUDAAPI cuGraphGetRootNodes(CUgraph hGraph, CUgraphNode *rootNodes, + size_t *numRootNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetRootNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, rootNodes, numRootNodes); +} + +CUresult CUDAAPI cuGraphGetEdges(CUgraph hGraph, CUgraphNode *from, + CUgraphNode *to, size_t *numEdges) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetEdges"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numEdges); +} + +CUresult CUDAAPI cuGraphNodeGetDependencies(CUgraphNode hNode, + CUgraphNode *dependencies, + size_t *numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, dependencies, numDependencies); +} + +CUresult CUDAAPI cuGraphNodeGetDependentNodes(CUgraphNode hNode, + CUgraphNode *dependentNodes, + size_t *numDependentNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetDependentNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, dependentNodes, numDependentNodes); +} + +CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, CUgraphNode *from, + CUgraphNode *to, + size_t numDependencies) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphAddDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numDependencies); +} + +CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, CUgraphNode *from, + CUgraphNode *to, + size_t numDependencies) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphRemoveDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numDependencies); +} + +CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode); + static auto func_ptr = LoadSymbol("cuGraphDestroyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode); +} + +CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, + CUgraphNode *phErrorNode, char *logBuffer, + size_t bufferSize) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec *, CUgraph, CUgraphNode *, + char *, size_t); + static auto func_ptr = LoadSymbol("cuGraphInstantiate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphExec, hGraph, phErrorNode, logBuffer, bufferSize); +} + +CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraphExec, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUstream); + static auto func_ptr = LoadSymbol("cuGraphLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hStream); +} + +CUresult CUDAAPI cuGraphExecDestroy(CUgraphExec hGraphExec) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec); + static auto func_ptr = LoadSymbol("cuGraphExecDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec); +} + +CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph); + static auto func_ptr = LoadSymbol("cuGraphDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph); +} + CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessor( int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize) { using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction, int, size_t); diff --git a/tensorflow/stream_executor/cuda/cuda_10_1.inc b/tensorflow/stream_executor/cuda/cuda_10_1.inc new file mode 100644 index 00000000000..d35035799a7 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_10_1.inc @@ -0,0 +1,2166 @@ +// Auto-generated, do not edit. + +extern "C" { + +CUresult CUDAAPI cuGetErrorString(CUresult error, const char **pStr) { + using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); + static auto func_ptr = LoadSymbol("cuGetErrorString"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(error, pStr); +} + +CUresult CUDAAPI cuGetErrorName(CUresult error, const char **pStr) { + using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); + static auto func_ptr = LoadSymbol("cuGetErrorName"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(error, pStr); +} + +CUresult CUDAAPI cuInit(unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int); + static auto func_ptr = LoadSymbol("cuInit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(Flags); +} + +CUresult CUDAAPI cuDriverGetVersion(int *driverVersion) { + using FuncPtr = CUresult(CUDAAPI *)(int *); + static auto func_ptr = LoadSymbol("cuDriverGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(driverVersion); +} + +CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *, int); + static auto func_ptr = LoadSymbol("cuDeviceGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device, ordinal); +} + +CUresult CUDAAPI cuDeviceGetCount(int *count) { + using FuncPtr = CUresult(CUDAAPI *)(int *); + static auto func_ptr = LoadSymbol("cuDeviceGetCount"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count); +} + +CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(char *, int, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetName"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(name, len, dev); +} + +CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUuuid *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetUuid"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(uuid, dev); +} + +CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceTotalMem_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(bytes, dev); +} + +CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUdevice_attribute, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pi, attrib, dev); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevprop *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetProperties"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(prop, dev); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceComputeCapability(int *major, + int *minor, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceComputeCapability"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(major, minor, dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxRetain"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxRelease"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice, unsigned int); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxSetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, flags); +} + +CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, + int *active) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice, unsigned int *, int *); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxGetState"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, flags, active); +} + +CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxReset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev); +} + +CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, unsigned int, CUdevice); + static auto func_ptr = LoadSymbol("cuCtxCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, flags, dev); +} + +CUresult CUDAAPI cuCtxDestroy(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxPushCurrent_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *); + static auto func_ptr = LoadSymbol("cuCtxPopCurrent_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx); +} + +CUresult CUDAAPI cuCtxSetCurrent(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxSetCurrent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxGetCurrent(CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *); + static auto func_ptr = LoadSymbol("cuCtxGetCurrent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx); +} + +CUresult CUDAAPI cuCtxGetDevice(CUdevice *device) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *); + static auto func_ptr = LoadSymbol("cuCtxGetDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device); +} + +CUresult CUDAAPI cuCtxGetFlags(unsigned int *flags) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *); + static auto func_ptr = LoadSymbol("cuCtxGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags); +} + +CUresult CUDAAPI cuCtxSynchronize(void) { + using FuncPtr = CUresult(CUDAAPI *)(); + static auto func_ptr = LoadSymbol("cuCtxSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +CUresult CUDAAPI cuCtxSetLimit(CUlimit limit, size_t value) { + using FuncPtr = CUresult(CUDAAPI *)(CUlimit, size_t); + static auto func_ptr = LoadSymbol("cuCtxSetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(limit, value); +} + +CUresult CUDAAPI cuCtxGetLimit(size_t *pvalue, CUlimit limit) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUlimit); + static auto func_ptr = LoadSymbol("cuCtxGetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pvalue, limit); +} + +CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunc_cache *); + static auto func_ptr = LoadSymbol("cuCtxGetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pconfig); +} + +CUresult CUDAAPI cuCtxSetCacheConfig(CUfunc_cache config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunc_cache); + static auto func_ptr = LoadSymbol("cuCtxSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +CUresult CUDAAPI cuCtxGetSharedMemConfig(CUsharedconfig *pConfig) { + using FuncPtr = CUresult(CUDAAPI *)(CUsharedconfig *); + static auto func_ptr = LoadSymbol("cuCtxGetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pConfig); +} + +CUresult CUDAAPI cuCtxSetSharedMemConfig(CUsharedconfig config) { + using FuncPtr = CUresult(CUDAAPI *)(CUsharedconfig); + static auto func_ptr = LoadSymbol("cuCtxSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +CUresult CUDAAPI cuCtxGetApiVersion(CUcontext ctx, unsigned int *version) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext, unsigned int *); + static auto func_ptr = LoadSymbol("cuCtxGetApiVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx, version); +} + +CUresult CUDAAPI cuCtxGetStreamPriorityRange(int *leastPriority, + int *greatestPriority) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *); + static auto func_ptr = LoadSymbol("cuCtxGetStreamPriorityRange"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(leastPriority, greatestPriority); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxAttach(CUcontext *pctx, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, unsigned int); + static auto func_ptr = LoadSymbol("cuCtxAttach"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxDetach(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDetach"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuModuleLoad(CUmodule *module, const char *fname) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const char *); + static auto func_ptr = LoadSymbol("cuModuleLoad"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, fname); +} + +CUresult CUDAAPI cuModuleLoadData(CUmodule *module, const void *image) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *); + static auto func_ptr = LoadSymbol("cuModuleLoadData"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, image); +} + +CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module, const void *image, + unsigned int numOptions, + CUjit_option *options, + void **optionValues) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *, unsigned int, + CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuModuleLoadDataEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, image, numOptions, options, optionValues); +} + +CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *); + static auto func_ptr = LoadSymbol("cuModuleLoadFatBinary"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, fatCubin); +} + +CUresult CUDAAPI cuModuleUnload(CUmodule hmod) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule); + static auto func_ptr = LoadSymbol("cuModuleUnload"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hmod); +} + +CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetFunction"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, hmod, name); +} + +CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, size_t *bytes, + CUmodule hmod, const char *name) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetGlobal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytes, hmod, name); +} + +CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetTexRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexRef, hmod, name); +} + +CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfref *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetSurfRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfRef, hmod, name); +} + +CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, + void **optionValues, CUlinkState *stateOut) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUjit_option *, void **, CUlinkState *); + static auto func_ptr = LoadSymbol("cuLinkCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numOptions, options, optionValues, stateOut); +} + +CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, + void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, + void **optionValues) { + using FuncPtr = + CUresult(CUDAAPI *)(CUlinkState, CUjitInputType, void *, size_t, + const char *, unsigned int, CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuLinkAddData_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, type, data, size, name, numOptions, options, + optionValues); +} + +CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, + const char *path, unsigned int numOptions, + CUjit_option *options, void **optionValues) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState, CUjitInputType, const char *, + unsigned int, CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuLinkAddFile_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDAAPI cuLinkComplete(CUlinkState state, void **cubinOut, + size_t *sizeOut) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState, void **, size_t *); + static auto func_ptr = LoadSymbol("cuLinkComplete"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, cubinOut, sizeOut); +} + +CUresult CUDAAPI cuLinkDestroy(CUlinkState state) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState); + static auto func_ptr = LoadSymbol("cuLinkDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state); +} + +CUresult CUDAAPI cuMemGetInfo(size_t *free, size_t *total) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, size_t *); + static auto func_ptr = LoadSymbol("cuMemGetInfo_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(free, total); +} + +CUresult CUDAAPI cuMemAlloc(CUdeviceptr *dptr, size_t bytesize) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t); + static auto func_ptr = LoadSymbol("cuMemAlloc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytesize); +} + +CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, size_t *pPitch, + size_t WidthInBytes, size_t Height, + unsigned int ElementSizeBytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, size_t, size_t, + unsigned int); + static auto func_ptr = LoadSymbol("cuMemAllocPitch_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, pPitch, WidthInBytes, Height, ElementSizeBytes); +} + +CUresult CUDAAPI cuMemFree(CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr); + static auto func_ptr = LoadSymbol("cuMemFree_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr); +} + +CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr *pbase, size_t *psize, + CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuMemGetAddressRange_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pbase, psize, dptr); +} + +CUresult CUDAAPI cuMemAllocHost(void **pp, size_t bytesize) { + using FuncPtr = CUresult(CUDAAPI *)(void **, size_t); + static auto func_ptr = LoadSymbol("cuMemAllocHost_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pp, bytesize); +} + +CUresult CUDAAPI cuMemFreeHost(void *p) { + using FuncPtr = CUresult(CUDAAPI *)(void *); + static auto func_ptr = LoadSymbol("cuMemFreeHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +CUresult CUDAAPI cuMemHostAlloc(void **pp, size_t bytesize, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(void **, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostAlloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pp, bytesize, Flags); +} + +CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostGetDevicePointer_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, p, Flags); +} + +CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *, void *); + static auto func_ptr = LoadSymbol("cuMemHostGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, p); +} + +CUresult CUDAAPI cuMemAllocManaged(CUdeviceptr *dptr, size_t bytesize, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemAllocManaged"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytesize, flags); +} + +CUresult CUDAAPI cuDeviceGetByPCIBusId(CUdevice *dev, const char *pciBusId) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *, const char *); + static auto func_ptr = LoadSymbol("cuDeviceGetByPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, pciBusId); +} + +CUresult CUDAAPI cuDeviceGetPCIBusId(char *pciBusId, int len, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(char *, int, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pciBusId, len, dev); +} + +CUresult CUDAAPI cuIpcGetEventHandle(CUipcEventHandle *pHandle, CUevent event) { + using FuncPtr = CUresult(CUDAAPI *)(CUipcEventHandle *, CUevent); + static auto func_ptr = LoadSymbol("cuIpcGetEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, event); +} + +CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, + CUipcEventHandle handle) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent *, CUipcEventHandle); + static auto func_ptr = LoadSymbol("cuIpcOpenEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phEvent, handle); +} + +CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUipcMemHandle *, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuIpcGetMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, dptr); +} + +CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, + unsigned int Flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, CUipcMemHandle, unsigned int); + static auto func_ptr = LoadSymbol("cuIpcOpenMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, handle, Flags); +} + +CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr); + static auto func_ptr = LoadSymbol("cuIpcCloseMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr); +} + +CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(void *, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostRegister_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p, bytesize, Flags); +} + +CUresult CUDAAPI cuMemHostUnregister(void *p) { + using FuncPtr = CUresult(CUDAAPI *)(void *); + static auto func_ptr = LoadSymbol("cuMemHostUnregister"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, ByteCount); +} + +CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUcontext, CUdeviceptr, + CUcontext, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstContext, srcDevice, srcContext, ByteCount); +} + +CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, const void *, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyHtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcHost, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoH_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, size_t dstOffset, + CUdeviceptr srcDevice, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr dstDevice, CUarray srcArray, + size_t srcOffset, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyHtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcHost, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, size_t srcOffset, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoH_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, size_t dstOffset, + CUarray srcArray, size_t srcOffset, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *); + static auto func_ptr = LoadSymbol("cuMemcpy2D_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *); + static auto func_ptr = LoadSymbol("cuMemcpy2DUnaligned_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuMemcpy3D_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D_PEER *); + static auto func_ptr = LoadSymbol("cuMemcpy3DPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUcontext, CUdeviceptr, + CUcontext, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstContext, srcDevice, srcContext, ByteCount, + hStream); +} + +CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, const void *, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyHtoDAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcHost, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyDtoHAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcDevice, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyDtoDAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcDevice, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray, size_t, const void *, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyHtoAAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcHost, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, + size_t srcOffset, size_t ByteCount, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(void *, CUarray, size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyAtoHAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcArray, srcOffset, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D *pCopy, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy2DAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy3DAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, + CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D_PEER *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy3DPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemsetD8(CUdeviceptr dstDevice, unsigned char uc, size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned char, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD8_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, uc, N); +} + +CUresult CUDAAPI cuMemsetD16(CUdeviceptr dstDevice, unsigned short us, + size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned short, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD16_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, us, N); +} + +CUresult CUDAAPI cuMemsetD32(CUdeviceptr dstDevice, unsigned int ui, size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned int, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD32_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, ui, N); +} + +CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned char, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D8_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, uc, Width, Height); +} + +CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned short, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D16_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, us, Width, Height); +} + +CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned int, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D32_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, ui, Width, Height); +} + +CUresult CUDAAPI cuMemsetD8Async(CUdeviceptr dstDevice, unsigned char uc, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned char, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD8Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, uc, N, hStream); +} + +CUresult CUDAAPI cuMemsetD16Async(CUdeviceptr dstDevice, unsigned short us, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned short, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD16Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, us, N, hStream); +} + +CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned int, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD32Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, ui, N, hStream); +} + +CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned char, + size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D8Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, uc, Width, Height, hStream); +} + +CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned short, + size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D16Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, us, Width, Height, hStream); +} + +CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned int, size_t, + size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D32Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, ui, Width, Height, hStream); +} + +CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, + const CUDA_ARRAY_DESCRIPTOR *pAllocateArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, const CUDA_ARRAY_DESCRIPTOR *); + static auto func_ptr = LoadSymbol("cuArrayCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pAllocateArray); +} + +CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, + CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_ARRAY_DESCRIPTOR *, CUarray); + static auto func_ptr = LoadSymbol("cuArrayGetDescriptor_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArrayDescriptor, hArray); +} + +CUresult CUDAAPI cuArrayDestroy(CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray); + static auto func_ptr = LoadSymbol("cuArrayDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hArray); +} + +CUresult CUDAAPI cuArray3DCreate( + CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR *pAllocateArray) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray *, const CUDA_ARRAY3D_DESCRIPTOR *); + static auto func_ptr = LoadSymbol("cuArray3DCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pAllocateArray); +} + +CUresult CUDAAPI cuArray3DGetDescriptor( + CUDA_ARRAY3D_DESCRIPTOR *pArrayDescriptor, CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_ARRAY3D_DESCRIPTOR *, CUarray); + static auto func_ptr = LoadSymbol("cuArray3DGetDescriptor_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArrayDescriptor, hArray); +} + +CUresult CUDAAPI +cuMipmappedArrayCreate(CUmipmappedArray *pHandle, + const CUDA_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc, + unsigned int numMipmapLevels) { + using FuncPtr = CUresult(CUDAAPI *)( + CUmipmappedArray *, const CUDA_ARRAY3D_DESCRIPTOR *, unsigned int); + static auto func_ptr = LoadSymbol("cuMipmappedArrayCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pMipmappedArrayDesc, numMipmapLevels); +} + +CUresult CUDAAPI cuMipmappedArrayGetLevel(CUarray *pLevelArray, + CUmipmappedArray hMipmappedArray, + unsigned int level) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray *, CUmipmappedArray, unsigned int); + static auto func_ptr = LoadSymbol("cuMipmappedArrayGetLevel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pLevelArray, hMipmappedArray, level); +} + +CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray); + static auto func_ptr = LoadSymbol("cuMipmappedArrayDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hMipmappedArray); +} + +CUresult CUDAAPI cuPointerGetAttribute(void *data, + CUpointer_attribute attribute, + CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUpointer_attribute, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, attribute, ptr); +} + +CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, + CUdevice dstDevice, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, CUdevice, CUstream); + static auto func_ptr = LoadSymbol("cuMemPrefetchAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, dstDevice, hStream); +} + +CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, + CUmem_advise advice, CUdevice device) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, CUmem_advise, CUdevice); + static auto func_ptr = LoadSymbol("cuMemAdvise"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, advice, device); +} + +CUresult CUDAAPI cuMemRangeGetAttribute(void *data, size_t dataSize, + CUmem_range_attribute attribute, + CUdeviceptr devPtr, size_t count) { + using FuncPtr = CUresult(CUDAAPI *)(void *, size_t, CUmem_range_attribute, + CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemRangeGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSize, attribute, devPtr, count); +} + +CUresult CUDAAPI cuMemRangeGetAttributes(void **data, size_t *dataSizes, + CUmem_range_attribute *attributes, + size_t numAttributes, + CUdeviceptr devPtr, size_t count) { + using FuncPtr = CUresult(CUDAAPI *)( + void **, size_t *, CUmem_range_attribute *, size_t, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemRangeGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSizes, attributes, numAttributes, devPtr, count); +} + +CUresult CUDAAPI cuPointerSetAttribute(const void *value, + CUpointer_attribute attribute, + CUdeviceptr ptr) { + using FuncPtr = + CUresult(CUDAAPI *)(const void *, CUpointer_attribute, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attribute, ptr); +} + +CUresult CUDAAPI cuPointerGetAttributes(unsigned int numAttributes, + CUpointer_attribute *attributes, + void **data, CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int, CUpointer_attribute *, + void **, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numAttributes, attributes, data, ptr); +} + +CUresult CUDAAPI cuStreamCreate(CUstream *phStream, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phStream, Flags); +} + +CUresult CUDAAPI cuStreamCreateWithPriority(CUstream *phStream, + unsigned int flags, int priority) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream *, unsigned int, int); + static auto func_ptr = LoadSymbol("cuStreamCreateWithPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phStream, flags, priority); +} + +CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, int *); + static auto func_ptr = LoadSymbol("cuStreamGetPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, priority); +} + +CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, unsigned int *); + static auto func_ptr = LoadSymbol("cuStreamGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, flags); +} + +CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUcontext *); + static auto func_ptr = LoadSymbol("cuStreamGetCtx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, pctx); +} + +CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUevent, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitEvent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, hEvent, Flags); +} + +CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, + CUstreamCallback callback, void *userData, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamCallback, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamAddCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, callback, userData, flags); +} + +CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream, + CUstreamCaptureMode mode) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUstreamCaptureMode); + static auto func_ptr = LoadSymbol("cuStreamBeginCapture_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, mode); +} + +CUresult CUDAAPI cuThreadExchangeStreamCaptureMode(CUstreamCaptureMode *mode) { + using FuncPtr = CUresult(CUDAAPI *)(CUstreamCaptureMode *); + static auto func_ptr = + LoadSymbol("cuThreadExchangeStreamCaptureMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mode); +} + +CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUgraph *); + static auto func_ptr = LoadSymbol("cuStreamEndCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, phGraph); +} + +CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, + CUstreamCaptureStatus *captureStatus) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUstreamCaptureStatus *); + static auto func_ptr = LoadSymbol("cuStreamIsCapturing"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, captureStatus); +} + +CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, + CUstreamCaptureStatus *captureStatus, + cuuint64_t *id) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamCaptureStatus *, cuuint64_t *); + static auto func_ptr = LoadSymbol("cuStreamGetCaptureInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, captureStatus, id); +} + +CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, + size_t length, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamAttachMemAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, dptr, length, flags); +} + +CUresult CUDAAPI cuStreamQuery(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamSynchronize(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamDestroy(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuEventCreate(CUevent *phEvent, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent *, unsigned int); + static auto func_ptr = LoadSymbol("cuEventCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phEvent, Flags); +} + +CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent, CUstream); + static auto func_ptr = LoadSymbol("cuEventRecord"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent, hStream); +} + +CUresult CUDAAPI cuEventQuery(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventSynchronize(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventDestroy(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, + CUevent hEnd) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUevent, CUevent); + static auto func_ptr = LoadSymbol("cuEventElapsedTime"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pMilliseconds, hStart, hEnd); +} + +CUresult CUDAAPI +cuImportExternalMemory(CUexternalMemory *extMem_out, + const CUDA_EXTERNAL_MEMORY_HANDLE_DESC *memHandleDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalMemory *, + const CUDA_EXTERNAL_MEMORY_HANDLE_DESC *); + static auto func_ptr = LoadSymbol("cuImportExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem_out, memHandleDesc); +} + +CUresult CUDAAPI cuExternalMemoryGetMappedBuffer( + CUdeviceptr *devPtr, CUexternalMemory extMem, + const CUDA_EXTERNAL_MEMORY_BUFFER_DESC *bufferDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, CUexternalMemory, + const CUDA_EXTERNAL_MEMORY_BUFFER_DESC *); + static auto func_ptr = LoadSymbol("cuExternalMemoryGetMappedBuffer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, extMem, bufferDesc); +} + +CUresult CUDAAPI cuExternalMemoryGetMappedMipmappedArray( + CUmipmappedArray *mipmap, CUexternalMemory extMem, + const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *mipmapDesc) { + using FuncPtr = + CUresult(CUDAAPI *)(CUmipmappedArray *, CUexternalMemory, + const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *); + static auto func_ptr = + LoadSymbol("cuExternalMemoryGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmap, extMem, mipmapDesc); +} + +CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalMemory); + static auto func_ptr = LoadSymbol("cuDestroyExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem); +} + +CUresult CUDAAPI cuImportExternalSemaphore( + CUexternalSemaphore *extSem_out, + const CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC *semHandleDesc) { + using FuncPtr = CUresult(CUDAAPI *)( + CUexternalSemaphore *, const CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC *); + static auto func_ptr = LoadSymbol("cuImportExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem_out, semHandleDesc); +} + +CUresult CUDAAPI cuSignalExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream) { + using FuncPtr = CUresult(CUDAAPI *)( + const CUexternalSemaphore *, + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *, unsigned int, CUstream); + static auto func_ptr = LoadSymbol("cuSignalExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +CUresult CUDAAPI cuWaitExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream) { + using FuncPtr = CUresult(CUDAAPI *)( + const CUexternalSemaphore *, const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *, + unsigned int, CUstream); + static auto func_ptr = LoadSymbol("cuWaitExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalSemaphore); + static auto func_ptr = LoadSymbol("cuDestroyExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem); +} + +CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint32_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitValue32"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint64_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitValue64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint32_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWriteValue32"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint64_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWriteValue64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, + CUstreamBatchMemOpParams *paramArray, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, unsigned int, + CUstreamBatchMemOpParams *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamBatchMemOp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, count, paramArray, flags); +} + +CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, + CUfunction hfunc) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction_attribute, CUfunction); + static auto func_ptr = LoadSymbol("cuFuncGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pi, attrib, hfunc); +} + +CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, + CUfunction_attribute attrib, int value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUfunction_attribute, int); + static auto func_ptr = LoadSymbol("cuFuncSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, attrib, value); +} + +CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUfunc_cache); + static auto func_ptr = LoadSymbol("cuFuncSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, config); +} + +CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, + CUsharedconfig config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUsharedconfig); + static auto func_ptr = LoadSymbol("cuFuncSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, config); +} + +CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams, void **extra) { + using FuncPtr = CUresult(CUDAAPI *)( + CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **, void **); + static auto func_ptr = LoadSymbol("cuLaunchKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDAAPI cuLaunchCooperativeKernel( + CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams) { + using FuncPtr = CUresult(CUDAAPI *)( + CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **); + static auto func_ptr = LoadSymbol("cuLaunchCooperativeKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice( + CUDA_LAUNCH_PARAMS *launchParamsList, unsigned int numDevices, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUDA_LAUNCH_PARAMS *, unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol("cuLaunchCooperativeKernelMultiDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(launchParamsList, numDevices, flags); +} + +CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, + void *userData) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUhostFn, void *); + static auto func_ptr = LoadSymbol("cuLaunchHostFunc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, fn, userData); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetBlockShape(CUfunction hfunc, int x, + int y, int z) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int, int); + static auto func_ptr = LoadSymbol("cuFuncSetBlockShape"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, x, y, z); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetSharedSize(CUfunction hfunc, + unsigned int bytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, unsigned int); + static auto func_ptr = LoadSymbol("cuFuncSetSharedSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, bytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetSize(CUfunction hfunc, + unsigned int numbytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSetSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, numbytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSeti(CUfunction hfunc, int offset, + unsigned int value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSeti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, value); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetf(CUfunction hfunc, int offset, + float value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, float); + static auto func_ptr = LoadSymbol("cuParamSetf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, value); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetv(CUfunction hfunc, int offset, + void *ptr, + unsigned int numbytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSetv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, ptr, numbytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunch(CUfunction f) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction); + static auto func_ptr = LoadSymbol("cuLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGrid(CUfunction f, int grid_width, + int grid_height) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int); + static auto func_ptr = LoadSymbol("cuLaunchGrid"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, grid_width, grid_height); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGridAsync(CUfunction f, + int grid_width, + int grid_height, + CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int, CUstream); + static auto func_ptr = LoadSymbol("cuLaunchGridAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, grid_width, grid_height, hStream); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetTexRef(CUfunction hfunc, + int texunit, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, CUtexref); + static auto func_ptr = LoadSymbol("cuParamSetTexRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, texunit, hTexRef); +} + +CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph *, unsigned int); + static auto func_ptr = LoadSymbol("cuGraphCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraph, flags); +} + +CUresult CUDAAPI cuGraphAddKernelNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphAddKernelNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); +} + +CUresult CUDAAPI cuGraphKernelNodeGetParams( + CUgraphNode hNode, CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphKernelNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphKernelNodeSetParams( + CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddMemcpyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_MEMCPY3D *copyParams, + CUcontext ctx) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_MEMCPY3D *, CUcontext); + static auto func_ptr = LoadSymbol("cuGraphAddMemcpyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + copyParams, ctx); +} + +CUresult CUDAAPI cuGraphMemcpyNodeGetParams(CUgraphNode hNode, + CUDA_MEMCPY3D *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuGraphMemcpyNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphMemcpyNodeSetParams(CUgraphNode hNode, + const CUDA_MEMCPY3D *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, const CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuGraphMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddMemsetNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_MEMSET_NODE_PARAMS *memsetParams, + CUcontext ctx) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_MEMSET_NODE_PARAMS *, CUcontext); + static auto func_ptr = LoadSymbol("cuGraphAddMemsetNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + memsetParams, ctx); +} + +CUresult CUDAAPI cuGraphMemsetNodeGetParams( + CUgraphNode hNode, CUDA_MEMSET_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_MEMSET_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphMemsetNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphMemsetNodeSetParams( + CUgraphNode hNode, const CUDA_MEMSET_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_MEMSET_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddHostNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphAddHostNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); +} + +CUresult CUDAAPI cuGraphHostNodeGetParams(CUgraphNode hNode, + CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphHostNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphHostNodeSetParams( + CUgraphNode hNode, const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, + CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + CUgraph childGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, + const CUgraphNode *, size_t, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphAddChildGraphNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + childGraph); +} + +CUresult CUDAAPI cuGraphChildGraphNodeGetGraph(CUgraphNode hNode, + CUgraph *phGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraph *); + static auto func_ptr = LoadSymbol("cuGraphChildGraphNodeGetGraph"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, phGraph); +} + +CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphAddEmptyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies); +} + +CUresult CUDAAPI cuGraphClone(CUgraph *phGraphClone, CUgraph originalGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph *, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphClone, originalGraph); +} + +CUresult CUDAAPI cuGraphNodeFindInClone(CUgraphNode *phNode, + CUgraphNode hOriginalNode, + CUgraph hClonedGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraphNode, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphNodeFindInClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phNode, hOriginalNode, hClonedGraph); +} + +CUresult CUDAAPI cuGraphNodeGetType(CUgraphNode hNode, CUgraphNodeType *type) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNodeType *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, type); +} + +CUresult CUDAAPI cuGraphGetNodes(CUgraph hGraph, CUgraphNode *nodes, + size_t *numNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, nodes, numNodes); +} + +CUresult CUDAAPI cuGraphGetRootNodes(CUgraph hGraph, CUgraphNode *rootNodes, + size_t *numRootNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetRootNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, rootNodes, numRootNodes); +} + +CUresult CUDAAPI cuGraphGetEdges(CUgraph hGraph, CUgraphNode *from, + CUgraphNode *to, size_t *numEdges) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetEdges"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numEdges); +} + +CUresult CUDAAPI cuGraphNodeGetDependencies(CUgraphNode hNode, + CUgraphNode *dependencies, + size_t *numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, dependencies, numDependencies); +} + +CUresult CUDAAPI cuGraphNodeGetDependentNodes(CUgraphNode hNode, + CUgraphNode *dependentNodes, + size_t *numDependentNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetDependentNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, dependentNodes, numDependentNodes); +} + +CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, const CUgraphNode *from, + const CUgraphNode *to, + size_t numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, const CUgraphNode *, + const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphAddDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numDependencies); +} + +CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, + const CUgraphNode *from, + const CUgraphNode *to, + size_t numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, const CUgraphNode *, + const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphRemoveDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numDependencies); +} + +CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode); + static auto func_ptr = LoadSymbol("cuGraphDestroyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode); +} + +CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, + CUgraphNode *phErrorNode, char *logBuffer, + size_t bufferSize) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec *, CUgraph, CUgraphNode *, + char *, size_t); + static auto func_ptr = LoadSymbol("cuGraphInstantiate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphExec, hGraph, phErrorNode, logBuffer, bufferSize); +} + +CUresult CUDAAPI +cuGraphExecKernelNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraphNode, + const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphExecKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraphExec, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUstream); + static auto func_ptr = LoadSymbol("cuGraphLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hStream); +} + +CUresult CUDAAPI cuGraphExecDestroy(CUgraphExec hGraphExec) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec); + static auto func_ptr = LoadSymbol("cuGraphExecDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec); +} + +CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph); + static auto func_ptr = LoadSymbol("cuGraphDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph); +} + +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessor( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction, int, size_t); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxActiveBlocksPerMultiprocessor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize); +} + +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(int *, CUfunction, int, size_t, unsigned int); + static auto func_ptr = LoadSymbol( + "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize, flags); +} + +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSize( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *, CUfunction, + CUoccupancyB2DSize, size_t, int); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxPotentialBlockSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(minGridSize, blockSize, func, blockSizeToDynamicSMemSize, + dynamicSMemSize, blockSizeLimit); +} + +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)( + int *, int *, CUfunction, CUoccupancyB2DSize, size_t, int, unsigned int); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxPotentialBlockSizeWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(minGridSize, blockSize, func, blockSizeToDynamicSMemSize, + dynamicSMemSize, blockSizeLimit, flags); +} + +CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUarray, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, hArray, Flags); +} + +CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, + CUmipmappedArray hMipmappedArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUmipmappedArray, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, hMipmappedArray, Flags); +} + +CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexref hTexRef, + CUdeviceptr dptr, size_t bytes) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUtexref, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuTexRefSetAddress_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ByteOffset, hTexRef, dptr, bytes); +} + +CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, + const CUDA_ARRAY_DESCRIPTOR *desc, + CUdeviceptr dptr, size_t Pitch) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, const CUDA_ARRAY_DESCRIPTOR *, + CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuTexRefSetAddress2D_v3"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, desc, dptr, Pitch); +} + +CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_format fmt, + int NumPackedComponents) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUarray_format, int); + static auto func_ptr = LoadSymbol("cuTexRefSetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fmt, NumPackedComponents); +} + +CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int dim, + CUaddress_mode am) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, int, CUaddress_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetAddressMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, dim, am); +} + +CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfilter_mode fm) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUfilter_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fm); +} + +CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, + CUfilter_mode fm) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUfilter_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fm); +} + +CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, float bias) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapLevelBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, bias); +} + +CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, + float minMipmapLevelClamp, + float maxMipmapLevelClamp) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float, float); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapLevelClamp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, minMipmapLevelClamp, maxMipmapLevelClamp); +} + +CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, + unsigned int maxAniso) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetMaxAnisotropy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, maxAniso); +} + +CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float *); + static auto func_ptr = LoadSymbol("cuTexRefSetBorderColor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, pBorderColor); +} + +CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, Flags); +} + +CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetAddress_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phArray, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmappedArray(CUmipmappedArray *phMipmappedArray, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phMipmappedArray, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, CUtexref hTexRef, + int dim) { + using FuncPtr = CUresult(CUDAAPI *)(CUaddress_mode *, CUtexref, int); + static auto func_ptr = LoadSymbol("cuTexRefGetAddressMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pam, hTexRef, dim); +} + +CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfilter_mode *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pfm, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, int *pNumChannels, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray_format *, int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFormat, pNumChannels, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfilter_mode *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pfm, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapLevelBias(float *pbias, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapLevelBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pbias, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, + float *pmaxMipmapLevelClamp, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapLevelClamp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pminMipmapLevelClamp, pmaxMipmapLevelClamp, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMaxAnisotropy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pmaxAniso, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetBorderColor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pBorderColor, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetFlags(unsigned int *pFlags, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, hTexRef); +} + +CUresult CUDAAPI cuTexRefCreate(CUtexref *pTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref *); + static auto func_ptr = LoadSymbol("cuTexRefCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexRef); +} + +CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef); +} + +CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, CUarray hArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfref, CUarray, unsigned int); + static auto func_ptr = LoadSymbol("cuSurfRefSetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hSurfRef, hArray, Flags); +} + +CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, CUsurfref hSurfRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUsurfref); + static auto func_ptr = LoadSymbol("cuSurfRefGetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phArray, hSurfRef); +} + +CUresult CUDAAPI +cuTexObjectCreate(CUtexObject *pTexObject, const CUDA_RESOURCE_DESC *pResDesc, + const CUDA_TEXTURE_DESC *pTexDesc, + const CUDA_RESOURCE_VIEW_DESC *pResViewDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexObject *, const CUDA_RESOURCE_DESC *, + const CUDA_TEXTURE_DESC *, + const CUDA_RESOURCE_VIEW_DESC *); + static auto func_ptr = LoadSymbol("cuTexObjectCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexObject, pResDesc, pTexDesc, pResViewDesc); +} + +CUresult CUDAAPI cuTexObjectDestroy(CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texObject); +} + +CUresult CUDAAPI cuTexObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, texObject); +} + +CUresult CUDAAPI cuTexObjectGetTextureDesc(CUDA_TEXTURE_DESC *pTexDesc, + CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_TEXTURE_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetTextureDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexDesc, texObject); +} + +CUresult CUDAAPI cuTexObjectGetResourceViewDesc( + CUDA_RESOURCE_VIEW_DESC *pResViewDesc, CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_VIEW_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetResourceViewDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResViewDesc, texObject); +} + +CUresult CUDAAPI cuSurfObjectCreate(CUsurfObject *pSurfObject, + const CUDA_RESOURCE_DESC *pResDesc) { + using FuncPtr = + CUresult(CUDAAPI *)(CUsurfObject *, const CUDA_RESOURCE_DESC *); + static auto func_ptr = LoadSymbol("cuSurfObjectCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfObject, pResDesc); +} + +CUresult CUDAAPI cuSurfObjectDestroy(CUsurfObject surfObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfObject); + static auto func_ptr = LoadSymbol("cuSurfObjectDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfObject); +} + +CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUsurfObject surfObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_DESC *, CUsurfObject); + static auto func_ptr = LoadSymbol("cuSurfObjectGetResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, surfObject); +} + +CUresult CUDAAPI cuDeviceCanAccessPeer(int *canAccessPeer, CUdevice dev, + CUdevice peerDev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUdevice, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceCanAccessPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(canAccessPeer, dev, peerDev); +} + +CUresult CUDAAPI cuCtxEnablePeerAccess(CUcontext peerContext, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext, unsigned int); + static auto func_ptr = LoadSymbol("cuCtxEnablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerContext, Flags); +} + +CUresult CUDAAPI cuCtxDisablePeerAccess(CUcontext peerContext) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDisablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerContext); +} + +CUresult CUDAAPI cuDeviceGetP2PAttribute(int *value, + CUdevice_P2PAttribute attrib, + CUdevice srcDevice, + CUdevice dstDevice) { + using FuncPtr = + CUresult(CUDAAPI *)(int *, CUdevice_P2PAttribute, CUdevice, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetP2PAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attrib, srcDevice, dstDevice); +} + +CUresult CUDAAPI cuGraphicsUnregisterResource(CUgraphicsResource resource) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphicsResource); + static auto func_ptr = LoadSymbol("cuGraphicsUnregisterResource"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource); +} + +CUresult CUDAAPI cuGraphicsSubResourceGetMappedArray( + CUarray *pArray, CUgraphicsResource resource, unsigned int arrayIndex, + unsigned int mipLevel) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUgraphicsResource, + unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol("cuGraphicsSubResourceGetMappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArray, resource, arrayIndex, mipLevel); +} + +CUresult CUDAAPI cuGraphicsResourceGetMappedMipmappedArray( + CUmipmappedArray *pMipmappedArray, CUgraphicsResource resource) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray *, CUgraphicsResource); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pMipmappedArray, resource); +} + +CUresult CUDAAPI cuGraphicsResourceGetMappedPointer( + CUdeviceptr *pDevPtr, size_t *pSize, CUgraphicsResource resource) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUgraphicsResource); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceGetMappedPointer_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pDevPtr, pSize, resource); +} + +CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphicsResource, unsigned int); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceSetMapFlags_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource, flags); +} + +CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUgraphicsResource *, CUstream); + static auto func_ptr = LoadSymbol("cuGraphicsMapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, hStream); +} + +CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUgraphicsResource *, CUstream); + static auto func_ptr = LoadSymbol("cuGraphicsUnmapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, hStream); +} + +CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, + const CUuuid *pExportTableId) { + using FuncPtr = CUresult(CUDAAPI *)(const void **, const CUuuid *); + static auto func_ptr = LoadSymbol("cuGetExportTable"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ppExportTable, pExportTableId); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cuda_10_2.inc b/tensorflow/stream_executor/cuda/cuda_10_2.inc new file mode 100644 index 00000000000..f37fc9d888d --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_10_2.inc @@ -0,0 +1,2328 @@ +// Auto-generated, do not edit. + +extern "C" { + +CUresult CUDAAPI cuGetErrorString(CUresult error, const char **pStr) { + using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); + static auto func_ptr = LoadSymbol("cuGetErrorString"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(error, pStr); +} + +CUresult CUDAAPI cuGetErrorName(CUresult error, const char **pStr) { + using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); + static auto func_ptr = LoadSymbol("cuGetErrorName"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(error, pStr); +} + +CUresult CUDAAPI cuInit(unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int); + static auto func_ptr = LoadSymbol("cuInit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(Flags); +} + +CUresult CUDAAPI cuDriverGetVersion(int *driverVersion) { + using FuncPtr = CUresult(CUDAAPI *)(int *); + static auto func_ptr = LoadSymbol("cuDriverGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(driverVersion); +} + +CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *, int); + static auto func_ptr = LoadSymbol("cuDeviceGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device, ordinal); +} + +CUresult CUDAAPI cuDeviceGetCount(int *count) { + using FuncPtr = CUresult(CUDAAPI *)(int *); + static auto func_ptr = LoadSymbol("cuDeviceGetCount"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count); +} + +CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(char *, int, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetName"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(name, len, dev); +} + +CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUuuid *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetUuid"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(uuid, dev); +} + +CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceTotalMem_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(bytes, dev); +} + +CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUdevice_attribute, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pi, attrib, dev); +} + +CUresult CUDAAPI cuDeviceGetNvSciSyncAttributes(void *nvSciSyncAttrList, + CUdevice dev, int flags) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdevice, int); + static auto func_ptr = LoadSymbol("cuDeviceGetNvSciSyncAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(nvSciSyncAttrList, dev, flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevprop *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetProperties"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(prop, dev); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceComputeCapability(int *major, + int *minor, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceComputeCapability"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(major, minor, dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxRetain"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxRelease"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice, unsigned int); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxSetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, flags); +} + +CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, + int *active) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice, unsigned int *, int *); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxGetState"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, flags, active); +} + +CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxReset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev); +} + +CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, unsigned int, CUdevice); + static auto func_ptr = LoadSymbol("cuCtxCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, flags, dev); +} + +CUresult CUDAAPI cuCtxDestroy(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxPushCurrent_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *); + static auto func_ptr = LoadSymbol("cuCtxPopCurrent_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx); +} + +CUresult CUDAAPI cuCtxSetCurrent(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxSetCurrent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxGetCurrent(CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *); + static auto func_ptr = LoadSymbol("cuCtxGetCurrent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx); +} + +CUresult CUDAAPI cuCtxGetDevice(CUdevice *device) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *); + static auto func_ptr = LoadSymbol("cuCtxGetDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device); +} + +CUresult CUDAAPI cuCtxGetFlags(unsigned int *flags) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *); + static auto func_ptr = LoadSymbol("cuCtxGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags); +} + +CUresult CUDAAPI cuCtxSynchronize(void) { + using FuncPtr = CUresult(CUDAAPI *)(); + static auto func_ptr = LoadSymbol("cuCtxSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +CUresult CUDAAPI cuCtxSetLimit(CUlimit limit, size_t value) { + using FuncPtr = CUresult(CUDAAPI *)(CUlimit, size_t); + static auto func_ptr = LoadSymbol("cuCtxSetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(limit, value); +} + +CUresult CUDAAPI cuCtxGetLimit(size_t *pvalue, CUlimit limit) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUlimit); + static auto func_ptr = LoadSymbol("cuCtxGetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pvalue, limit); +} + +CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunc_cache *); + static auto func_ptr = LoadSymbol("cuCtxGetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pconfig); +} + +CUresult CUDAAPI cuCtxSetCacheConfig(CUfunc_cache config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunc_cache); + static auto func_ptr = LoadSymbol("cuCtxSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +CUresult CUDAAPI cuCtxGetSharedMemConfig(CUsharedconfig *pConfig) { + using FuncPtr = CUresult(CUDAAPI *)(CUsharedconfig *); + static auto func_ptr = LoadSymbol("cuCtxGetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pConfig); +} + +CUresult CUDAAPI cuCtxSetSharedMemConfig(CUsharedconfig config) { + using FuncPtr = CUresult(CUDAAPI *)(CUsharedconfig); + static auto func_ptr = LoadSymbol("cuCtxSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +CUresult CUDAAPI cuCtxGetApiVersion(CUcontext ctx, unsigned int *version) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext, unsigned int *); + static auto func_ptr = LoadSymbol("cuCtxGetApiVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx, version); +} + +CUresult CUDAAPI cuCtxGetStreamPriorityRange(int *leastPriority, + int *greatestPriority) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *); + static auto func_ptr = LoadSymbol("cuCtxGetStreamPriorityRange"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(leastPriority, greatestPriority); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxAttach(CUcontext *pctx, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, unsigned int); + static auto func_ptr = LoadSymbol("cuCtxAttach"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxDetach(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDetach"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuModuleLoad(CUmodule *module, const char *fname) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const char *); + static auto func_ptr = LoadSymbol("cuModuleLoad"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, fname); +} + +CUresult CUDAAPI cuModuleLoadData(CUmodule *module, const void *image) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *); + static auto func_ptr = LoadSymbol("cuModuleLoadData"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, image); +} + +CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module, const void *image, + unsigned int numOptions, + CUjit_option *options, + void **optionValues) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *, unsigned int, + CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuModuleLoadDataEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, image, numOptions, options, optionValues); +} + +CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *); + static auto func_ptr = LoadSymbol("cuModuleLoadFatBinary"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, fatCubin); +} + +CUresult CUDAAPI cuModuleUnload(CUmodule hmod) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule); + static auto func_ptr = LoadSymbol("cuModuleUnload"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hmod); +} + +CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetFunction"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, hmod, name); +} + +CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, size_t *bytes, + CUmodule hmod, const char *name) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetGlobal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytes, hmod, name); +} + +CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetTexRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexRef, hmod, name); +} + +CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfref *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetSurfRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfRef, hmod, name); +} + +CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, + void **optionValues, CUlinkState *stateOut) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUjit_option *, void **, CUlinkState *); + static auto func_ptr = LoadSymbol("cuLinkCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numOptions, options, optionValues, stateOut); +} + +CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, + void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, + void **optionValues) { + using FuncPtr = + CUresult(CUDAAPI *)(CUlinkState, CUjitInputType, void *, size_t, + const char *, unsigned int, CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuLinkAddData_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, type, data, size, name, numOptions, options, + optionValues); +} + +CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, + const char *path, unsigned int numOptions, + CUjit_option *options, void **optionValues) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState, CUjitInputType, const char *, + unsigned int, CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuLinkAddFile_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDAAPI cuLinkComplete(CUlinkState state, void **cubinOut, + size_t *sizeOut) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState, void **, size_t *); + static auto func_ptr = LoadSymbol("cuLinkComplete"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, cubinOut, sizeOut); +} + +CUresult CUDAAPI cuLinkDestroy(CUlinkState state) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState); + static auto func_ptr = LoadSymbol("cuLinkDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state); +} + +CUresult CUDAAPI cuMemGetInfo(size_t *free, size_t *total) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, size_t *); + static auto func_ptr = LoadSymbol("cuMemGetInfo_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(free, total); +} + +CUresult CUDAAPI cuMemAlloc(CUdeviceptr *dptr, size_t bytesize) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t); + static auto func_ptr = LoadSymbol("cuMemAlloc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytesize); +} + +CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, size_t *pPitch, + size_t WidthInBytes, size_t Height, + unsigned int ElementSizeBytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, size_t, size_t, + unsigned int); + static auto func_ptr = LoadSymbol("cuMemAllocPitch_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, pPitch, WidthInBytes, Height, ElementSizeBytes); +} + +CUresult CUDAAPI cuMemFree(CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr); + static auto func_ptr = LoadSymbol("cuMemFree_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr); +} + +CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr *pbase, size_t *psize, + CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuMemGetAddressRange_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pbase, psize, dptr); +} + +CUresult CUDAAPI cuMemAllocHost(void **pp, size_t bytesize) { + using FuncPtr = CUresult(CUDAAPI *)(void **, size_t); + static auto func_ptr = LoadSymbol("cuMemAllocHost_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pp, bytesize); +} + +CUresult CUDAAPI cuMemFreeHost(void *p) { + using FuncPtr = CUresult(CUDAAPI *)(void *); + static auto func_ptr = LoadSymbol("cuMemFreeHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +CUresult CUDAAPI cuMemHostAlloc(void **pp, size_t bytesize, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(void **, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostAlloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pp, bytesize, Flags); +} + +CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostGetDevicePointer_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, p, Flags); +} + +CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *, void *); + static auto func_ptr = LoadSymbol("cuMemHostGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, p); +} + +CUresult CUDAAPI cuMemAllocManaged(CUdeviceptr *dptr, size_t bytesize, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemAllocManaged"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytesize, flags); +} + +CUresult CUDAAPI cuDeviceGetByPCIBusId(CUdevice *dev, const char *pciBusId) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *, const char *); + static auto func_ptr = LoadSymbol("cuDeviceGetByPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, pciBusId); +} + +CUresult CUDAAPI cuDeviceGetPCIBusId(char *pciBusId, int len, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(char *, int, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pciBusId, len, dev); +} + +CUresult CUDAAPI cuIpcGetEventHandle(CUipcEventHandle *pHandle, CUevent event) { + using FuncPtr = CUresult(CUDAAPI *)(CUipcEventHandle *, CUevent); + static auto func_ptr = LoadSymbol("cuIpcGetEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, event); +} + +CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, + CUipcEventHandle handle) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent *, CUipcEventHandle); + static auto func_ptr = LoadSymbol("cuIpcOpenEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phEvent, handle); +} + +CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUipcMemHandle *, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuIpcGetMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, dptr); +} + +CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, + unsigned int Flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, CUipcMemHandle, unsigned int); + static auto func_ptr = LoadSymbol("cuIpcOpenMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, handle, Flags); +} + +CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr); + static auto func_ptr = LoadSymbol("cuIpcCloseMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr); +} + +CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(void *, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostRegister_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p, bytesize, Flags); +} + +CUresult CUDAAPI cuMemHostUnregister(void *p) { + using FuncPtr = CUresult(CUDAAPI *)(void *); + static auto func_ptr = LoadSymbol("cuMemHostUnregister"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, ByteCount); +} + +CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUcontext, CUdeviceptr, + CUcontext, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstContext, srcDevice, srcContext, ByteCount); +} + +CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, const void *, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyHtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcHost, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoH_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, size_t dstOffset, + CUdeviceptr srcDevice, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr dstDevice, CUarray srcArray, + size_t srcOffset, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyHtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcHost, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, size_t srcOffset, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoH_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, size_t dstOffset, + CUarray srcArray, size_t srcOffset, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *); + static auto func_ptr = LoadSymbol("cuMemcpy2D_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *); + static auto func_ptr = LoadSymbol("cuMemcpy2DUnaligned_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuMemcpy3D_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D_PEER *); + static auto func_ptr = LoadSymbol("cuMemcpy3DPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUcontext, CUdeviceptr, + CUcontext, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstContext, srcDevice, srcContext, ByteCount, + hStream); +} + +CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, const void *, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyHtoDAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcHost, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyDtoHAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcDevice, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyDtoDAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcDevice, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray, size_t, const void *, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyHtoAAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcHost, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, + size_t srcOffset, size_t ByteCount, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(void *, CUarray, size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyAtoHAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcArray, srcOffset, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D *pCopy, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy2DAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy3DAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, + CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D_PEER *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy3DPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemsetD8(CUdeviceptr dstDevice, unsigned char uc, size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned char, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD8_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, uc, N); +} + +CUresult CUDAAPI cuMemsetD16(CUdeviceptr dstDevice, unsigned short us, + size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned short, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD16_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, us, N); +} + +CUresult CUDAAPI cuMemsetD32(CUdeviceptr dstDevice, unsigned int ui, size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned int, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD32_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, ui, N); +} + +CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned char, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D8_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, uc, Width, Height); +} + +CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned short, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D16_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, us, Width, Height); +} + +CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned int, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D32_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, ui, Width, Height); +} + +CUresult CUDAAPI cuMemsetD8Async(CUdeviceptr dstDevice, unsigned char uc, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned char, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD8Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, uc, N, hStream); +} + +CUresult CUDAAPI cuMemsetD16Async(CUdeviceptr dstDevice, unsigned short us, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned short, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD16Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, us, N, hStream); +} + +CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned int, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD32Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, ui, N, hStream); +} + +CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned char, + size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D8Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, uc, Width, Height, hStream); +} + +CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned short, + size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D16Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, us, Width, Height, hStream); +} + +CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned int, size_t, + size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D32Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, ui, Width, Height, hStream); +} + +CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, + const CUDA_ARRAY_DESCRIPTOR *pAllocateArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, const CUDA_ARRAY_DESCRIPTOR *); + static auto func_ptr = LoadSymbol("cuArrayCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pAllocateArray); +} + +CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, + CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_ARRAY_DESCRIPTOR *, CUarray); + static auto func_ptr = LoadSymbol("cuArrayGetDescriptor_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArrayDescriptor, hArray); +} + +CUresult CUDAAPI cuArrayDestroy(CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray); + static auto func_ptr = LoadSymbol("cuArrayDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hArray); +} + +CUresult CUDAAPI cuArray3DCreate( + CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR *pAllocateArray) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray *, const CUDA_ARRAY3D_DESCRIPTOR *); + static auto func_ptr = LoadSymbol("cuArray3DCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pAllocateArray); +} + +CUresult CUDAAPI cuArray3DGetDescriptor( + CUDA_ARRAY3D_DESCRIPTOR *pArrayDescriptor, CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_ARRAY3D_DESCRIPTOR *, CUarray); + static auto func_ptr = LoadSymbol("cuArray3DGetDescriptor_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArrayDescriptor, hArray); +} + +CUresult CUDAAPI +cuMipmappedArrayCreate(CUmipmappedArray *pHandle, + const CUDA_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc, + unsigned int numMipmapLevels) { + using FuncPtr = CUresult(CUDAAPI *)( + CUmipmappedArray *, const CUDA_ARRAY3D_DESCRIPTOR *, unsigned int); + static auto func_ptr = LoadSymbol("cuMipmappedArrayCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pMipmappedArrayDesc, numMipmapLevels); +} + +CUresult CUDAAPI cuMipmappedArrayGetLevel(CUarray *pLevelArray, + CUmipmappedArray hMipmappedArray, + unsigned int level) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray *, CUmipmappedArray, unsigned int); + static auto func_ptr = LoadSymbol("cuMipmappedArrayGetLevel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pLevelArray, hMipmappedArray, level); +} + +CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray); + static auto func_ptr = LoadSymbol("cuMipmappedArrayDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hMipmappedArray); +} + +CUresult CUDAAPI cuMemAddressReserve(CUdeviceptr *ptr, size_t size, + size_t alignment, CUdeviceptr addr, + unsigned long long flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t, size_t, + CUdeviceptr, unsigned long long); + static auto func_ptr = LoadSymbol("cuMemAddressReserve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size, alignment, addr, flags); +} + +CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemAddressFree"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size); +} + +CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUmemGenericAllocationHandle *, size_t, + const CUmemAllocationProp *, unsigned long long); + static auto func_ptr = LoadSymbol("cuMemCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, size, prop, flags); +} + +CUresult CUDAAPI cuMemRelease(CUmemGenericAllocationHandle handle) { + using FuncPtr = CUresult(CUDAAPI *)(CUmemGenericAllocationHandle); + static auto func_ptr = LoadSymbol("cuMemRelease"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +CUresult CUDAAPI cuMemMap(CUdeviceptr ptr, size_t size, size_t offset, + CUmemGenericAllocationHandle handle, + unsigned long long flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, size_t, + CUmemGenericAllocationHandle, unsigned long long); + static auto func_ptr = LoadSymbol("cuMemMap"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size, offset, handle, flags); +} + +CUresult CUDAAPI cuMemUnmap(CUdeviceptr ptr, size_t size) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemUnmap"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size); +} + +CUresult CUDAAPI cuMemSetAccess(CUdeviceptr ptr, size_t size, + const CUmemAccessDesc *desc, size_t count) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, const CUmemAccessDesc *, size_t); + static auto func_ptr = LoadSymbol("cuMemSetAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size, desc, count); +} + +CUresult CUDAAPI cuMemGetAccess(unsigned long long *flags, + const CUmemLocation *location, + CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned long long *, + const CUmemLocation *, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuMemGetAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags, location, ptr); +} + +CUresult CUDAAPI cuMemExportToShareableHandle( + void *shareableHandle, CUmemGenericAllocationHandle handle, + CUmemAllocationHandleType handleType, unsigned long long flags) { + using FuncPtr = + CUresult(CUDAAPI *)(void *, CUmemGenericAllocationHandle, + CUmemAllocationHandleType, unsigned long long); + static auto func_ptr = LoadSymbol("cuMemExportToShareableHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(shareableHandle, handle, handleType, flags); +} + +CUresult CUDAAPI cuMemImportFromShareableHandle( + CUmemGenericAllocationHandle *handle, void *osHandle, + CUmemAllocationHandleType shHandleType) { + using FuncPtr = CUresult(CUDAAPI *)(CUmemGenericAllocationHandle *, void *, + CUmemAllocationHandleType); + static auto func_ptr = LoadSymbol("cuMemImportFromShareableHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, osHandle, shHandleType); +} + +CUresult CUDAAPI cuMemGetAllocationGranularity( + size_t *granularity, const CUmemAllocationProp *prop, + CUmemAllocationGranularity_flags option) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, const CUmemAllocationProp *, + CUmemAllocationGranularity_flags); + static auto func_ptr = LoadSymbol("cuMemGetAllocationGranularity"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(granularity, prop, option); +} + +CUresult CUDAAPI cuMemGetAllocationPropertiesFromHandle( + CUmemAllocationProp *prop, CUmemGenericAllocationHandle handle) { + using FuncPtr = + CUresult(CUDAAPI *)(CUmemAllocationProp *, CUmemGenericAllocationHandle); + static auto func_ptr = + LoadSymbol("cuMemGetAllocationPropertiesFromHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(prop, handle); +} + +CUresult CUDAAPI cuPointerGetAttribute(void *data, + CUpointer_attribute attribute, + CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUpointer_attribute, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, attribute, ptr); +} + +CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, + CUdevice dstDevice, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, CUdevice, CUstream); + static auto func_ptr = LoadSymbol("cuMemPrefetchAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, dstDevice, hStream); +} + +CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, + CUmem_advise advice, CUdevice device) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, CUmem_advise, CUdevice); + static auto func_ptr = LoadSymbol("cuMemAdvise"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, advice, device); +} + +CUresult CUDAAPI cuMemRangeGetAttribute(void *data, size_t dataSize, + CUmem_range_attribute attribute, + CUdeviceptr devPtr, size_t count) { + using FuncPtr = CUresult(CUDAAPI *)(void *, size_t, CUmem_range_attribute, + CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemRangeGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSize, attribute, devPtr, count); +} + +CUresult CUDAAPI cuMemRangeGetAttributes(void **data, size_t *dataSizes, + CUmem_range_attribute *attributes, + size_t numAttributes, + CUdeviceptr devPtr, size_t count) { + using FuncPtr = CUresult(CUDAAPI *)( + void **, size_t *, CUmem_range_attribute *, size_t, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemRangeGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSizes, attributes, numAttributes, devPtr, count); +} + +CUresult CUDAAPI cuPointerSetAttribute(const void *value, + CUpointer_attribute attribute, + CUdeviceptr ptr) { + using FuncPtr = + CUresult(CUDAAPI *)(const void *, CUpointer_attribute, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attribute, ptr); +} + +CUresult CUDAAPI cuPointerGetAttributes(unsigned int numAttributes, + CUpointer_attribute *attributes, + void **data, CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int, CUpointer_attribute *, + void **, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numAttributes, attributes, data, ptr); +} + +CUresult CUDAAPI cuStreamCreate(CUstream *phStream, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phStream, Flags); +} + +CUresult CUDAAPI cuStreamCreateWithPriority(CUstream *phStream, + unsigned int flags, int priority) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream *, unsigned int, int); + static auto func_ptr = LoadSymbol("cuStreamCreateWithPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phStream, flags, priority); +} + +CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, int *); + static auto func_ptr = LoadSymbol("cuStreamGetPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, priority); +} + +CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, unsigned int *); + static auto func_ptr = LoadSymbol("cuStreamGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, flags); +} + +CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUcontext *); + static auto func_ptr = LoadSymbol("cuStreamGetCtx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, pctx); +} + +CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUevent, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitEvent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, hEvent, Flags); +} + +CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, + CUstreamCallback callback, void *userData, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamCallback, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamAddCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, callback, userData, flags); +} + +CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream, + CUstreamCaptureMode mode) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUstreamCaptureMode); + static auto func_ptr = LoadSymbol("cuStreamBeginCapture_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, mode); +} + +CUresult CUDAAPI cuThreadExchangeStreamCaptureMode(CUstreamCaptureMode *mode) { + using FuncPtr = CUresult(CUDAAPI *)(CUstreamCaptureMode *); + static auto func_ptr = + LoadSymbol("cuThreadExchangeStreamCaptureMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mode); +} + +CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUgraph *); + static auto func_ptr = LoadSymbol("cuStreamEndCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, phGraph); +} + +CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, + CUstreamCaptureStatus *captureStatus) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUstreamCaptureStatus *); + static auto func_ptr = LoadSymbol("cuStreamIsCapturing"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, captureStatus); +} + +CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, + CUstreamCaptureStatus *captureStatus, + cuuint64_t *id) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamCaptureStatus *, cuuint64_t *); + static auto func_ptr = LoadSymbol("cuStreamGetCaptureInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, captureStatus, id); +} + +CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, + size_t length, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamAttachMemAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, dptr, length, flags); +} + +CUresult CUDAAPI cuStreamQuery(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamSynchronize(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamDestroy(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuEventCreate(CUevent *phEvent, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent *, unsigned int); + static auto func_ptr = LoadSymbol("cuEventCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phEvent, Flags); +} + +CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent, CUstream); + static auto func_ptr = LoadSymbol("cuEventRecord"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent, hStream); +} + +CUresult CUDAAPI cuEventQuery(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventSynchronize(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventDestroy(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, + CUevent hEnd) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUevent, CUevent); + static auto func_ptr = LoadSymbol("cuEventElapsedTime"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pMilliseconds, hStart, hEnd); +} + +CUresult CUDAAPI +cuImportExternalMemory(CUexternalMemory *extMem_out, + const CUDA_EXTERNAL_MEMORY_HANDLE_DESC *memHandleDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalMemory *, + const CUDA_EXTERNAL_MEMORY_HANDLE_DESC *); + static auto func_ptr = LoadSymbol("cuImportExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem_out, memHandleDesc); +} + +CUresult CUDAAPI cuExternalMemoryGetMappedBuffer( + CUdeviceptr *devPtr, CUexternalMemory extMem, + const CUDA_EXTERNAL_MEMORY_BUFFER_DESC *bufferDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, CUexternalMemory, + const CUDA_EXTERNAL_MEMORY_BUFFER_DESC *); + static auto func_ptr = LoadSymbol("cuExternalMemoryGetMappedBuffer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, extMem, bufferDesc); +} + +CUresult CUDAAPI cuExternalMemoryGetMappedMipmappedArray( + CUmipmappedArray *mipmap, CUexternalMemory extMem, + const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *mipmapDesc) { + using FuncPtr = + CUresult(CUDAAPI *)(CUmipmappedArray *, CUexternalMemory, + const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *); + static auto func_ptr = + LoadSymbol("cuExternalMemoryGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmap, extMem, mipmapDesc); +} + +CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalMemory); + static auto func_ptr = LoadSymbol("cuDestroyExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem); +} + +CUresult CUDAAPI cuImportExternalSemaphore( + CUexternalSemaphore *extSem_out, + const CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC *semHandleDesc) { + using FuncPtr = CUresult(CUDAAPI *)( + CUexternalSemaphore *, const CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC *); + static auto func_ptr = LoadSymbol("cuImportExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem_out, semHandleDesc); +} + +CUresult CUDAAPI cuSignalExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream) { + using FuncPtr = CUresult(CUDAAPI *)( + const CUexternalSemaphore *, + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *, unsigned int, CUstream); + static auto func_ptr = LoadSymbol("cuSignalExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +CUresult CUDAAPI cuWaitExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream) { + using FuncPtr = CUresult(CUDAAPI *)( + const CUexternalSemaphore *, const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *, + unsigned int, CUstream); + static auto func_ptr = LoadSymbol("cuWaitExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalSemaphore); + static auto func_ptr = LoadSymbol("cuDestroyExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem); +} + +CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint32_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitValue32"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint64_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitValue64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint32_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWriteValue32"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint64_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWriteValue64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, + CUstreamBatchMemOpParams *paramArray, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, unsigned int, + CUstreamBatchMemOpParams *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamBatchMemOp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, count, paramArray, flags); +} + +CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, + CUfunction hfunc) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction_attribute, CUfunction); + static auto func_ptr = LoadSymbol("cuFuncGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pi, attrib, hfunc); +} + +CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, + CUfunction_attribute attrib, int value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUfunction_attribute, int); + static auto func_ptr = LoadSymbol("cuFuncSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, attrib, value); +} + +CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUfunc_cache); + static auto func_ptr = LoadSymbol("cuFuncSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, config); +} + +CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, + CUsharedconfig config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUsharedconfig); + static auto func_ptr = LoadSymbol("cuFuncSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, config); +} + +CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams, void **extra) { + using FuncPtr = CUresult(CUDAAPI *)( + CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **, void **); + static auto func_ptr = LoadSymbol("cuLaunchKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDAAPI cuLaunchCooperativeKernel( + CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams) { + using FuncPtr = CUresult(CUDAAPI *)( + CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **); + static auto func_ptr = LoadSymbol("cuLaunchCooperativeKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice( + CUDA_LAUNCH_PARAMS *launchParamsList, unsigned int numDevices, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUDA_LAUNCH_PARAMS *, unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol("cuLaunchCooperativeKernelMultiDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(launchParamsList, numDevices, flags); +} + +CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, + void *userData) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUhostFn, void *); + static auto func_ptr = LoadSymbol("cuLaunchHostFunc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, fn, userData); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetBlockShape(CUfunction hfunc, int x, + int y, int z) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int, int); + static auto func_ptr = LoadSymbol("cuFuncSetBlockShape"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, x, y, z); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetSharedSize(CUfunction hfunc, + unsigned int bytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, unsigned int); + static auto func_ptr = LoadSymbol("cuFuncSetSharedSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, bytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetSize(CUfunction hfunc, + unsigned int numbytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSetSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, numbytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSeti(CUfunction hfunc, int offset, + unsigned int value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSeti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, value); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetf(CUfunction hfunc, int offset, + float value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, float); + static auto func_ptr = LoadSymbol("cuParamSetf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, value); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetv(CUfunction hfunc, int offset, + void *ptr, + unsigned int numbytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSetv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, ptr, numbytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunch(CUfunction f) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction); + static auto func_ptr = LoadSymbol("cuLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGrid(CUfunction f, int grid_width, + int grid_height) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int); + static auto func_ptr = LoadSymbol("cuLaunchGrid"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, grid_width, grid_height); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGridAsync(CUfunction f, + int grid_width, + int grid_height, + CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int, CUstream); + static auto func_ptr = LoadSymbol("cuLaunchGridAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, grid_width, grid_height, hStream); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetTexRef(CUfunction hfunc, + int texunit, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, CUtexref); + static auto func_ptr = LoadSymbol("cuParamSetTexRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, texunit, hTexRef); +} + +CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph *, unsigned int); + static auto func_ptr = LoadSymbol("cuGraphCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraph, flags); +} + +CUresult CUDAAPI cuGraphAddKernelNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphAddKernelNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); +} + +CUresult CUDAAPI cuGraphKernelNodeGetParams( + CUgraphNode hNode, CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphKernelNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphKernelNodeSetParams( + CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddMemcpyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_MEMCPY3D *copyParams, + CUcontext ctx) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_MEMCPY3D *, CUcontext); + static auto func_ptr = LoadSymbol("cuGraphAddMemcpyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + copyParams, ctx); +} + +CUresult CUDAAPI cuGraphMemcpyNodeGetParams(CUgraphNode hNode, + CUDA_MEMCPY3D *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuGraphMemcpyNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphMemcpyNodeSetParams(CUgraphNode hNode, + const CUDA_MEMCPY3D *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, const CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuGraphMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddMemsetNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_MEMSET_NODE_PARAMS *memsetParams, + CUcontext ctx) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_MEMSET_NODE_PARAMS *, CUcontext); + static auto func_ptr = LoadSymbol("cuGraphAddMemsetNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + memsetParams, ctx); +} + +CUresult CUDAAPI cuGraphMemsetNodeGetParams( + CUgraphNode hNode, CUDA_MEMSET_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_MEMSET_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphMemsetNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphMemsetNodeSetParams( + CUgraphNode hNode, const CUDA_MEMSET_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_MEMSET_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddHostNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphAddHostNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); +} + +CUresult CUDAAPI cuGraphHostNodeGetParams(CUgraphNode hNode, + CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphHostNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphHostNodeSetParams( + CUgraphNode hNode, const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, + CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + CUgraph childGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, + const CUgraphNode *, size_t, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphAddChildGraphNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + childGraph); +} + +CUresult CUDAAPI cuGraphChildGraphNodeGetGraph(CUgraphNode hNode, + CUgraph *phGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraph *); + static auto func_ptr = LoadSymbol("cuGraphChildGraphNodeGetGraph"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, phGraph); +} + +CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphAddEmptyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies); +} + +CUresult CUDAAPI cuGraphClone(CUgraph *phGraphClone, CUgraph originalGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph *, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphClone, originalGraph); +} + +CUresult CUDAAPI cuGraphNodeFindInClone(CUgraphNode *phNode, + CUgraphNode hOriginalNode, + CUgraph hClonedGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraphNode, CUgraph); + static auto func_ptr = LoadSymbol("cuGraphNodeFindInClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phNode, hOriginalNode, hClonedGraph); +} + +CUresult CUDAAPI cuGraphNodeGetType(CUgraphNode hNode, CUgraphNodeType *type) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNodeType *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, type); +} + +CUresult CUDAAPI cuGraphGetNodes(CUgraph hGraph, CUgraphNode *nodes, + size_t *numNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, nodes, numNodes); +} + +CUresult CUDAAPI cuGraphGetRootNodes(CUgraph hGraph, CUgraphNode *rootNodes, + size_t *numRootNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetRootNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, rootNodes, numRootNodes); +} + +CUresult CUDAAPI cuGraphGetEdges(CUgraph hGraph, CUgraphNode *from, + CUgraphNode *to, size_t *numEdges) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphGetEdges"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numEdges); +} + +CUresult CUDAAPI cuGraphNodeGetDependencies(CUgraphNode hNode, + CUgraphNode *dependencies, + size_t *numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, dependencies, numDependencies); +} + +CUresult CUDAAPI cuGraphNodeGetDependentNodes(CUgraphNode hNode, + CUgraphNode *dependentNodes, + size_t *numDependentNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol("cuGraphNodeGetDependentNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, dependentNodes, numDependentNodes); +} + +CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, const CUgraphNode *from, + const CUgraphNode *to, + size_t numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, const CUgraphNode *, + const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphAddDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numDependencies); +} + +CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, + const CUgraphNode *from, + const CUgraphNode *to, + size_t numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, const CUgraphNode *, + const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol("cuGraphRemoveDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numDependencies); +} + +CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode); + static auto func_ptr = LoadSymbol("cuGraphDestroyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode); +} + +CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, + CUgraphNode *phErrorNode, char *logBuffer, + size_t bufferSize) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec *, CUgraph, CUgraphNode *, + char *, size_t); + static auto func_ptr = LoadSymbol("cuGraphInstantiate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphExec, hGraph, phErrorNode, logBuffer, bufferSize); +} + +CUresult CUDAAPI +cuGraphExecKernelNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraphNode, + const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphExecKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphExecMemcpyNodeSetParams(CUgraphExec hGraphExec, + CUgraphNode hNode, + const CUDA_MEMCPY3D *copyParams, + CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraphNode, + const CUDA_MEMCPY3D *, CUcontext); + static auto func_ptr = LoadSymbol("cuGraphExecMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, copyParams, ctx); +} + +CUresult CUDAAPI cuGraphExecMemsetNodeSetParams( + CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_MEMSET_NODE_PARAMS *memsetParams, CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)( + CUgraphExec, CUgraphNode, const CUDA_MEMSET_NODE_PARAMS *, CUcontext); + static auto func_ptr = LoadSymbol("cuGraphExecMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, memsetParams, ctx); +} + +CUresult CUDAAPI +cuGraphExecHostNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraphNode, + const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol("cuGraphExecHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraphExec, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUstream); + static auto func_ptr = LoadSymbol("cuGraphLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hStream); +} + +CUresult CUDAAPI cuGraphExecDestroy(CUgraphExec hGraphExec) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec); + static auto func_ptr = LoadSymbol("cuGraphExecDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec); +} + +CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph); + static auto func_ptr = LoadSymbol("cuGraphDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph); +} + +CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, + CUgraphNode *hErrorNode_out, + CUgraphExecUpdateResult *updateResult_out) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraph, CUgraphNode *, + CUgraphExecUpdateResult *); + static auto func_ptr = LoadSymbol("cuGraphExecUpdate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hGraph, hErrorNode_out, updateResult_out); +} + +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessor( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction, int, size_t); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxActiveBlocksPerMultiprocessor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize); +} + +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(int *, CUfunction, int, size_t, unsigned int); + static auto func_ptr = LoadSymbol( + "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize, flags); +} + +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSize( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *, CUfunction, + CUoccupancyB2DSize, size_t, int); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxPotentialBlockSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(minGridSize, blockSize, func, blockSizeToDynamicSMemSize, + dynamicSMemSize, blockSizeLimit); +} + +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)( + int *, int *, CUfunction, CUoccupancyB2DSize, size_t, int, unsigned int); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxPotentialBlockSizeWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(minGridSize, blockSize, func, blockSizeToDynamicSMemSize, + dynamicSMemSize, blockSizeLimit, flags); +} + +CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUarray, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, hArray, Flags); +} + +CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, + CUmipmappedArray hMipmappedArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUmipmappedArray, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, hMipmappedArray, Flags); +} + +CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexref hTexRef, + CUdeviceptr dptr, size_t bytes) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUtexref, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuTexRefSetAddress_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ByteOffset, hTexRef, dptr, bytes); +} + +CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, + const CUDA_ARRAY_DESCRIPTOR *desc, + CUdeviceptr dptr, size_t Pitch) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, const CUDA_ARRAY_DESCRIPTOR *, + CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuTexRefSetAddress2D_v3"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, desc, dptr, Pitch); +} + +CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_format fmt, + int NumPackedComponents) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUarray_format, int); + static auto func_ptr = LoadSymbol("cuTexRefSetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fmt, NumPackedComponents); +} + +CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int dim, + CUaddress_mode am) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, int, CUaddress_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetAddressMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, dim, am); +} + +CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfilter_mode fm) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUfilter_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fm); +} + +CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, + CUfilter_mode fm) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUfilter_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fm); +} + +CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, float bias) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapLevelBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, bias); +} + +CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, + float minMipmapLevelClamp, + float maxMipmapLevelClamp) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float, float); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapLevelClamp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, minMipmapLevelClamp, maxMipmapLevelClamp); +} + +CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, + unsigned int maxAniso) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetMaxAnisotropy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, maxAniso); +} + +CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float *); + static auto func_ptr = LoadSymbol("cuTexRefSetBorderColor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, pBorderColor); +} + +CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, Flags); +} + +CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetAddress_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phArray, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmappedArray(CUmipmappedArray *phMipmappedArray, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phMipmappedArray, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, CUtexref hTexRef, + int dim) { + using FuncPtr = CUresult(CUDAAPI *)(CUaddress_mode *, CUtexref, int); + static auto func_ptr = LoadSymbol("cuTexRefGetAddressMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pam, hTexRef, dim); +} + +CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfilter_mode *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pfm, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, int *pNumChannels, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray_format *, int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFormat, pNumChannels, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfilter_mode *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pfm, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapLevelBias(float *pbias, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapLevelBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pbias, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, + float *pmaxMipmapLevelClamp, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapLevelClamp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pminMipmapLevelClamp, pmaxMipmapLevelClamp, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMaxAnisotropy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pmaxAniso, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetBorderColor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pBorderColor, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetFlags(unsigned int *pFlags, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, hTexRef); +} + +CUresult CUDAAPI cuTexRefCreate(CUtexref *pTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref *); + static auto func_ptr = LoadSymbol("cuTexRefCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexRef); +} + +CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef); +} + +CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, CUarray hArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfref, CUarray, unsigned int); + static auto func_ptr = LoadSymbol("cuSurfRefSetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hSurfRef, hArray, Flags); +} + +CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, CUsurfref hSurfRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUsurfref); + static auto func_ptr = LoadSymbol("cuSurfRefGetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phArray, hSurfRef); +} + +CUresult CUDAAPI +cuTexObjectCreate(CUtexObject *pTexObject, const CUDA_RESOURCE_DESC *pResDesc, + const CUDA_TEXTURE_DESC *pTexDesc, + const CUDA_RESOURCE_VIEW_DESC *pResViewDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexObject *, const CUDA_RESOURCE_DESC *, + const CUDA_TEXTURE_DESC *, + const CUDA_RESOURCE_VIEW_DESC *); + static auto func_ptr = LoadSymbol("cuTexObjectCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexObject, pResDesc, pTexDesc, pResViewDesc); +} + +CUresult CUDAAPI cuTexObjectDestroy(CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texObject); +} + +CUresult CUDAAPI cuTexObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, texObject); +} + +CUresult CUDAAPI cuTexObjectGetTextureDesc(CUDA_TEXTURE_DESC *pTexDesc, + CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_TEXTURE_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetTextureDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexDesc, texObject); +} + +CUresult CUDAAPI cuTexObjectGetResourceViewDesc( + CUDA_RESOURCE_VIEW_DESC *pResViewDesc, CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_VIEW_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetResourceViewDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResViewDesc, texObject); +} + +CUresult CUDAAPI cuSurfObjectCreate(CUsurfObject *pSurfObject, + const CUDA_RESOURCE_DESC *pResDesc) { + using FuncPtr = + CUresult(CUDAAPI *)(CUsurfObject *, const CUDA_RESOURCE_DESC *); + static auto func_ptr = LoadSymbol("cuSurfObjectCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfObject, pResDesc); +} + +CUresult CUDAAPI cuSurfObjectDestroy(CUsurfObject surfObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfObject); + static auto func_ptr = LoadSymbol("cuSurfObjectDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfObject); +} + +CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUsurfObject surfObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_DESC *, CUsurfObject); + static auto func_ptr = LoadSymbol("cuSurfObjectGetResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, surfObject); +} + +CUresult CUDAAPI cuDeviceCanAccessPeer(int *canAccessPeer, CUdevice dev, + CUdevice peerDev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUdevice, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceCanAccessPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(canAccessPeer, dev, peerDev); +} + +CUresult CUDAAPI cuCtxEnablePeerAccess(CUcontext peerContext, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext, unsigned int); + static auto func_ptr = LoadSymbol("cuCtxEnablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerContext, Flags); +} + +CUresult CUDAAPI cuCtxDisablePeerAccess(CUcontext peerContext) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDisablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerContext); +} + +CUresult CUDAAPI cuDeviceGetP2PAttribute(int *value, + CUdevice_P2PAttribute attrib, + CUdevice srcDevice, + CUdevice dstDevice) { + using FuncPtr = + CUresult(CUDAAPI *)(int *, CUdevice_P2PAttribute, CUdevice, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetP2PAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attrib, srcDevice, dstDevice); +} + +CUresult CUDAAPI cuGraphicsUnregisterResource(CUgraphicsResource resource) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphicsResource); + static auto func_ptr = LoadSymbol("cuGraphicsUnregisterResource"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource); +} + +CUresult CUDAAPI cuGraphicsSubResourceGetMappedArray( + CUarray *pArray, CUgraphicsResource resource, unsigned int arrayIndex, + unsigned int mipLevel) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUgraphicsResource, + unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol("cuGraphicsSubResourceGetMappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArray, resource, arrayIndex, mipLevel); +} + +CUresult CUDAAPI cuGraphicsResourceGetMappedMipmappedArray( + CUmipmappedArray *pMipmappedArray, CUgraphicsResource resource) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray *, CUgraphicsResource); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pMipmappedArray, resource); +} + +CUresult CUDAAPI cuGraphicsResourceGetMappedPointer( + CUdeviceptr *pDevPtr, size_t *pSize, CUgraphicsResource resource) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUgraphicsResource); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceGetMappedPointer_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pDevPtr, pSize, resource); +} + +CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphicsResource, unsigned int); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceSetMapFlags_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource, flags); +} + +CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUgraphicsResource *, CUstream); + static auto func_ptr = LoadSymbol("cuGraphicsMapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, hStream); +} + +CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUgraphicsResource *, CUstream); + static auto func_ptr = LoadSymbol("cuGraphicsUnmapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, hStream); +} + +CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, + const CUuuid *pExportTableId) { + using FuncPtr = CUresult(CUDAAPI *)(const void **, const CUuuid *); + static auto func_ptr = LoadSymbol("cuGetExportTable"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ppExportTable, pExportTableId); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cuda_9_0.inc b/tensorflow/stream_executor/cuda/cuda_9_0.inc new file mode 100644 index 00000000000..14d32ca1c2c --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_9_0.inc @@ -0,0 +1,1718 @@ +// Auto-generated, do not edit. + +extern "C" { + +CUresult CUDAAPI cuGetErrorString(CUresult error, const char **pStr) { + using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); + static auto func_ptr = LoadSymbol("cuGetErrorString"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(error, pStr); +} + +CUresult CUDAAPI cuGetErrorName(CUresult error, const char **pStr) { + using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); + static auto func_ptr = LoadSymbol("cuGetErrorName"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(error, pStr); +} + +CUresult CUDAAPI cuInit(unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int); + static auto func_ptr = LoadSymbol("cuInit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(Flags); +} + +CUresult CUDAAPI cuDriverGetVersion(int *driverVersion) { + using FuncPtr = CUresult(CUDAAPI *)(int *); + static auto func_ptr = LoadSymbol("cuDriverGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(driverVersion); +} + +CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *, int); + static auto func_ptr = LoadSymbol("cuDeviceGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device, ordinal); +} + +CUresult CUDAAPI cuDeviceGetCount(int *count) { + using FuncPtr = CUresult(CUDAAPI *)(int *); + static auto func_ptr = LoadSymbol("cuDeviceGetCount"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count); +} + +CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(char *, int, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetName"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(name, len, dev); +} + +CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceTotalMem_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(bytes, dev); +} + +CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUdevice_attribute, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pi, attrib, dev); +} + +CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevprop *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetProperties"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(prop, dev); +} + +CUresult CUDAAPI cuDeviceComputeCapability(int *major, int *minor, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceComputeCapability"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(major, minor, dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxRetain"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxRelease"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice, unsigned int); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxSetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, flags); +} + +CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, + int *active) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice, unsigned int *, int *); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxGetState"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, flags, active); +} + +CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice); + static auto func_ptr = LoadSymbol("cuDevicePrimaryCtxReset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev); +} + +CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, unsigned int, CUdevice); + static auto func_ptr = LoadSymbol("cuCtxCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, flags, dev); +} + +CUresult CUDAAPI cuCtxDestroy(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxPushCurrent_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *); + static auto func_ptr = LoadSymbol("cuCtxPopCurrent_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx); +} + +CUresult CUDAAPI cuCtxSetCurrent(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxSetCurrent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxGetCurrent(CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *); + static auto func_ptr = LoadSymbol("cuCtxGetCurrent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx); +} + +CUresult CUDAAPI cuCtxGetDevice(CUdevice *device) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *); + static auto func_ptr = LoadSymbol("cuCtxGetDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device); +} + +CUresult CUDAAPI cuCtxGetFlags(unsigned int *flags) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *); + static auto func_ptr = LoadSymbol("cuCtxGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags); +} + +CUresult CUDAAPI cuCtxSynchronize(void) { + using FuncPtr = CUresult(CUDAAPI *)(); + static auto func_ptr = LoadSymbol("cuCtxSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +CUresult CUDAAPI cuCtxSetLimit(CUlimit limit, size_t value) { + using FuncPtr = CUresult(CUDAAPI *)(CUlimit, size_t); + static auto func_ptr = LoadSymbol("cuCtxSetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(limit, value); +} + +CUresult CUDAAPI cuCtxGetLimit(size_t *pvalue, CUlimit limit) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUlimit); + static auto func_ptr = LoadSymbol("cuCtxGetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pvalue, limit); +} + +CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunc_cache *); + static auto func_ptr = LoadSymbol("cuCtxGetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pconfig); +} + +CUresult CUDAAPI cuCtxSetCacheConfig(CUfunc_cache config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunc_cache); + static auto func_ptr = LoadSymbol("cuCtxSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +CUresult CUDAAPI cuCtxGetSharedMemConfig(CUsharedconfig *pConfig) { + using FuncPtr = CUresult(CUDAAPI *)(CUsharedconfig *); + static auto func_ptr = LoadSymbol("cuCtxGetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pConfig); +} + +CUresult CUDAAPI cuCtxSetSharedMemConfig(CUsharedconfig config) { + using FuncPtr = CUresult(CUDAAPI *)(CUsharedconfig); + static auto func_ptr = LoadSymbol("cuCtxSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +CUresult CUDAAPI cuCtxGetApiVersion(CUcontext ctx, unsigned int *version) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext, unsigned int *); + static auto func_ptr = LoadSymbol("cuCtxGetApiVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx, version); +} + +CUresult CUDAAPI cuCtxGetStreamPriorityRange(int *leastPriority, + int *greatestPriority) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *); + static auto func_ptr = LoadSymbol("cuCtxGetStreamPriorityRange"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(leastPriority, greatestPriority); +} + +CUresult CUDAAPI cuCtxAttach(CUcontext *pctx, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, unsigned int); + static auto func_ptr = LoadSymbol("cuCtxAttach"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, flags); +} + +CUresult CUDAAPI cuCtxDetach(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDetach"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuModuleLoad(CUmodule *module, const char *fname) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const char *); + static auto func_ptr = LoadSymbol("cuModuleLoad"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, fname); +} + +CUresult CUDAAPI cuModuleLoadData(CUmodule *module, const void *image) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *); + static auto func_ptr = LoadSymbol("cuModuleLoadData"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, image); +} + +CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module, const void *image, + unsigned int numOptions, + CUjit_option *options, + void **optionValues) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *, unsigned int, + CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuModuleLoadDataEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, image, numOptions, options, optionValues); +} + +CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *); + static auto func_ptr = LoadSymbol("cuModuleLoadFatBinary"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, fatCubin); +} + +CUresult CUDAAPI cuModuleUnload(CUmodule hmod) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule); + static auto func_ptr = LoadSymbol("cuModuleUnload"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hmod); +} + +CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetFunction"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, hmod, name); +} + +CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, size_t *bytes, + CUmodule hmod, const char *name) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetGlobal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytes, hmod, name); +} + +CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetTexRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexRef, hmod, name); +} + +CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfref *, CUmodule, const char *); + static auto func_ptr = LoadSymbol("cuModuleGetSurfRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfRef, hmod, name); +} + +CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, + void **optionValues, CUlinkState *stateOut) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUjit_option *, void **, CUlinkState *); + static auto func_ptr = LoadSymbol("cuLinkCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numOptions, options, optionValues, stateOut); +} + +CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, + void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, + void **optionValues) { + using FuncPtr = + CUresult(CUDAAPI *)(CUlinkState, CUjitInputType, void *, size_t, + const char *, unsigned int, CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuLinkAddData_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, type, data, size, name, numOptions, options, + optionValues); +} + +CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, + const char *path, unsigned int numOptions, + CUjit_option *options, void **optionValues) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState, CUjitInputType, const char *, + unsigned int, CUjit_option *, void **); + static auto func_ptr = LoadSymbol("cuLinkAddFile_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDAAPI cuLinkComplete(CUlinkState state, void **cubinOut, + size_t *sizeOut) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState, void **, size_t *); + static auto func_ptr = LoadSymbol("cuLinkComplete"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, cubinOut, sizeOut); +} + +CUresult CUDAAPI cuLinkDestroy(CUlinkState state) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState); + static auto func_ptr = LoadSymbol("cuLinkDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state); +} + +CUresult CUDAAPI cuMemGetInfo(size_t *free, size_t *total) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, size_t *); + static auto func_ptr = LoadSymbol("cuMemGetInfo_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(free, total); +} + +CUresult CUDAAPI cuMemAlloc(CUdeviceptr *dptr, size_t bytesize) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t); + static auto func_ptr = LoadSymbol("cuMemAlloc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytesize); +} + +CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, size_t *pPitch, + size_t WidthInBytes, size_t Height, + unsigned int ElementSizeBytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, size_t, size_t, + unsigned int); + static auto func_ptr = LoadSymbol("cuMemAllocPitch_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, pPitch, WidthInBytes, Height, ElementSizeBytes); +} + +CUresult CUDAAPI cuMemFree(CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr); + static auto func_ptr = LoadSymbol("cuMemFree_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr); +} + +CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr *pbase, size_t *psize, + CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuMemGetAddressRange_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pbase, psize, dptr); +} + +CUresult CUDAAPI cuMemAllocHost(void **pp, size_t bytesize) { + using FuncPtr = CUresult(CUDAAPI *)(void **, size_t); + static auto func_ptr = LoadSymbol("cuMemAllocHost_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pp, bytesize); +} + +CUresult CUDAAPI cuMemFreeHost(void *p) { + using FuncPtr = CUresult(CUDAAPI *)(void *); + static auto func_ptr = LoadSymbol("cuMemFreeHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +CUresult CUDAAPI cuMemHostAlloc(void **pp, size_t bytesize, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(void **, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostAlloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pp, bytesize, Flags); +} + +CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostGetDevicePointer_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, p, Flags); +} + +CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *, void *); + static auto func_ptr = LoadSymbol("cuMemHostGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, p); +} + +CUresult CUDAAPI cuMemAllocManaged(CUdeviceptr *dptr, size_t bytesize, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemAllocManaged"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytesize, flags); +} + +CUresult CUDAAPI cuDeviceGetByPCIBusId(CUdevice *dev, const char *pciBusId) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *, const char *); + static auto func_ptr = LoadSymbol("cuDeviceGetByPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, pciBusId); +} + +CUresult CUDAAPI cuDeviceGetPCIBusId(char *pciBusId, int len, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(char *, int, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pciBusId, len, dev); +} + +CUresult CUDAAPI cuIpcGetEventHandle(CUipcEventHandle *pHandle, CUevent event) { + using FuncPtr = CUresult(CUDAAPI *)(CUipcEventHandle *, CUevent); + static auto func_ptr = LoadSymbol("cuIpcGetEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, event); +} + +CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, + CUipcEventHandle handle) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent *, CUipcEventHandle); + static auto func_ptr = LoadSymbol("cuIpcOpenEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phEvent, handle); +} + +CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUipcMemHandle *, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuIpcGetMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, dptr); +} + +CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, + unsigned int Flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, CUipcMemHandle, unsigned int); + static auto func_ptr = LoadSymbol("cuIpcOpenMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, handle, Flags); +} + +CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr); + static auto func_ptr = LoadSymbol("cuIpcCloseMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr); +} + +CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(void *, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuMemHostRegister_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p, bytesize, Flags); +} + +CUresult CUDAAPI cuMemHostUnregister(void *p) { + using FuncPtr = CUresult(CUDAAPI *)(void *); + static auto func_ptr = LoadSymbol("cuMemHostUnregister"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, ByteCount); +} + +CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUcontext, CUdeviceptr, + CUcontext, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstContext, srcDevice, srcContext, ByteCount); +} + +CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, const void *, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyHtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcHost, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoH_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, size_t dstOffset, + CUdeviceptr srcDevice, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyDtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr dstDevice, CUarray srcArray, + size_t srcOffset, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyHtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcHost, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, size_t srcOffset, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoH_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, size_t dstOffset, + CUarray srcArray, size_t srcOffset, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemcpyAtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *); + static auto func_ptr = LoadSymbol("cuMemcpy2D_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *); + static auto func_ptr = LoadSymbol("cuMemcpy2DUnaligned_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol("cuMemcpy3D_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D_PEER *); + static auto func_ptr = LoadSymbol("cuMemcpy3DPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUcontext, CUdeviceptr, + CUcontext, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstContext, srcDevice, srcContext, ByteCount, + hStream); +} + +CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, const void *, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyHtoDAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcHost, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyDtoHAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcDevice, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyDtoDAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcDevice, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray, size_t, const void *, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyHtoAAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcHost, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, + size_t srcOffset, size_t ByteCount, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(void *, CUarray, size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpyAtoHAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcArray, srcOffset, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D *pCopy, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy2DAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy3DAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, + CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D_PEER *, CUstream); + static auto func_ptr = LoadSymbol("cuMemcpy3DPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemsetD8(CUdeviceptr dstDevice, unsigned char uc, size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned char, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD8_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, uc, N); +} + +CUresult CUDAAPI cuMemsetD16(CUdeviceptr dstDevice, unsigned short us, + size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned short, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD16_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, us, N); +} + +CUresult CUDAAPI cuMemsetD32(CUdeviceptr dstDevice, unsigned int ui, size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned int, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD32_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, ui, N); +} + +CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned char, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D8_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, uc, Width, Height); +} + +CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned short, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D16_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, us, Width, Height); +} + +CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned int, size_t, size_t); + static auto func_ptr = LoadSymbol("cuMemsetD2D32_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, ui, Width, Height); +} + +CUresult CUDAAPI cuMemsetD8Async(CUdeviceptr dstDevice, unsigned char uc, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned char, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD8Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, uc, N, hStream); +} + +CUresult CUDAAPI cuMemsetD16Async(CUdeviceptr dstDevice, unsigned short us, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned short, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD16Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, us, N, hStream); +} + +CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned int, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD32Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, ui, N, hStream); +} + +CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned char, + size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D8Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, uc, Width, Height, hStream); +} + +CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned short, + size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D16Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, us, Width, Height, hStream); +} + +CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned int, size_t, + size_t, CUstream); + static auto func_ptr = LoadSymbol("cuMemsetD2D32Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, ui, Width, Height, hStream); +} + +CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, + const CUDA_ARRAY_DESCRIPTOR *pAllocateArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, const CUDA_ARRAY_DESCRIPTOR *); + static auto func_ptr = LoadSymbol("cuArrayCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pAllocateArray); +} + +CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, + CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_ARRAY_DESCRIPTOR *, CUarray); + static auto func_ptr = LoadSymbol("cuArrayGetDescriptor_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArrayDescriptor, hArray); +} + +CUresult CUDAAPI cuArrayDestroy(CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray); + static auto func_ptr = LoadSymbol("cuArrayDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hArray); +} + +CUresult CUDAAPI cuArray3DCreate( + CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR *pAllocateArray) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray *, const CUDA_ARRAY3D_DESCRIPTOR *); + static auto func_ptr = LoadSymbol("cuArray3DCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pAllocateArray); +} + +CUresult CUDAAPI cuArray3DGetDescriptor( + CUDA_ARRAY3D_DESCRIPTOR *pArrayDescriptor, CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_ARRAY3D_DESCRIPTOR *, CUarray); + static auto func_ptr = LoadSymbol("cuArray3DGetDescriptor_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArrayDescriptor, hArray); +} + +CUresult CUDAAPI +cuMipmappedArrayCreate(CUmipmappedArray *pHandle, + const CUDA_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc, + unsigned int numMipmapLevels) { + using FuncPtr = CUresult(CUDAAPI *)( + CUmipmappedArray *, const CUDA_ARRAY3D_DESCRIPTOR *, unsigned int); + static auto func_ptr = LoadSymbol("cuMipmappedArrayCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pMipmappedArrayDesc, numMipmapLevels); +} + +CUresult CUDAAPI cuMipmappedArrayGetLevel(CUarray *pLevelArray, + CUmipmappedArray hMipmappedArray, + unsigned int level) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray *, CUmipmappedArray, unsigned int); + static auto func_ptr = LoadSymbol("cuMipmappedArrayGetLevel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pLevelArray, hMipmappedArray, level); +} + +CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray); + static auto func_ptr = LoadSymbol("cuMipmappedArrayDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hMipmappedArray); +} + +CUresult CUDAAPI cuPointerGetAttribute(void *data, + CUpointer_attribute attribute, + CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUpointer_attribute, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, attribute, ptr); +} + +CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, + CUdevice dstDevice, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, CUdevice, CUstream); + static auto func_ptr = LoadSymbol("cuMemPrefetchAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, dstDevice, hStream); +} + +CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, + CUmem_advise advice, CUdevice device) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, CUmem_advise, CUdevice); + static auto func_ptr = LoadSymbol("cuMemAdvise"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, advice, device); +} + +CUresult CUDAAPI cuMemRangeGetAttribute(void *data, size_t dataSize, + CUmem_range_attribute attribute, + CUdeviceptr devPtr, size_t count) { + using FuncPtr = CUresult(CUDAAPI *)(void *, size_t, CUmem_range_attribute, + CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemRangeGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSize, attribute, devPtr, count); +} + +CUresult CUDAAPI cuMemRangeGetAttributes(void **data, size_t *dataSizes, + CUmem_range_attribute *attributes, + size_t numAttributes, + CUdeviceptr devPtr, size_t count) { + using FuncPtr = CUresult(CUDAAPI *)( + void **, size_t *, CUmem_range_attribute *, size_t, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuMemRangeGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSizes, attributes, numAttributes, devPtr, count); +} + +CUresult CUDAAPI cuPointerSetAttribute(const void *value, + CUpointer_attribute attribute, + CUdeviceptr ptr) { + using FuncPtr = + CUresult(CUDAAPI *)(const void *, CUpointer_attribute, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attribute, ptr); +} + +CUresult CUDAAPI cuPointerGetAttributes(unsigned int numAttributes, + CUpointer_attribute *attributes, + void **data, CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int, CUpointer_attribute *, + void **, CUdeviceptr); + static auto func_ptr = LoadSymbol("cuPointerGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numAttributes, attributes, data, ptr); +} + +CUresult CUDAAPI cuStreamCreate(CUstream *phStream, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phStream, Flags); +} + +CUresult CUDAAPI cuStreamCreateWithPriority(CUstream *phStream, + unsigned int flags, int priority) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream *, unsigned int, int); + static auto func_ptr = LoadSymbol("cuStreamCreateWithPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phStream, flags, priority); +} + +CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, int *); + static auto func_ptr = LoadSymbol("cuStreamGetPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, priority); +} + +CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, unsigned int *); + static auto func_ptr = LoadSymbol("cuStreamGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, flags); +} + +CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUevent, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitEvent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, hEvent, Flags); +} + +CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, + CUstreamCallback callback, void *userData, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamCallback, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamAddCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, callback, userData, flags); +} + +CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, + size_t length, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamAttachMemAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, dptr, length, flags); +} + +CUresult CUDAAPI cuStreamQuery(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamSynchronize(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamDestroy(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol("cuStreamDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuEventCreate(CUevent *phEvent, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent *, unsigned int); + static auto func_ptr = LoadSymbol("cuEventCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phEvent, Flags); +} + +CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent, CUstream); + static auto func_ptr = LoadSymbol("cuEventRecord"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent, hStream); +} + +CUresult CUDAAPI cuEventQuery(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventSynchronize(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventDestroy(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol("cuEventDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, + CUevent hEnd) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUevent, CUevent); + static auto func_ptr = LoadSymbol("cuEventElapsedTime"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pMilliseconds, hStart, hEnd); +} + +CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint32_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitValue32"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint64_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWaitValue64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint32_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWriteValue32"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint64_t, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamWriteValue64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, + CUstreamBatchMemOpParams *paramArray, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, unsigned int, + CUstreamBatchMemOpParams *, unsigned int); + static auto func_ptr = LoadSymbol("cuStreamBatchMemOp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, count, paramArray, flags); +} + +CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, + CUfunction hfunc) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction_attribute, CUfunction); + static auto func_ptr = LoadSymbol("cuFuncGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pi, attrib, hfunc); +} + +CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, + CUfunction_attribute attrib, int value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUfunction_attribute, int); + static auto func_ptr = LoadSymbol("cuFuncSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, attrib, value); +} + +CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUfunc_cache); + static auto func_ptr = LoadSymbol("cuFuncSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, config); +} + +CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, + CUsharedconfig config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUsharedconfig); + static auto func_ptr = LoadSymbol("cuFuncSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, config); +} + +CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams, void **extra) { + using FuncPtr = CUresult(CUDAAPI *)( + CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **, void **); + static auto func_ptr = LoadSymbol("cuLaunchKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDAAPI cuLaunchCooperativeKernel( + CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams) { + using FuncPtr = CUresult(CUDAAPI *)( + CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **); + static auto func_ptr = LoadSymbol("cuLaunchCooperativeKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice( + CUDA_LAUNCH_PARAMS *launchParamsList, unsigned int numDevices, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUDA_LAUNCH_PARAMS *, unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol("cuLaunchCooperativeKernelMultiDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(launchParamsList, numDevices, flags); +} + +CUresult CUDAAPI cuFuncSetBlockShape(CUfunction hfunc, int x, int y, int z) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int, int); + static auto func_ptr = LoadSymbol("cuFuncSetBlockShape"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, x, y, z); +} + +CUresult CUDAAPI cuFuncSetSharedSize(CUfunction hfunc, unsigned int bytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, unsigned int); + static auto func_ptr = LoadSymbol("cuFuncSetSharedSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, bytes); +} + +CUresult CUDAAPI cuParamSetSize(CUfunction hfunc, unsigned int numbytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSetSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, numbytes); +} + +CUresult CUDAAPI cuParamSeti(CUfunction hfunc, int offset, unsigned int value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSeti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, value); +} + +CUresult CUDAAPI cuParamSetf(CUfunction hfunc, int offset, float value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, float); + static auto func_ptr = LoadSymbol("cuParamSetf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, value); +} + +CUresult CUDAAPI cuParamSetv(CUfunction hfunc, int offset, void *ptr, + unsigned int numbytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, void *, unsigned int); + static auto func_ptr = LoadSymbol("cuParamSetv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, ptr, numbytes); +} + +CUresult CUDAAPI cuLaunch(CUfunction f) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction); + static auto func_ptr = LoadSymbol("cuLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f); +} + +CUresult CUDAAPI cuLaunchGrid(CUfunction f, int grid_width, int grid_height) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int); + static auto func_ptr = LoadSymbol("cuLaunchGrid"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, grid_width, grid_height); +} + +CUresult CUDAAPI cuLaunchGridAsync(CUfunction f, int grid_width, + int grid_height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int, CUstream); + static auto func_ptr = LoadSymbol("cuLaunchGridAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, grid_width, grid_height, hStream); +} + +CUresult CUDAAPI cuParamSetTexRef(CUfunction hfunc, int texunit, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, CUtexref); + static auto func_ptr = LoadSymbol("cuParamSetTexRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, texunit, hTexRef); +} + +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessor( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction, int, size_t); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxActiveBlocksPerMultiprocessor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize); +} + +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(int *, CUfunction, int, size_t, unsigned int); + static auto func_ptr = LoadSymbol( + "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize, flags); +} + +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSize( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *, CUfunction, + CUoccupancyB2DSize, size_t, int); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxPotentialBlockSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(minGridSize, blockSize, func, blockSizeToDynamicSMemSize, + dynamicSMemSize, blockSizeLimit); +} + +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)( + int *, int *, CUfunction, CUoccupancyB2DSize, size_t, int, unsigned int); + static auto func_ptr = + LoadSymbol("cuOccupancyMaxPotentialBlockSizeWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(minGridSize, blockSize, func, blockSizeToDynamicSMemSize, + dynamicSMemSize, blockSizeLimit, flags); +} + +CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUarray, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, hArray, Flags); +} + +CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, + CUmipmappedArray hMipmappedArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUmipmappedArray, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, hMipmappedArray, Flags); +} + +CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexref hTexRef, + CUdeviceptr dptr, size_t bytes) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUtexref, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuTexRefSetAddress_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ByteOffset, hTexRef, dptr, bytes); +} + +CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, + const CUDA_ARRAY_DESCRIPTOR *desc, + CUdeviceptr dptr, size_t Pitch) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, const CUDA_ARRAY_DESCRIPTOR *, + CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol("cuTexRefSetAddress2D_v3"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, desc, dptr, Pitch); +} + +CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_format fmt, + int NumPackedComponents) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUarray_format, int); + static auto func_ptr = LoadSymbol("cuTexRefSetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fmt, NumPackedComponents); +} + +CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int dim, + CUaddress_mode am) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, int, CUaddress_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetAddressMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, dim, am); +} + +CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfilter_mode fm) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUfilter_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fm); +} + +CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, + CUfilter_mode fm) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUfilter_mode); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fm); +} + +CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, float bias) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapLevelBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, bias); +} + +CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, + float minMipmapLevelClamp, + float maxMipmapLevelClamp) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float, float); + static auto func_ptr = LoadSymbol("cuTexRefSetMipmapLevelClamp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, minMipmapLevelClamp, maxMipmapLevelClamp); +} + +CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, + unsigned int maxAniso) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetMaxAnisotropy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, maxAniso); +} + +CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float *); + static auto func_ptr = LoadSymbol("cuTexRefSetBorderColor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, pBorderColor); +} + +CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, unsigned int); + static auto func_ptr = LoadSymbol("cuTexRefSetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, Flags); +} + +CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetAddress_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phArray, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmappedArray(CUmipmappedArray *phMipmappedArray, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phMipmappedArray, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, CUtexref hTexRef, + int dim) { + using FuncPtr = CUresult(CUDAAPI *)(CUaddress_mode *, CUtexref, int); + static auto func_ptr = LoadSymbol("cuTexRefGetAddressMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pam, hTexRef, dim); +} + +CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfilter_mode *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pfm, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, int *pNumChannels, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray_format *, int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFormat, pNumChannels, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfilter_mode *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pfm, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapLevelBias(float *pbias, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapLevelBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pbias, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, + float *pmaxMipmapLevelClamp, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMipmapLevelClamp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pminMipmapLevelClamp, pmaxMipmapLevelClamp, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetMaxAnisotropy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pmaxAniso, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetBorderColor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pBorderColor, hTexRef); +} + +CUresult CUDAAPI cuTexRefGetFlags(unsigned int *pFlags, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *, CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, hTexRef); +} + +CUresult CUDAAPI cuTexRefCreate(CUtexref *pTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref *); + static auto func_ptr = LoadSymbol("cuTexRefCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexRef); +} + +CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref); + static auto func_ptr = LoadSymbol("cuTexRefDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef); +} + +CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, CUarray hArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfref, CUarray, unsigned int); + static auto func_ptr = LoadSymbol("cuSurfRefSetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hSurfRef, hArray, Flags); +} + +CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, CUsurfref hSurfRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUsurfref); + static auto func_ptr = LoadSymbol("cuSurfRefGetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phArray, hSurfRef); +} + +CUresult CUDAAPI +cuTexObjectCreate(CUtexObject *pTexObject, const CUDA_RESOURCE_DESC *pResDesc, + const CUDA_TEXTURE_DESC *pTexDesc, + const CUDA_RESOURCE_VIEW_DESC *pResViewDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexObject *, const CUDA_RESOURCE_DESC *, + const CUDA_TEXTURE_DESC *, + const CUDA_RESOURCE_VIEW_DESC *); + static auto func_ptr = LoadSymbol("cuTexObjectCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexObject, pResDesc, pTexDesc, pResViewDesc); +} + +CUresult CUDAAPI cuTexObjectDestroy(CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texObject); +} + +CUresult CUDAAPI cuTexObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, texObject); +} + +CUresult CUDAAPI cuTexObjectGetTextureDesc(CUDA_TEXTURE_DESC *pTexDesc, + CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_TEXTURE_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetTextureDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexDesc, texObject); +} + +CUresult CUDAAPI cuTexObjectGetResourceViewDesc( + CUDA_RESOURCE_VIEW_DESC *pResViewDesc, CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_VIEW_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol("cuTexObjectGetResourceViewDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResViewDesc, texObject); +} + +CUresult CUDAAPI cuSurfObjectCreate(CUsurfObject *pSurfObject, + const CUDA_RESOURCE_DESC *pResDesc) { + using FuncPtr = + CUresult(CUDAAPI *)(CUsurfObject *, const CUDA_RESOURCE_DESC *); + static auto func_ptr = LoadSymbol("cuSurfObjectCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfObject, pResDesc); +} + +CUresult CUDAAPI cuSurfObjectDestroy(CUsurfObject surfObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfObject); + static auto func_ptr = LoadSymbol("cuSurfObjectDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfObject); +} + +CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUsurfObject surfObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_DESC *, CUsurfObject); + static auto func_ptr = LoadSymbol("cuSurfObjectGetResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, surfObject); +} + +CUresult CUDAAPI cuDeviceCanAccessPeer(int *canAccessPeer, CUdevice dev, + CUdevice peerDev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUdevice, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceCanAccessPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(canAccessPeer, dev, peerDev); +} + +CUresult CUDAAPI cuCtxEnablePeerAccess(CUcontext peerContext, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext, unsigned int); + static auto func_ptr = LoadSymbol("cuCtxEnablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerContext, Flags); +} + +CUresult CUDAAPI cuCtxDisablePeerAccess(CUcontext peerContext) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol("cuCtxDisablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerContext); +} + +CUresult CUDAAPI cuDeviceGetP2PAttribute(int *value, + CUdevice_P2PAttribute attrib, + CUdevice srcDevice, + CUdevice dstDevice) { + using FuncPtr = + CUresult(CUDAAPI *)(int *, CUdevice_P2PAttribute, CUdevice, CUdevice); + static auto func_ptr = LoadSymbol("cuDeviceGetP2PAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attrib, srcDevice, dstDevice); +} + +CUresult CUDAAPI cuGraphicsUnregisterResource(CUgraphicsResource resource) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphicsResource); + static auto func_ptr = LoadSymbol("cuGraphicsUnregisterResource"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource); +} + +CUresult CUDAAPI cuGraphicsSubResourceGetMappedArray( + CUarray *pArray, CUgraphicsResource resource, unsigned int arrayIndex, + unsigned int mipLevel) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUgraphicsResource, + unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol("cuGraphicsSubResourceGetMappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArray, resource, arrayIndex, mipLevel); +} + +CUresult CUDAAPI cuGraphicsResourceGetMappedMipmappedArray( + CUmipmappedArray *pMipmappedArray, CUgraphicsResource resource) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray *, CUgraphicsResource); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pMipmappedArray, resource); +} + +CUresult CUDAAPI cuGraphicsResourceGetMappedPointer( + CUdeviceptr *pDevPtr, size_t *pSize, CUgraphicsResource resource) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUgraphicsResource); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceGetMappedPointer_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pDevPtr, pSize, resource); +} + +CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphicsResource, unsigned int); + static auto func_ptr = + LoadSymbol("cuGraphicsResourceSetMapFlags_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource, flags); +} + +CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUgraphicsResource *, CUstream); + static auto func_ptr = LoadSymbol("cuGraphicsMapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, hStream); +} + +CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUgraphicsResource *, CUstream); + static auto func_ptr = LoadSymbol("cuGraphicsUnmapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, hStream); +} + +CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, + const CUuuid *pExportTableId) { + using FuncPtr = CUresult(CUDAAPI *)(const void **, const CUuuid *); + static auto func_ptr = LoadSymbol("cuGetExportTable"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ppExportTable, pExportTableId); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index aceec6211a7..94ddaec03ac 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -430,6 +430,14 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, return ret == CUBLAS_STATUS_SUCCESS; } +// cublas_func may be overloaded, so we need to figure out which one we really +// need to call based on the args. One way to do it is to wrap it in lambda. +#define AS_LAMBDA(func) \ + [](auto &&... args) -> decltype( \ + func(std::forward(args)...)) { \ + return func(std::forward(args)...); \ + } + bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { @@ -1953,8 +1961,9 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( // essentially reinterpet_cast to __half, which is safe because Eigen::half // inherits from __half. bool result = DoBlasInternalFailureOK( - cublasGemmEx, stream, /* pointer_mode_host = */ !alpha.is_pointer(), - CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + AS_LAMBDA(cublasGemmEx), stream, + /* pointer_mode_host = */ !alpha.is_pointer(), CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value(), GpuMemory(a), cuda_in_type, lda, GpuMemory(b), cuda_in_type, ldb, beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value(), @@ -2227,7 +2236,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( reinterpret_cast(const_cast(GpuMemory(c))); bool ok; ok = DoBlasInternalImpl( - cublasGemmBatchedEx, stream, true /* = pointer_mode_host */, + AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */, true /* = err_on_failure */, use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda, b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc, @@ -2375,12 +2384,12 @@ bool CUDABlas::DoBlasGemmStridedBatched( cublasGemmAlgo_t algo = (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); bool ok = DoBlasInternalImpl( - cublasGemmStridedBatchedEx, stream, true /* = pointer_mode_host */, - true /* = err_on_failure */, use_tensor_ops, - CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, - GpuMemory(a), CUDA_R_16F, lda, stride_a, GpuMemory(b), CUDA_R_16F, - ldb, stride_b, &beta, GpuMemoryMutable(c), CUDA_R_16F, ldc, stride_c, - batch_count, CUDA_R_32F, algo); + AS_LAMBDA(cublasGemmStridedBatchedEx), stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, + use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), + m, n, k, &alpha, GpuMemory(a), CUDA_R_16F, lda, stride_a, + GpuMemory(b), CUDA_R_16F, ldb, stride_b, &beta, GpuMemoryMutable(c), + CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); if (ok) { return true; } diff --git a/tensorflow/stream_executor/cuda/cuda_runtime_10_0.inc b/tensorflow/stream_executor/cuda/cuda_runtime_10_0.inc index 9b912330512..89a072dde12 100644 --- a/tensorflow/stream_executor/cuda/cuda_runtime_10_0.inc +++ b/tensorflow/stream_executor/cuda/cuda_runtime_10_0.inc @@ -383,6 +383,22 @@ cudaStreamAttachMemAsync(cudaStream_t stream, void *devPtr, return func_ptr(stream, devPtr, length, flags); } +extern __host__ cudaError_t CUDARTAPI +cudaStreamBeginCapture(cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol("cudaStreamBeginCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamEndCapture(cudaStream_t stream, cudaGraph_t *pGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, cudaGraph_t *); + static auto func_ptr = LoadSymbol("cudaStreamEndCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, pGraph); +} + extern __host__ cudaError_t CUDARTAPI cudaStreamIsCapturing( cudaStream_t stream, enum cudaStreamCaptureStatus *pCaptureStatus) { using FuncPtr = @@ -1508,6 +1524,306 @@ cudaRuntimeGetVersion(int *runtimeVersion) { return func_ptr(runtimeVersion); } +extern __host__ cudaError_t CUDARTAPI cudaGraphCreate(cudaGraph_t *pGraph, + unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t *, unsigned int); + static auto func_ptr = LoadSymbol("cudaGraphCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraph, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddKernelNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + cudaGraphNode_t *, size_t, + const struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphAddKernelNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphKernelNodeGetParams( + cudaGraphNode_t node, struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphKernelNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphKernelNodeSetParams( + cudaGraphNode_t node, const struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddMemcpyNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaMemcpy3DParms *pCopyParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + cudaGraphNode_t *, size_t, + const struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol("cudaGraphAddMemcpyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pCopyParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemcpyNodeGetParams( + cudaGraphNode_t node, struct cudaMemcpy3DParms *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol("cudaGraphMemcpyNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemcpyNodeSetParams( + cudaGraphNode_t node, const struct cudaMemcpy3DParms *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol("cudaGraphMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddMemsetNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaMemsetParams *pMemsetParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + cudaGraphNode_t *, size_t, + const struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol("cudaGraphAddMemsetNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pMemsetParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemsetNodeGetParams( + cudaGraphNode_t node, struct cudaMemsetParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol("cudaGraphMemsetNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemsetNodeSetParams( + cudaGraphNode_t node, const struct cudaMemsetParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol("cudaGraphMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddHostNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + cudaGraphNode_t *, size_t, + const struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphAddHostNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphHostNodeGetParams( + cudaGraphNode_t node, struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphHostNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphHostNodeSetParams( + cudaGraphNode_t node, const struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddChildGraphNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + cudaGraphNode_t *pDependencies, + size_t numDependencies, cudaGraph_t childGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaGraphNode_t *, cudaGraph_t, cudaGraphNode_t *, size_t, cudaGraph_t); + static auto func_ptr = LoadSymbol("cudaGraphAddChildGraphNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + childGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphChildGraphNodeGetGraph(cudaGraphNode_t node, cudaGraph_t *pGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraph_t *); + static auto func_ptr = LoadSymbol("cudaGraphChildGraphNodeGetGraph"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddEmptyNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + cudaGraphNode_t *pDependencies, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol("cudaGraphAddEmptyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphClone(cudaGraph_t *pGraphClone, cudaGraph_t originalGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t *, cudaGraph_t); + static auto func_ptr = LoadSymbol("cudaGraphClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphClone, originalGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphNodeFindInClone(cudaGraphNode_t *pNode, cudaGraphNode_t originalNode, + cudaGraph_t clonedGraph) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraphNode_t, cudaGraph_t); + static auto func_ptr = LoadSymbol("cudaGraphNodeFindInClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pNode, originalNode, clonedGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphNodeGetType(cudaGraphNode_t node, enum cudaGraphNodeType *pType) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, enum cudaGraphNodeType *); + static auto func_ptr = LoadSymbol("cudaGraphNodeGetType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pType); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetNodes(cudaGraph_t graph, + cudaGraphNode_t *nodes, + size_t *numNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphGetNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, nodes, numNodes); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetRootNodes( + cudaGraph_t graph, cudaGraphNode_t *pRootNodes, size_t *pNumRootNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphGetRootNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, pRootNodes, pNumRootNodes); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetEdges(cudaGraph_t graph, + cudaGraphNode_t *from, + cudaGraphNode_t *to, + size_t *numEdges) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, + cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphGetEdges"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numEdges); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphNodeGetDependencies( + cudaGraphNode_t node, cudaGraphNode_t *pDependencies, + size_t *pNumDependencies) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphNodeGetDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pDependencies, pNumDependencies); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphNodeGetDependentNodes( + cudaGraphNode_t node, cudaGraphNode_t *pDependentNodes, + size_t *pNumDependentNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphNodeGetDependentNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pDependentNodes, pNumDependentNodes); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddDependencies(cudaGraph_t graph, cudaGraphNode_t *from, + cudaGraphNode_t *to, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, + cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol("cudaGraphAddDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphRemoveDependencies(cudaGraph_t graph, cudaGraphNode_t *from, + cudaGraphNode_t *to, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, + cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol("cudaGraphRemoveDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphDestroyNode(cudaGraphNode_t node) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t); + static auto func_ptr = LoadSymbol("cudaGraphDestroyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphInstantiate( + cudaGraphExec_t *pGraphExec, cudaGraph_t graph, cudaGraphNode_t *pErrorNode, + char *pLogBuffer, size_t bufferSize) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t *, cudaGraph_t, + cudaGraphNode_t *, char *, size_t); + static auto func_ptr = LoadSymbol("cudaGraphInstantiate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphExec, graph, pErrorNode, pLogBuffer, bufferSize); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphLaunch(cudaGraphExec_t graphExec, + cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaGraphLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graphExec, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphExecDestroy(cudaGraphExec_t graphExec) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t); + static auto func_ptr = LoadSymbol("cudaGraphExecDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graphExec); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphDestroy(cudaGraph_t graph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t); + static auto func_ptr = LoadSymbol("cudaGraphDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph); +} + extern __host__ cudaError_t CUDARTAPI cudaGetExportTable( const void **ppExportTable, const cudaUUID_t *pExportTableId) { using FuncPtr = cudaError_t(CUDARTAPI *)(const void **, const cudaUUID_t *); @@ -1515,4 +1831,5 @@ extern __host__ cudaError_t CUDARTAPI cudaGetExportTable( if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ppExportTable, pExportTableId); } + } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cuda_runtime_10_2.inc b/tensorflow/stream_executor/cuda/cuda_runtime_10_2.inc new file mode 100644 index 00000000000..b7ecc3b7c8d --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_runtime_10_2.inc @@ -0,0 +1,1896 @@ +// Auto-generated, do not edit. + +extern "C" { + +extern __host__ cudaError_t CUDARTAPI cudaDeviceReset(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol("cudaDeviceReset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceSynchronize(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol("cudaDeviceSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ cudaError_t CUDARTAPI cudaDeviceSetLimit(enum cudaLimit limit, + size_t value) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaLimit, size_t); + static auto func_ptr = LoadSymbol("cudaDeviceSetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(limit, value); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetLimit(size_t *pValue, enum cudaLimit limit) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, enum cudaLimit); + static auto func_ptr = LoadSymbol("cudaDeviceGetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pValue, limit); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetCacheConfig(enum cudaFuncCache *pCacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaFuncCache *); + static auto func_ptr = LoadSymbol("cudaDeviceGetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCacheConfig); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetStreamPriorityRange(int *leastPriority, int *greatestPriority) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, int *); + static auto func_ptr = + LoadSymbol("cudaDeviceGetStreamPriorityRange"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(leastPriority, greatestPriority); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceSetCacheConfig(enum cudaFuncCache cacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaFuncCache); + static auto func_ptr = LoadSymbol("cudaDeviceSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(cacheConfig); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetSharedMemConfig(enum cudaSharedMemConfig *pConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaSharedMemConfig *); + static auto func_ptr = LoadSymbol("cudaDeviceGetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pConfig); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceSetSharedMemConfig(enum cudaSharedMemConfig config) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaSharedMemConfig); + static auto func_ptr = LoadSymbol("cudaDeviceSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceGetByPCIBusId(int *device, const char *pciBusId) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, const char *); + static auto func_ptr = LoadSymbol("cudaDeviceGetByPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device, pciBusId); +} + +extern __host__ cudaError_t CUDARTAPI cudaDeviceGetPCIBusId(char *pciBusId, + int len, + int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(char *, int, int); + static auto func_ptr = LoadSymbol("cudaDeviceGetPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pciBusId, len, device); +} + +extern __host__ cudaError_t CUDARTAPI +cudaIpcGetEventHandle(cudaIpcEventHandle_t *handle, cudaEvent_t event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaIpcEventHandle_t *, cudaEvent_t); + static auto func_ptr = LoadSymbol("cudaIpcGetEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, event); +} + +extern __host__ cudaError_t CUDARTAPI +cudaIpcOpenEventHandle(cudaEvent_t *event, cudaIpcEventHandle_t handle) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t *, cudaIpcEventHandle_t); + static auto func_ptr = LoadSymbol("cudaIpcOpenEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event, handle); +} + +extern __host__ cudaError_t CUDARTAPI +cudaIpcGetMemHandle(cudaIpcMemHandle_t *handle, void *devPtr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaIpcMemHandle_t *, void *); + static auto func_ptr = LoadSymbol("cudaIpcGetMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, devPtr); +} + +extern __host__ cudaError_t CUDARTAPI cudaIpcOpenMemHandle( + void **devPtr, cudaIpcMemHandle_t handle, unsigned int flags) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void **, cudaIpcMemHandle_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaIpcOpenMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, handle, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaIpcCloseMemHandle(void *devPtr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *); + static auto func_ptr = LoadSymbol("cudaIpcCloseMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaThreadExit(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol("cudaThreadExit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadSynchronize(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol("cudaThreadSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadSetLimit(enum cudaLimit limit, size_t value) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaLimit, size_t); + static auto func_ptr = LoadSymbol("cudaThreadSetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(limit, value); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadGetLimit(size_t *pValue, enum cudaLimit limit) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, enum cudaLimit); + static auto func_ptr = LoadSymbol("cudaThreadGetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pValue, limit); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadGetCacheConfig(enum cudaFuncCache *pCacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaFuncCache *); + static auto func_ptr = LoadSymbol("cudaThreadGetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCacheConfig); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadSetCacheConfig(enum cudaFuncCache cacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaFuncCache); + static auto func_ptr = LoadSymbol("cudaThreadSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(cacheConfig); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaGetLastError(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol("cudaGetLastError"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaPeekAtLastError(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol("cudaPeekAtLastError"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ __cudart_builtin__ const char *CUDARTAPI +cudaGetErrorName(cudaError_t error) { + using FuncPtr = const char *(CUDARTAPI *)(cudaError_t); + static auto func_ptr = LoadSymbol("cudaGetErrorName"); + if (!func_ptr) return "cudaGetErrorName symbol not found."; + return func_ptr(error); +} + +extern __host__ __cudart_builtin__ const char *CUDARTAPI +cudaGetErrorString(cudaError_t error) { + using FuncPtr = const char *(CUDARTAPI *)(cudaError_t); + static auto func_ptr = LoadSymbol("cudaGetErrorString"); + if (!func_ptr) return "cudaGetErrorString symbol not found."; + return func_ptr(error); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaGetDeviceCount(int *count) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *); + static auto func_ptr = LoadSymbol("cudaGetDeviceCount"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaGetDeviceProperties(struct cudaDeviceProp *prop, int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaDeviceProp *, int); + static auto func_ptr = LoadSymbol("cudaGetDeviceProperties"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(prop, device); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetAttribute(int *value, enum cudaDeviceAttr attr, int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, enum cudaDeviceAttr, int); + static auto func_ptr = LoadSymbol("cudaDeviceGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attr, device); +} + +extern __host__ cudaError_t CUDARTAPI cudaDeviceGetNvSciSyncAttributes( + void *nvSciSyncAttrList, int device, int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, int, int); + static auto func_ptr = + LoadSymbol("cudaDeviceGetNvSciSyncAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(nvSciSyncAttrList, device, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetP2PAttribute(int *value, enum cudaDeviceP2PAttr attr, + int srcDevice, int dstDevice) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int *, enum cudaDeviceP2PAttr, int, int); + static auto func_ptr = LoadSymbol("cudaDeviceGetP2PAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attr, srcDevice, dstDevice); +} + +extern __host__ cudaError_t CUDARTAPI +cudaChooseDevice(int *device, const struct cudaDeviceProp *prop) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int *, const struct cudaDeviceProp *); + static auto func_ptr = LoadSymbol("cudaChooseDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device, prop); +} + +extern __host__ cudaError_t CUDARTAPI cudaSetDevice(int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int); + static auto func_ptr = LoadSymbol("cudaSetDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaGetDevice(int *device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *); + static auto func_ptr = LoadSymbol("cudaGetDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device); +} + +extern __host__ cudaError_t CUDARTAPI cudaSetValidDevices(int *device_arr, + int len) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, int); + static auto func_ptr = LoadSymbol("cudaSetValidDevices"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device_arr, len); +} + +extern __host__ cudaError_t CUDARTAPI cudaSetDeviceFlags(unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(unsigned int); + static auto func_ptr = LoadSymbol("cudaSetDeviceFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetDeviceFlags(unsigned int *flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(unsigned int *); + static auto func_ptr = LoadSymbol("cudaGetDeviceFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaStreamCreate(cudaStream_t *pStream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t *); + static auto func_ptr = LoadSymbol("cudaStreamCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pStream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamCreateWithFlags(cudaStream_t *pStream, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t *, unsigned int); + static auto func_ptr = LoadSymbol("cudaStreamCreateWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pStream, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamCreateWithPriority(cudaStream_t *pStream, unsigned int flags, + int priority) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t *, unsigned int, int); + static auto func_ptr = LoadSymbol("cudaStreamCreateWithPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pStream, flags, priority); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamGetPriority(cudaStream_t hStream, int *priority) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, int *); + static auto func_ptr = LoadSymbol("cudaStreamGetPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, priority); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamGetFlags(cudaStream_t hStream, unsigned int *flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, unsigned int *); + static auto func_ptr = LoadSymbol("cudaStreamGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamDestroy(cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol("cudaStreamDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaStreamWaitEvent( + cudaStream_t stream, cudaEvent_t event, unsigned int flags) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaStream_t, cudaEvent_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaStreamWaitEvent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, event, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamAddCallback(cudaStream_t stream, cudaStreamCallback_t callback, + void *userData, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, cudaStreamCallback_t, + void *, unsigned int); + static auto func_ptr = LoadSymbol("cudaStreamAddCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, callback, userData, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamSynchronize(cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol("cudaStreamSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaStreamQuery(cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol("cudaStreamQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamAttachMemAsync(cudaStream_t stream, void *devPtr, + size_t length __dv(0), + unsigned int flags __dv(cudaMemAttachSingle)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaStream_t, void *, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaStreamAttachMemAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, devPtr, length, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamBeginCapture(cudaStream_t stream, enum cudaStreamCaptureMode mode) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaStream_t, enum cudaStreamCaptureMode); + static auto func_ptr = LoadSymbol("cudaStreamBeginCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, mode); +} + +extern __host__ cudaError_t CUDARTAPI +cudaThreadExchangeStreamCaptureMode(enum cudaStreamCaptureMode *mode) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaStreamCaptureMode *); + static auto func_ptr = + LoadSymbol("cudaThreadExchangeStreamCaptureMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mode); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamEndCapture(cudaStream_t stream, cudaGraph_t *pGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, cudaGraph_t *); + static auto func_ptr = LoadSymbol("cudaStreamEndCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, pGraph); +} + +extern __host__ cudaError_t CUDARTAPI cudaStreamIsCapturing( + cudaStream_t stream, enum cudaStreamCaptureStatus *pCaptureStatus) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaStream_t, enum cudaStreamCaptureStatus *); + static auto func_ptr = LoadSymbol("cudaStreamIsCapturing"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, pCaptureStatus); +} + +extern __host__ cudaError_t CUDARTAPI cudaStreamGetCaptureInfo( + cudaStream_t stream, enum cudaStreamCaptureStatus *pCaptureStatus, + unsigned long long *pId) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaStream_t, enum cudaStreamCaptureStatus *, unsigned long long *); + static auto func_ptr = LoadSymbol("cudaStreamGetCaptureInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, pCaptureStatus, pId); +} + +extern __host__ cudaError_t CUDARTAPI cudaEventCreate(cudaEvent_t *event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t *); + static auto func_ptr = LoadSymbol("cudaEventCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaEventCreateWithFlags(cudaEvent_t *event, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t *, unsigned int); + static auto func_ptr = LoadSymbol("cudaEventCreateWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaEventRecord(cudaEvent_t event, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaEventRecord"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaEventQuery(cudaEvent_t event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t); + static auto func_ptr = LoadSymbol("cudaEventQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event); +} + +extern __host__ cudaError_t CUDARTAPI cudaEventSynchronize(cudaEvent_t event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t); + static auto func_ptr = LoadSymbol("cudaEventSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaEventDestroy(cudaEvent_t event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t); + static auto func_ptr = LoadSymbol("cudaEventDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event); +} + +extern __host__ cudaError_t CUDARTAPI cudaEventElapsedTime(float *ms, + cudaEvent_t start, + cudaEvent_t end) { + using FuncPtr = cudaError_t(CUDARTAPI *)(float *, cudaEvent_t, cudaEvent_t); + static auto func_ptr = LoadSymbol("cudaEventElapsedTime"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ms, start, end); +} + +extern __host__ cudaError_t CUDARTAPI cudaImportExternalMemory( + cudaExternalMemory_t *extMem_out, + const struct cudaExternalMemoryHandleDesc *memHandleDesc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaExternalMemory_t *, const struct cudaExternalMemoryHandleDesc *); + static auto func_ptr = LoadSymbol("cudaImportExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem_out, memHandleDesc); +} + +extern __host__ cudaError_t CUDARTAPI cudaExternalMemoryGetMappedBuffer( + void **devPtr, cudaExternalMemory_t extMem, + const struct cudaExternalMemoryBufferDesc *bufferDesc) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void **, cudaExternalMemory_t, + const struct cudaExternalMemoryBufferDesc *); + static auto func_ptr = + LoadSymbol("cudaExternalMemoryGetMappedBuffer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, extMem, bufferDesc); +} + +extern __host__ cudaError_t CUDARTAPI cudaExternalMemoryGetMappedMipmappedArray( + cudaMipmappedArray_t *mipmap, cudaExternalMemory_t extMem, + const struct cudaExternalMemoryMipmappedArrayDesc *mipmapDesc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaMipmappedArray_t *, cudaExternalMemory_t, + const struct cudaExternalMemoryMipmappedArrayDesc *); + static auto func_ptr = + LoadSymbol("cudaExternalMemoryGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmap, extMem, mipmapDesc); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDestroyExternalMemory(cudaExternalMemory_t extMem) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaExternalMemory_t); + static auto func_ptr = LoadSymbol("cudaDestroyExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem); +} + +extern __host__ cudaError_t CUDARTAPI cudaImportExternalSemaphore( + cudaExternalSemaphore_t *extSem_out, + const struct cudaExternalSemaphoreHandleDesc *semHandleDesc) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaExternalSemaphore_t *, + const struct cudaExternalSemaphoreHandleDesc *); + static auto func_ptr = LoadSymbol("cudaImportExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem_out, semHandleDesc); +} + +extern __host__ cudaError_t CUDARTAPI cudaSignalExternalSemaphoresAsync( + const cudaExternalSemaphore_t *extSemArray, + const struct cudaExternalSemaphoreSignalParams *paramsArray, + unsigned int numExtSems, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const cudaExternalSemaphore_t *, + const struct cudaExternalSemaphoreSignalParams *, + unsigned int, cudaStream_t); + static auto func_ptr = + LoadSymbol("cudaSignalExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaWaitExternalSemaphoresAsync( + const cudaExternalSemaphore_t *extSemArray, + const struct cudaExternalSemaphoreWaitParams *paramsArray, + unsigned int numExtSems, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const cudaExternalSemaphore_t *, + const struct cudaExternalSemaphoreWaitParams *, + unsigned int, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaWaitExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDestroyExternalSemaphore(cudaExternalSemaphore_t extSem) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaExternalSemaphore_t); + static auto func_ptr = LoadSymbol("cudaDestroyExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem); +} + +extern __host__ cudaError_t CUDARTAPI +cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim, void **args, + size_t sharedMem, cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, dim3, dim3, void **, + size_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaLaunchKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, gridDim, blockDim, args, sharedMem, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaLaunchCooperativeKernel( + const void *func, dim3 gridDim, dim3 blockDim, void **args, + size_t sharedMem, cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, dim3, dim3, void **, + size_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaLaunchCooperativeKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, gridDim, blockDim, args, sharedMem, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaLaunchCooperativeKernelMultiDevice( + struct cudaLaunchParams *launchParamsList, unsigned int numDevices, + unsigned int flags __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaLaunchParams *, + unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol("cudaLaunchCooperativeKernelMultiDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(launchParamsList, numDevices, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaFuncSetCacheConfig(const void *func, enum cudaFuncCache cacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, enum cudaFuncCache); + static auto func_ptr = LoadSymbol("cudaFuncSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, cacheConfig); +} + +extern __host__ cudaError_t CUDARTAPI +cudaFuncSetSharedMemConfig(const void *func, enum cudaSharedMemConfig config) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const void *, enum cudaSharedMemConfig); + static auto func_ptr = LoadSymbol("cudaFuncSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, config); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaFuncGetAttributes(struct cudaFuncAttributes *attr, const void *func) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaFuncAttributes *, const void *); + static auto func_ptr = LoadSymbol("cudaFuncGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(attr, func); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaFuncSetAttribute(const void *func, enum cudaFuncAttribute attr, int value) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const void *, enum cudaFuncAttribute, int); + static auto func_ptr = LoadSymbol("cudaFuncSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, attr, value); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaSetDoubleForDevice(double *d) { + using FuncPtr = cudaError_t(CUDARTAPI *)(double *); + static auto func_ptr = LoadSymbol("cudaSetDoubleForDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(d); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaSetDoubleForHost(double *d) { + using FuncPtr = cudaError_t(CUDARTAPI *)(double *); + static auto func_ptr = LoadSymbol("cudaSetDoubleForHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(d); +} + +extern __host__ cudaError_t CUDARTAPI cudaLaunchHostFunc(cudaStream_t stream, + cudaHostFn_t fn, + void *userData) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, cudaHostFn_t, void *); + static auto func_ptr = LoadSymbol("cudaLaunchHostFunc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, fn, userData); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaOccupancyMaxActiveBlocksPerMultiprocessor(int *numBlocks, const void *func, + int blockSize, + size_t dynamicSMemSize) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, const void *, int, size_t); + static auto func_ptr = + LoadSymbol("cudaOccupancyMaxActiveBlocksPerMultiprocessor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(int *numBlocks, + const void *func, + int blockSize, + size_t dynamicSMemSize, + unsigned int flags) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int *, const void *, int, size_t, unsigned int); + static auto func_ptr = LoadSymbol( + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMallocManaged( + void **devPtr, size_t size, unsigned int flags __dv(cudaMemAttachGlobal)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaMallocManaged"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, size, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaMalloc(void **devPtr, size_t size) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t); + static auto func_ptr = LoadSymbol("cudaMalloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, size); +} + +extern __host__ cudaError_t CUDARTAPI cudaMallocHost(void **ptr, size_t size) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t); + static auto func_ptr = LoadSymbol("cudaMallocHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size); +} + +extern __host__ cudaError_t CUDARTAPI cudaMallocPitch(void **devPtr, + size_t *pitch, + size_t width, + size_t height) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t *, size_t, size_t); + static auto func_ptr = LoadSymbol("cudaMallocPitch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, pitch, width, height); +} + +extern __host__ cudaError_t CUDARTAPI cudaMallocArray( + cudaArray_t *array, const struct cudaChannelFormatDesc *desc, size_t width, + size_t height __dv(0), unsigned int flags __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t *, + const struct cudaChannelFormatDesc *, + size_t, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaMallocArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(array, desc, width, height, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaFree(void *devPtr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *); + static auto func_ptr = LoadSymbol("cudaFree"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr); +} + +extern __host__ cudaError_t CUDARTAPI cudaFreeHost(void *ptr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *); + static auto func_ptr = LoadSymbol("cudaFreeHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr); +} + +extern __host__ cudaError_t CUDARTAPI cudaFreeArray(cudaArray_t array) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t); + static auto func_ptr = LoadSymbol("cudaFreeArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(array); +} + +extern __host__ cudaError_t CUDARTAPI +cudaFreeMipmappedArray(cudaMipmappedArray_t mipmappedArray) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaMipmappedArray_t); + static auto func_ptr = LoadSymbol("cudaFreeMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmappedArray); +} + +extern __host__ cudaError_t CUDARTAPI cudaHostAlloc(void **pHost, size_t size, + unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaHostAlloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHost, size, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaHostRegister(void *ptr, size_t size, + unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaHostRegister"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaHostUnregister(void *ptr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *); + static auto func_ptr = LoadSymbol("cudaHostUnregister"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr); +} + +extern __host__ cudaError_t CUDARTAPI +cudaHostGetDevicePointer(void **pDevice, void *pHost, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, void *, unsigned int); + static auto func_ptr = LoadSymbol("cudaHostGetDevicePointer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pDevice, pHost, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaHostGetFlags(unsigned int *pFlags, + void *pHost) { + using FuncPtr = cudaError_t(CUDARTAPI *)(unsigned int *, void *); + static auto func_ptr = LoadSymbol("cudaHostGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, pHost); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMalloc3D(struct cudaPitchedPtr *pitchedDevPtr, struct cudaExtent extent) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaPitchedPtr *, struct cudaExtent); + static auto func_ptr = LoadSymbol("cudaMalloc3D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pitchedDevPtr, extent); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMalloc3DArray(cudaArray_t *array, const struct cudaChannelFormatDesc *desc, + struct cudaExtent extent, unsigned int flags __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t *, + const struct cudaChannelFormatDesc *, + struct cudaExtent, unsigned int); + static auto func_ptr = LoadSymbol("cudaMalloc3DArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(array, desc, extent, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaMallocMipmappedArray( + cudaMipmappedArray_t *mipmappedArray, + const struct cudaChannelFormatDesc *desc, struct cudaExtent extent, + unsigned int numLevels, unsigned int flags __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaMipmappedArray_t *, const struct cudaChannelFormatDesc *, + struct cudaExtent, unsigned int, unsigned int); + static auto func_ptr = LoadSymbol("cudaMallocMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmappedArray, desc, extent, numLevels, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetMipmappedArrayLevel( + cudaArray_t *levelArray, cudaMipmappedArray_const_t mipmappedArray, + unsigned int level) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaArray_t *, cudaMipmappedArray_const_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaGetMipmappedArrayLevel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(levelArray, mipmappedArray, level); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemcpy3D(const struct cudaMemcpy3DParms *p) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol("cudaMemcpy3D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemcpy3DPeer(const struct cudaMemcpy3DPeerParms *p) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const struct cudaMemcpy3DPeerParms *); + static auto func_ptr = LoadSymbol("cudaMemcpy3DPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMemcpy3DAsync( + const struct cudaMemcpy3DParms *p, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const struct cudaMemcpy3DParms *, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpy3DAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy3DPeerAsync( + const struct cudaMemcpy3DPeerParms *p, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const struct cudaMemcpy3DPeerParms *, + cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpy3DPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemGetInfo(size_t *free, + size_t *total) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaMemGetInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(free, total); +} + +extern __host__ cudaError_t CUDARTAPI +cudaArrayGetInfo(struct cudaChannelFormatDesc *desc, struct cudaExtent *extent, + unsigned int *flags, cudaArray_t array) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaChannelFormatDesc *, + struct cudaExtent *, unsigned int *, + cudaArray_t); + static auto func_ptr = LoadSymbol("cudaArrayGetInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(desc, extent, flags, array); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy(void *dst, const void *src, + size_t count, + enum cudaMemcpyKind kind) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, const void *, size_t, + enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, count, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyPeer(void *dst, int dstDevice, + const void *src, + int srcDevice, + size_t count) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void *, int, const void *, int, size_t); + static auto func_ptr = LoadSymbol("cudaMemcpyPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dstDevice, src, srcDevice, count); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2D(void *dst, size_t dpitch, + const void *src, + size_t spitch, size_t width, + size_t height, + enum cudaMemcpyKind kind) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, const void *, size_t, + size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpy2D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dpitch, src, spitch, width, height, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DToArray( + cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, + size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, const void *, + size_t, size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpy2DToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffset, hOffset, src, spitch, width, height, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DFromArray( + void *dst, size_t dpitch, cudaArray_const_t src, size_t wOffset, + size_t hOffset, size_t width, size_t height, enum cudaMemcpyKind kind) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void *, size_t, cudaArray_const_t, size_t, + size_t, size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpy2DFromArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dpitch, src, wOffset, hOffset, width, height, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DArrayToArray( + cudaArray_t dst, size_t wOffsetDst, size_t hOffsetDst, + cudaArray_const_t src, size_t wOffsetSrc, size_t hOffsetSrc, size_t width, + size_t height, enum cudaMemcpyKind kind __dv(cudaMemcpyDeviceToDevice)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, + cudaArray_const_t, size_t, size_t, + size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpy2DArrayToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffsetDst, hOffsetDst, src, wOffsetSrc, hOffsetSrc, + width, height, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyToSymbol( + const void *symbol, const void *src, size_t count, size_t offset __dv(0), + enum cudaMemcpyKind kind __dv(cudaMemcpyHostToDevice)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, const void *, size_t, + size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpyToSymbol"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(symbol, src, count, offset, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyFromSymbol( + void *dst, const void *symbol, size_t count, size_t offset __dv(0), + enum cudaMemcpyKind kind __dv(cudaMemcpyDeviceToHost)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, const void *, size_t, size_t, + enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpyFromSymbol"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, symbol, count, offset, kind); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaMemcpyAsync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, const void *, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpyAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, count, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemcpyPeerAsync(void *dst, int dstDevice, const void *src, int srcDevice, + size_t count, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, int, const void *, int, + size_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpyPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dstDevice, src, srcDevice, count, stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMemcpy2DAsync( + void *dst, size_t dpitch, const void *src, size_t spitch, size_t width, + size_t height, enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void *, size_t, const void *, size_t, size_t, + size_t, enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpy2DAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dpitch, src, spitch, width, height, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DToArrayAsync( + cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, + size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind, + cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, + const void *, size_t, size_t, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpy2DToArrayAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffset, hOffset, src, spitch, width, height, kind, + stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DFromArrayAsync( + void *dst, size_t dpitch, cudaArray_const_t src, size_t wOffset, + size_t hOffset, size_t width, size_t height, enum cudaMemcpyKind kind, + cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, cudaArray_const_t, + size_t, size_t, size_t, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpy2DFromArrayAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dpitch, src, wOffset, hOffset, width, height, kind, + stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyToSymbolAsync( + const void *symbol, const void *src, size_t count, size_t offset, + enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const void *, const void *, size_t, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpyToSymbolAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(symbol, src, count, offset, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyFromSymbolAsync( + void *dst, const void *symbol, size_t count, size_t offset, + enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, const void *, size_t, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpyFromSymbolAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, symbol, count, offset, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemset(void *devPtr, int value, + size_t count) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, int, size_t); + static auto func_ptr = LoadSymbol("cudaMemset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, value, count); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemset2D(void *devPtr, size_t pitch, + int value, size_t width, + size_t height) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, int, size_t, size_t); + static auto func_ptr = LoadSymbol("cudaMemset2D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, pitch, value, width, height); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemset3D( + struct cudaPitchedPtr pitchedDevPtr, int value, struct cudaExtent extent) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaPitchedPtr, int, struct cudaExtent); + static auto func_ptr = LoadSymbol("cudaMemset3D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pitchedDevPtr, value, extent); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMemsetAsync( + void *devPtr, int value, size_t count, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, int, size_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemsetAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, value, count, stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaMemset2DAsync(void *devPtr, size_t pitch, int value, size_t width, + size_t height, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, int, size_t, size_t, + cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemset2DAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, pitch, value, width, height, stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaMemset3DAsync(struct cudaPitchedPtr pitchedDevPtr, int value, + struct cudaExtent extent, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaPitchedPtr, int, + struct cudaExtent, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemset3DAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pitchedDevPtr, value, extent, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetSymbolAddress(void **devPtr, + const void *symbol) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, const void *); + static auto func_ptr = LoadSymbol("cudaGetSymbolAddress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, symbol); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetSymbolSize(size_t *size, + const void *symbol) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, const void *); + static auto func_ptr = LoadSymbol("cudaGetSymbolSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(size, symbol); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemPrefetchAsync(const void *devPtr, size_t count, int dstDevice, + cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const void *, size_t, int, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemPrefetchAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, dstDevice, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemAdvise(const void *devPtr, size_t count, enum cudaMemoryAdvise advice, + int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, size_t, + enum cudaMemoryAdvise, int); + static auto func_ptr = LoadSymbol("cudaMemAdvise"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, advice, device); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemRangeGetAttribute( + void *data, size_t dataSize, enum cudaMemRangeAttribute attribute, + const void *devPtr, size_t count) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + void *, size_t, enum cudaMemRangeAttribute, const void *, size_t); + static auto func_ptr = LoadSymbol("cudaMemRangeGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSize, attribute, devPtr, count); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemRangeGetAttributes( + void **data, size_t *dataSizes, enum cudaMemRangeAttribute *attributes, + size_t numAttributes, const void *devPtr, size_t count) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void **, size_t *, enum cudaMemRangeAttribute *, + size_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cudaMemRangeGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSizes, attributes, numAttributes, devPtr, count); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaMemcpyToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, + const void *src, size_t count, enum cudaMemcpyKind kind) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaArray_t, size_t, size_t, const void *, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpyToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffset, hOffset, src, count, kind); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaMemcpyFromArray(void *dst, cudaArray_const_t src, size_t wOffset, + size_t hOffset, size_t count, enum cudaMemcpyKind kind) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, cudaArray_const_t, size_t, + size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpyFromArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, wOffset, hOffset, count, kind); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaMemcpyArrayToArray( + cudaArray_t dst, size_t wOffsetDst, size_t hOffsetDst, + cudaArray_const_t src, size_t wOffsetSrc, size_t hOffsetSrc, size_t count, + enum cudaMemcpyKind kind __dv(cudaMemcpyDeviceToDevice)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, cudaArray_const_t, + size_t, size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol("cudaMemcpyArrayToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffsetDst, hOffsetDst, src, wOffsetSrc, hOffsetSrc, + count, kind); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaMemcpyToArrayAsync( + cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, + size_t count, enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, const void *, + size_t, enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpyToArrayAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffset, hOffset, src, count, kind, stream); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaMemcpyFromArrayAsync(void *dst, cudaArray_const_t src, size_t wOffset, + size_t hOffset, size_t count, enum cudaMemcpyKind kind, + cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void *, cudaArray_const_t, size_t, size_t, + size_t, enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaMemcpyFromArrayAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, wOffset, hOffset, count, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaPointerGetAttributes( + struct cudaPointerAttributes *attributes, const void *ptr) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaPointerAttributes *, const void *); + static auto func_ptr = LoadSymbol("cudaPointerGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(attributes, ptr); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceCanAccessPeer(int *canAccessPeer, int device, int peerDevice) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, int, int); + static auto func_ptr = LoadSymbol("cudaDeviceCanAccessPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(canAccessPeer, device, peerDevice); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceEnablePeerAccess(int peerDevice, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int, unsigned int); + static auto func_ptr = LoadSymbol("cudaDeviceEnablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerDevice, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceDisablePeerAccess(int peerDevice) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int); + static auto func_ptr = LoadSymbol("cudaDeviceDisablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerDevice); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphicsUnregisterResource(cudaGraphicsResource_t resource) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphicsResource_t); + static auto func_ptr = LoadSymbol("cudaGraphicsUnregisterResource"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsResourceSetMapFlags( + cudaGraphicsResource_t resource, unsigned int flags) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphicsResource_t, unsigned int); + static auto func_ptr = LoadSymbol("cudaGraphicsResourceSetMapFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsMapResources( + int count, cudaGraphicsResource_t *resources, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int, cudaGraphicsResource_t *, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaGraphicsMapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsUnmapResources( + int count, cudaGraphicsResource_t *resources, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int, cudaGraphicsResource_t *, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaGraphicsUnmapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsResourceGetMappedPointer( + void **devPtr, size_t *size, cudaGraphicsResource_t resource) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void **, size_t *, cudaGraphicsResource_t); + static auto func_ptr = + LoadSymbol("cudaGraphicsResourceGetMappedPointer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, size, resource); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsSubResourceGetMappedArray( + cudaArray_t *array, cudaGraphicsResource_t resource, + unsigned int arrayIndex, unsigned int mipLevel) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaArray_t *, cudaGraphicsResource_t, unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol("cudaGraphicsSubResourceGetMappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(array, resource, arrayIndex, mipLevel); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphicsResourceGetMappedMipmappedArray( + cudaMipmappedArray_t *mipmappedArray, cudaGraphicsResource_t resource) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaMipmappedArray_t *, cudaGraphicsResource_t); + static auto func_ptr = + LoadSymbol("cudaGraphicsResourceGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmappedArray, resource); +} + +extern __host__ cudaError_t CUDARTAPI cudaBindTexture( + size_t *offset, const struct textureReference *texref, const void *devPtr, + const struct cudaChannelFormatDesc *desc, size_t size __dv(UINT_MAX)) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + size_t *, const struct textureReference *, const void *, + const struct cudaChannelFormatDesc *, size_t); + static auto func_ptr = LoadSymbol("cudaBindTexture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(offset, texref, devPtr, desc, size); +} + +extern __host__ cudaError_t CUDARTAPI +cudaBindTexture2D(size_t *offset, const struct textureReference *texref, + const void *devPtr, const struct cudaChannelFormatDesc *desc, + size_t width, size_t height, size_t pitch) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + size_t *, const struct textureReference *, const void *, + const struct cudaChannelFormatDesc *, size_t, size_t, size_t); + static auto func_ptr = LoadSymbol("cudaBindTexture2D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(offset, texref, devPtr, desc, width, height, pitch); +} + +extern __host__ cudaError_t CUDARTAPI cudaBindTextureToArray( + const struct textureReference *texref, cudaArray_const_t array, + const struct cudaChannelFormatDesc *desc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + const struct textureReference *, cudaArray_const_t, + const struct cudaChannelFormatDesc *); + static auto func_ptr = LoadSymbol("cudaBindTextureToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texref, array, desc); +} + +extern __host__ cudaError_t CUDARTAPI +cudaBindTextureToMipmappedArray(const struct textureReference *texref, + cudaMipmappedArray_const_t mipmappedArray, + const struct cudaChannelFormatDesc *desc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + const struct textureReference *, cudaMipmappedArray_const_t, + const struct cudaChannelFormatDesc *); + static auto func_ptr = LoadSymbol("cudaBindTextureToMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texref, mipmappedArray, desc); +} + +extern __host__ cudaError_t CUDARTAPI +cudaUnbindTexture(const struct textureReference *texref) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const struct textureReference *); + static auto func_ptr = LoadSymbol("cudaUnbindTexture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texref); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetTextureAlignmentOffset( + size_t *offset, const struct textureReference *texref) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(size_t *, const struct textureReference *); + static auto func_ptr = LoadSymbol("cudaGetTextureAlignmentOffset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(offset, texref); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetTextureReference( + const struct textureReference **texref, const void *symbol) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const struct textureReference **, const void *); + static auto func_ptr = LoadSymbol("cudaGetTextureReference"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texref, symbol); +} + +extern __host__ cudaError_t CUDARTAPI cudaBindSurfaceToArray( + const struct surfaceReference *surfref, cudaArray_const_t array, + const struct cudaChannelFormatDesc *desc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + const struct surfaceReference *, cudaArray_const_t, + const struct cudaChannelFormatDesc *); + static auto func_ptr = LoadSymbol("cudaBindSurfaceToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfref, array, desc); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetSurfaceReference( + const struct surfaceReference **surfref, const void *symbol) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const struct surfaceReference **, const void *); + static auto func_ptr = LoadSymbol("cudaGetSurfaceReference"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfref, symbol); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetChannelDesc( + struct cudaChannelFormatDesc *desc, cudaArray_const_t array) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaChannelFormatDesc *, + cudaArray_const_t); + static auto func_ptr = LoadSymbol("cudaGetChannelDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(desc, array); +} + +extern __host__ cudaError_t CUDARTAPI cudaCreateTextureObject( + cudaTextureObject_t *pTexObject, const struct cudaResourceDesc *pResDesc, + const struct cudaTextureDesc *pTexDesc, + const struct cudaResourceViewDesc *pResViewDesc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaTextureObject_t *, const struct cudaResourceDesc *, + const struct cudaTextureDesc *, const struct cudaResourceViewDesc *); + static auto func_ptr = LoadSymbol("cudaCreateTextureObject"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexObject, pResDesc, pTexDesc, pResViewDesc); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDestroyTextureObject(cudaTextureObject_t texObject) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaTextureObject_t); + static auto func_ptr = LoadSymbol("cudaDestroyTextureObject"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetTextureObjectResourceDesc( + struct cudaResourceDesc *pResDesc, cudaTextureObject_t texObject) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaResourceDesc *, cudaTextureObject_t); + static auto func_ptr = + LoadSymbol("cudaGetTextureObjectResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, texObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetTextureObjectTextureDesc( + struct cudaTextureDesc *pTexDesc, cudaTextureObject_t texObject) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaTextureDesc *, cudaTextureObject_t); + static auto func_ptr = LoadSymbol("cudaGetTextureObjectTextureDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexDesc, texObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetTextureObjectResourceViewDesc( + struct cudaResourceViewDesc *pResViewDesc, cudaTextureObject_t texObject) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaResourceViewDesc *, + cudaTextureObject_t); + static auto func_ptr = + LoadSymbol("cudaGetTextureObjectResourceViewDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResViewDesc, texObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaCreateSurfaceObject( + cudaSurfaceObject_t *pSurfObject, const struct cudaResourceDesc *pResDesc) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaSurfaceObject_t *, + const struct cudaResourceDesc *); + static auto func_ptr = LoadSymbol("cudaCreateSurfaceObject"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfObject, pResDesc); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDestroySurfaceObject(cudaSurfaceObject_t surfObject) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaSurfaceObject_t); + static auto func_ptr = LoadSymbol("cudaDestroySurfaceObject"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetSurfaceObjectResourceDesc( + struct cudaResourceDesc *pResDesc, cudaSurfaceObject_t surfObject) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaResourceDesc *, cudaSurfaceObject_t); + static auto func_ptr = + LoadSymbol("cudaGetSurfaceObjectResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, surfObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaDriverGetVersion(int *driverVersion) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *); + static auto func_ptr = LoadSymbol("cudaDriverGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(driverVersion); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaRuntimeGetVersion(int *runtimeVersion) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *); + static auto func_ptr = LoadSymbol("cudaRuntimeGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(runtimeVersion); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphCreate(cudaGraph_t *pGraph, + unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t *, unsigned int); + static auto func_ptr = LoadSymbol("cudaGraphCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraph, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddKernelNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, + const struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphAddKernelNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphKernelNodeGetParams( + cudaGraphNode_t node, struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphKernelNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphKernelNodeSetParams( + cudaGraphNode_t node, const struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddMemcpyNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaMemcpy3DParms *pCopyParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, + const struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol("cudaGraphAddMemcpyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pCopyParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemcpyNodeGetParams( + cudaGraphNode_t node, struct cudaMemcpy3DParms *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol("cudaGraphMemcpyNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemcpyNodeSetParams( + cudaGraphNode_t node, const struct cudaMemcpy3DParms *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol("cudaGraphMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddMemsetNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaMemsetParams *pMemsetParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, + const struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol("cudaGraphAddMemsetNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pMemsetParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemsetNodeGetParams( + cudaGraphNode_t node, struct cudaMemsetParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol("cudaGraphMemsetNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemsetNodeSetParams( + cudaGraphNode_t node, const struct cudaMemsetParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol("cudaGraphMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddHostNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, + const struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphAddHostNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphHostNodeGetParams( + cudaGraphNode_t node, struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphHostNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphHostNodeSetParams( + cudaGraphNode_t node, const struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddChildGraphNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, + size_t numDependencies, cudaGraph_t childGraph) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, cudaGraph_t); + static auto func_ptr = LoadSymbol("cudaGraphAddChildGraphNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + childGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphChildGraphNodeGetGraph(cudaGraphNode_t node, cudaGraph_t *pGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraph_t *); + static auto func_ptr = LoadSymbol("cudaGraphChildGraphNodeGetGraph"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pGraph); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddEmptyNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol("cudaGraphAddEmptyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphClone(cudaGraph_t *pGraphClone, cudaGraph_t originalGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t *, cudaGraph_t); + static auto func_ptr = LoadSymbol("cudaGraphClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphClone, originalGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphNodeFindInClone(cudaGraphNode_t *pNode, cudaGraphNode_t originalNode, + cudaGraph_t clonedGraph) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraphNode_t, cudaGraph_t); + static auto func_ptr = LoadSymbol("cudaGraphNodeFindInClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pNode, originalNode, clonedGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphNodeGetType(cudaGraphNode_t node, enum cudaGraphNodeType *pType) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, enum cudaGraphNodeType *); + static auto func_ptr = LoadSymbol("cudaGraphNodeGetType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pType); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetNodes(cudaGraph_t graph, + cudaGraphNode_t *nodes, + size_t *numNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphGetNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, nodes, numNodes); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetRootNodes( + cudaGraph_t graph, cudaGraphNode_t *pRootNodes, size_t *pNumRootNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphGetRootNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, pRootNodes, pNumRootNodes); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetEdges(cudaGraph_t graph, + cudaGraphNode_t *from, + cudaGraphNode_t *to, + size_t *numEdges) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, + cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphGetEdges"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numEdges); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphNodeGetDependencies( + cudaGraphNode_t node, cudaGraphNode_t *pDependencies, + size_t *pNumDependencies) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphNodeGetDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pDependencies, pNumDependencies); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphNodeGetDependentNodes( + cudaGraphNode_t node, cudaGraphNode_t *pDependentNodes, + size_t *pNumDependentNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol("cudaGraphNodeGetDependentNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pDependentNodes, pNumDependentNodes); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddDependencies(cudaGraph_t graph, const cudaGraphNode_t *from, + const cudaGraphNode_t *to, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, const cudaGraphNode_t *, + const cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol("cudaGraphAddDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphRemoveDependencies(cudaGraph_t graph, const cudaGraphNode_t *from, + const cudaGraphNode_t *to, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, const cudaGraphNode_t *, + const cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol("cudaGraphRemoveDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphDestroyNode(cudaGraphNode_t node) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t); + static auto func_ptr = LoadSymbol("cudaGraphDestroyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphInstantiate( + cudaGraphExec_t *pGraphExec, cudaGraph_t graph, cudaGraphNode_t *pErrorNode, + char *pLogBuffer, size_t bufferSize) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t *, cudaGraph_t, + cudaGraphNode_t *, char *, size_t); + static auto func_ptr = LoadSymbol("cudaGraphInstantiate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphExec, graph, pErrorNode, pLogBuffer, bufferSize); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphExecKernelNodeSetParams( + cudaGraphExec_t hGraphExec, cudaGraphNode_t node, + const struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraphNode_t, + const struct cudaKernelNodeParams *); + static auto func_ptr = + LoadSymbol("cudaGraphExecKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphExecMemcpyNodeSetParams( + cudaGraphExec_t hGraphExec, cudaGraphNode_t node, + const struct cudaMemcpy3DParms *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraphNode_t, + const struct cudaMemcpy3DParms *); + static auto func_ptr = + LoadSymbol("cudaGraphExecMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphExecMemsetNodeSetParams( + cudaGraphExec_t hGraphExec, cudaGraphNode_t node, + const struct cudaMemsetParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraphNode_t, + const struct cudaMemsetParams *); + static auto func_ptr = + LoadSymbol("cudaGraphExecMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphExecHostNodeSetParams(cudaGraphExec_t hGraphExec, cudaGraphNode_t node, + const struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraphNode_t, + const struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol("cudaGraphExecHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphExecUpdate(cudaGraphExec_t hGraphExec, cudaGraph_t hGraph, + cudaGraphNode_t *hErrorNode_out, + enum cudaGraphExecUpdateResult *updateResult_out) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraph_t, cudaGraphNode_t *, + enum cudaGraphExecUpdateResult *); + static auto func_ptr = LoadSymbol("cudaGraphExecUpdate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hGraph, hErrorNode_out, updateResult_out); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphLaunch(cudaGraphExec_t graphExec, + cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cudaGraphLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graphExec, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphExecDestroy(cudaGraphExec_t graphExec) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t); + static auto func_ptr = LoadSymbol("cudaGraphExecDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graphExec); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphDestroy(cudaGraph_t graph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t); + static auto func_ptr = LoadSymbol("cudaGraphDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetExportTable( + const void **ppExportTable, const cudaUUID_t *pExportTableId) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void **, const cudaUUID_t *); + static auto func_ptr = LoadSymbol("cudaGetExportTable"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ppExportTable, pExportTableId); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cuda_stub.cc b/tensorflow/stream_executor/cuda/cuda_stub.cc index 3248c9ddefd..ebdc4a33db6 100644 --- a/tensorflow/stream_executor/cuda/cuda_stub.cc +++ b/tensorflow/stream_executor/cuda/cuda_stub.cc @@ -93,7 +93,16 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS; typedef void(CUDA_CB* CUhostFn)(void* userData); -// For now only one stub implementation is needed. If a function that is not -// available in the given CUDA release, the corresponding wrapper returns -// CUDA_ERROR_SHARED_OBJECT_INIT_FAILED. +#if CUDA_VERSION <= 9000 +#include "tensorflow/stream_executor/cuda/cuda_9_0.inc" +#elif CUDA_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cuda_10_0.inc" +#elif CUDA_VERSION <= 10010 +#include "tensorflow/stream_executor/cuda/cuda_10_1.inc" +#elif CUDA_VERSION <= 10020 +#include "tensorflow/stream_executor/cuda/cuda_10_2.inc" +#elif CUDA_VERSION <= 11000 +#include "tensorflow/stream_executor/cuda/cuda_11_0.inc" +#else +#error "We have no wrapper for this version." +#endif diff --git a/tensorflow/stream_executor/cuda/cudart_stub.cc b/tensorflow/stream_executor/cuda/cudart_stub.cc index 0c6b274f88b..3afe6780402 100644 --- a/tensorflow/stream_executor/cuda/cudart_stub.cc +++ b/tensorflow/stream_executor/cuda/cudart_stub.cc @@ -53,10 +53,16 @@ cudaError_t GetSymbolNotFoundError() { // A bunch of new symbols were introduced in version 10 #if CUDART_VERSION <= 9020 #include "tensorflow/stream_executor/cuda/cuda_runtime_9_0.inc" -#elif CUDART_VERSION < 10010 +#elif CUDART_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_0.inc" -#else +#elif CUDART_VERSION == 10010 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_1.inc" +#elif CUDART_VERSION == 10020 +#include "tensorflow/stream_executor/cuda/cuda_runtime_10_2.inc" +#elif CUDART_VERSION == 11000 +#include "tensorflow/stream_executor/cuda/cuda_runtime_11_0.inc" +#else +#error "We have no wrapper for this version." #endif #undef __dv #undef __CUDA_DEPRECATED diff --git a/tensorflow/stream_executor/cuda/cufft_10_0.inc b/tensorflow/stream_executor/cuda/cufft_10_0.inc index ba726770ac3..48f80b05c5e 100644 --- a/tensorflow/stream_executor/cuda/cufft_10_0.inc +++ b/tensorflow/stream_executor/cuda/cufft_10_0.inc @@ -1,6 +1,7 @@ // Auto-generated, do not edit. extern "C" { + cufftResult CUFFTAPI cufftPlan1d(cufftHandle *plan, int nx, cufftType type, int batch) { using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle *, int, cufftType, int); diff --git a/tensorflow/stream_executor/cuda/cufft_9_0.inc b/tensorflow/stream_executor/cuda/cufft_9_0.inc new file mode 100644 index 00000000000..e6244f0705d --- /dev/null +++ b/tensorflow/stream_executor/cuda/cufft_9_0.inc @@ -0,0 +1,307 @@ +// Auto-generated, do not edit. + +extern "C" { + +cufftResult CUFFTAPI cufftPlan1d(cufftHandle *plan, int nx, cufftType type, + int batch) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle *, int, cufftType, int); + static auto func_ptr = LoadSymbol("cufftPlan1d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, nx, type, batch); +} + +cufftResult CUFFTAPI cufftPlan2d(cufftHandle *plan, int nx, int ny, + cufftType type) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle *, int, int, cufftType); + static auto func_ptr = LoadSymbol("cufftPlan2d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, nx, ny, type); +} + +cufftResult CUFFTAPI cufftPlan3d(cufftHandle *plan, int nx, int ny, int nz, + cufftType type) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle *, int, int, int, cufftType); + static auto func_ptr = LoadSymbol("cufftPlan3d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, nx, ny, nz, type); +} + +cufftResult CUFFTAPI cufftPlanMany(cufftHandle *plan, int rank, int *n, + int *inembed, int istride, int idist, + int *onembed, int ostride, int odist, + cufftType type, int batch) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle *, int, int *, int *, int, + int, int *, int, int, cufftType, int); + static auto func_ptr = LoadSymbol("cufftPlanMany"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, rank, n, inembed, istride, idist, onembed, ostride, + odist, type, batch); +} + +cufftResult CUFFTAPI cufftMakePlan1d(cufftHandle plan, int nx, cufftType type, + int batch, size_t *workSize) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, int, cufftType, int, size_t *); + static auto func_ptr = LoadSymbol("cufftMakePlan1d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, nx, type, batch, workSize); +} + +cufftResult CUFFTAPI cufftMakePlan2d(cufftHandle plan, int nx, int ny, + cufftType type, size_t *workSize) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, int, int, cufftType, size_t *); + static auto func_ptr = LoadSymbol("cufftMakePlan2d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, nx, ny, type, workSize); +} + +cufftResult CUFFTAPI cufftMakePlan3d(cufftHandle plan, int nx, int ny, int nz, + cufftType type, size_t *workSize) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, int, int, int, cufftType, size_t *); + static auto func_ptr = LoadSymbol("cufftMakePlan3d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, nx, ny, nz, type, workSize); +} + +cufftResult CUFFTAPI cufftMakePlanMany(cufftHandle plan, int rank, int *n, + int *inembed, int istride, int idist, + int *onembed, int ostride, int odist, + cufftType type, int batch, + size_t *workSize) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, int, int *, int *, int, int, int *, + int, int, cufftType, int, size_t *); + static auto func_ptr = LoadSymbol("cufftMakePlanMany"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, rank, n, inembed, istride, idist, onembed, ostride, + odist, type, batch, workSize); +} + +cufftResult CUFFTAPI cufftMakePlanMany64( + cufftHandle plan, int rank, long long int *n, long long int *inembed, + long long int istride, long long int idist, long long int *onembed, + long long int ostride, long long int odist, cufftType type, + long long int batch, size_t *workSize) { + using FuncPtr = cufftResult(CUFFTAPI *)( + cufftHandle, int, long long *, long long *, long long, long long, + long long *, long long, long long, cufftType, long long, size_t *); + static auto func_ptr = LoadSymbol("cufftMakePlanMany64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, rank, n, inembed, istride, idist, onembed, ostride, + odist, type, batch, workSize); +} + +cufftResult CUFFTAPI cufftGetSizeMany64( + cufftHandle plan, int rank, long long int *n, long long int *inembed, + long long int istride, long long int idist, long long int *onembed, + long long int ostride, long long int odist, cufftType type, + long long int batch, size_t *workSize) { + using FuncPtr = cufftResult(CUFFTAPI *)( + cufftHandle, int, long long *, long long *, long long, long long, + long long *, long long, long long, cufftType, long long, size_t *); + static auto func_ptr = LoadSymbol("cufftGetSizeMany64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, rank, n, inembed, istride, idist, onembed, ostride, + odist, type, batch, workSize); +} + +cufftResult CUFFTAPI cufftEstimate1d(int nx, cufftType type, int batch, + size_t *workSize) { + using FuncPtr = cufftResult(CUFFTAPI *)(int, cufftType, int, size_t *); + static auto func_ptr = LoadSymbol("cufftEstimate1d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(nx, type, batch, workSize); +} + +cufftResult CUFFTAPI cufftEstimate2d(int nx, int ny, cufftType type, + size_t *workSize) { + using FuncPtr = cufftResult(CUFFTAPI *)(int, int, cufftType, size_t *); + static auto func_ptr = LoadSymbol("cufftEstimate2d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(nx, ny, type, workSize); +} + +cufftResult CUFFTAPI cufftEstimate3d(int nx, int ny, int nz, cufftType type, + size_t *workSize) { + using FuncPtr = cufftResult(CUFFTAPI *)(int, int, int, cufftType, size_t *); + static auto func_ptr = LoadSymbol("cufftEstimate3d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(nx, ny, nz, type, workSize); +} + +cufftResult CUFFTAPI cufftEstimateMany(int rank, int *n, int *inembed, + int istride, int idist, int *onembed, + int ostride, int odist, cufftType type, + int batch, size_t *workSize) { + using FuncPtr = cufftResult(CUFFTAPI *)(int, int *, int *, int, int, int *, + int, int, cufftType, int, size_t *); + static auto func_ptr = LoadSymbol("cufftEstimateMany"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rank, n, inembed, istride, idist, onembed, ostride, odist, + type, batch, workSize); +} + +cufftResult CUFFTAPI cufftCreate(cufftHandle *handle) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle *); + static auto func_ptr = LoadSymbol("cufftCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cufftResult CUFFTAPI cufftGetSize1d(cufftHandle handle, int nx, cufftType type, + int batch, size_t *workSize) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, int, cufftType, int, size_t *); + static auto func_ptr = LoadSymbol("cufftGetSize1d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nx, type, batch, workSize); +} + +cufftResult CUFFTAPI cufftGetSize2d(cufftHandle handle, int nx, int ny, + cufftType type, size_t *workSize) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, int, int, cufftType, size_t *); + static auto func_ptr = LoadSymbol("cufftGetSize2d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nx, ny, type, workSize); +} + +cufftResult CUFFTAPI cufftGetSize3d(cufftHandle handle, int nx, int ny, int nz, + cufftType type, size_t *workSize) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, int, int, int, cufftType, size_t *); + static auto func_ptr = LoadSymbol("cufftGetSize3d"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nx, ny, nz, type, workSize); +} + +cufftResult CUFFTAPI cufftGetSizeMany(cufftHandle handle, int rank, int *n, + int *inembed, int istride, int idist, + int *onembed, int ostride, int odist, + cufftType type, int batch, + size_t *workArea) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, int, int *, int *, int, int, int *, + int, int, cufftType, int, size_t *); + static auto func_ptr = LoadSymbol("cufftGetSizeMany"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rank, n, inembed, istride, idist, onembed, ostride, + odist, type, batch, workArea); +} + +cufftResult CUFFTAPI cufftGetSize(cufftHandle handle, size_t *workSize) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle, size_t *); + static auto func_ptr = LoadSymbol("cufftGetSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, workSize); +} + +cufftResult CUFFTAPI cufftSetWorkArea(cufftHandle plan, void *workArea) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle, void *); + static auto func_ptr = LoadSymbol("cufftSetWorkArea"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, workArea); +} + +cufftResult CUFFTAPI cufftSetAutoAllocation(cufftHandle plan, + int autoAllocate) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle, int); + static auto func_ptr = LoadSymbol("cufftSetAutoAllocation"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, autoAllocate); +} + +cufftResult CUFFTAPI cufftExecC2C(cufftHandle plan, cufftComplex *idata, + cufftComplex *odata, int direction) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, cufftComplex *, cufftComplex *, int); + static auto func_ptr = LoadSymbol("cufftExecC2C"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, idata, odata, direction); +} + +cufftResult CUFFTAPI cufftExecR2C(cufftHandle plan, cufftReal *idata, + cufftComplex *odata) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, cufftReal *, cufftComplex *); + static auto func_ptr = LoadSymbol("cufftExecR2C"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, idata, odata); +} + +cufftResult CUFFTAPI cufftExecC2R(cufftHandle plan, cufftComplex *idata, + cufftReal *odata) { + using FuncPtr = + cufftResult(CUFFTAPI *)(cufftHandle, cufftComplex *, cufftReal *); + static auto func_ptr = LoadSymbol("cufftExecC2R"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, idata, odata); +} + +cufftResult CUFFTAPI cufftExecZ2Z(cufftHandle plan, cufftDoubleComplex *idata, + cufftDoubleComplex *odata, int direction) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle, cufftDoubleComplex *, + cufftDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cufftExecZ2Z"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, idata, odata, direction); +} + +cufftResult CUFFTAPI cufftExecD2Z(cufftHandle plan, cufftDoubleReal *idata, + cufftDoubleComplex *odata) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle, cufftDoubleReal *, + cufftDoubleComplex *); + static auto func_ptr = LoadSymbol("cufftExecD2Z"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, idata, odata); +} + +cufftResult CUFFTAPI cufftExecZ2D(cufftHandle plan, cufftDoubleComplex *idata, + cufftDoubleReal *odata) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle, cufftDoubleComplex *, + cufftDoubleReal *); + static auto func_ptr = LoadSymbol("cufftExecZ2D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, idata, odata); +} + +cufftResult CUFFTAPI cufftSetStream(cufftHandle plan, cudaStream_t stream) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle, cudaStream_t); + static auto func_ptr = LoadSymbol("cufftSetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, stream); +} + +cufftResult CUFFTAPI cufftSetCompatibilityMode(cufftHandle plan, + cufftCompatibility mode) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle, cufftCompatibility); + static auto func_ptr = LoadSymbol("cufftSetCompatibilityMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan, mode); +} + +cufftResult CUFFTAPI cufftDestroy(cufftHandle plan) { + using FuncPtr = cufftResult(CUFFTAPI *)(cufftHandle); + static auto func_ptr = LoadSymbol("cufftDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(plan); +} + +cufftResult CUFFTAPI cufftGetVersion(int *version) { + using FuncPtr = cufftResult(CUFFTAPI *)(int *); + static auto func_ptr = LoadSymbol("cufftGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(version); +} + +cufftResult CUFFTAPI cufftGetProperty(libraryPropertyType type, int *value) { + using FuncPtr = cufftResult(CUFFTAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol("cufftGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cufft_stub.cc b/tensorflow/stream_executor/cuda/cufft_stub.cc index 68d7ec7634d..58af370eafd 100644 --- a/tensorflow/stream_executor/cuda/cufft_stub.cc +++ b/tensorflow/stream_executor/cuda/cufft_stub.cc @@ -47,4 +47,9 @@ T LoadSymbol(const char* symbol_name) { cufftResult GetSymbolNotFoundError() { return CUFFT_INTERNAL_ERROR; } } // namespace +#if CUFFT_VERSION < 10000 +#include "tensorflow/stream_executor/cuda/cufft_9_0.inc" +#else +// All CUDA-10+ implementations use the same API. #include "tensorflow/stream_executor/cuda/cufft_10_0.inc" +#endif diff --git a/tensorflow/stream_executor/cuda/cusolver_dense_10_2.inc b/tensorflow/stream_executor/cuda/cusolver_dense_10_2.inc new file mode 100644 index 00000000000..50fa2464fd6 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cusolver_dense_10_2.inc @@ -0,0 +1,3677 @@ +// Auto-generated, do not edit. + +extern "C" { + +cusolverStatus_t CUSOLVERAPI cusolverGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol("cusolverGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +cusolverStatus_t CUSOLVERAPI cusolverGetVersion(int *version) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(int *); + static auto func_ptr = LoadSymbol("cusolverGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(version); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreate(cusolverDnHandle_t *handle) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t *); + static auto func_ptr = LoadSymbol("cusolverDnCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroy(cusolverDnHandle_t handle) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t); + static auto func_ptr = LoadSymbol("cusolverDnDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSetStream(cusolverDnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cusolverDnSetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnGetStream(cusolverDnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cudaStream_t *); + static auto func_ptr = LoadSymbol("cusolverDnGetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSParamsCreate(cusolverDnIRSParams_t *params_ptr) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSParamsCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params_ptr); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSParamsDestroy(cusolverDnIRSParams_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t); + static auto func_ptr = LoadSymbol("cusolverDnIRSParamsDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetTol( + cusolverDnIRSParams_t params, cudaDataType data_type, double val) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cudaDataType, double); + static auto func_ptr = LoadSymbol("cusolverDnIRSParamsSetTol"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, data_type, val); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetTolInner( + cusolverDnIRSParams_t params, cudaDataType data_type, double val) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cudaDataType, double); + static auto func_ptr = LoadSymbol("cusolverDnIRSParamsSetTolInner"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, data_type, val); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverPrecisions( + cusolverDnIRSParams_t params, cudaDataType solver_main_precision, + cudaDataType solver_lowest_precision) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cudaDataType, cudaDataType); + static auto func_ptr = + LoadSymbol("cusolverDnIRSParamsSetSolverPrecisions"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, solver_main_precision, solver_lowest_precision); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetRefinementSolver( + cusolverDnIRSParams_t params, cusolverIRSRefinement_t refinement_solver) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cusolverIRSRefinement_t); + static auto func_ptr = + LoadSymbol("cusolverDnIRSParamsSetRefinementSolver"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, refinement_solver); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetMaxIters( + cusolverDnIRSParams_t params, cusolver_int_t maxiters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cusolver_int_t); + static auto func_ptr = LoadSymbol("cusolverDnIRSParamsSetMaxIters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, maxiters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetMaxItersInner( + cusolverDnIRSParams_t params, cusolver_int_t maxiters_inner) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cusolver_int_t); + static auto func_ptr = + LoadSymbol("cusolverDnIRSParamsSetMaxItersInner"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, maxiters_inner); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsGetNiters( + cusolverDnIRSParams_t params, cusolver_int_t *niters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSParamsGetNiters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, niters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsGetOuterNiters( + cusolverDnIRSParams_t params, cusolver_int_t *outer_niters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cusolver_int_t *); + static auto func_ptr = + LoadSymbol("cusolverDnIRSParamsGetOuterNiters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, outer_niters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsGetMaxIters( + cusolverDnIRSParams_t params, cusolver_int_t *maxiters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSParamsGetMaxIters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, maxiters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverMainPrecision( + cusolverDnIRSParams_t params, cudaDataType solver_main_precision) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cudaDataType); + static auto func_ptr = + LoadSymbol("cusolverDnIRSParamsSetSolverMainPrecision"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, solver_main_precision); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverLowestPrecision( + cusolverDnIRSParams_t params, cudaDataType solver_lowest_precision) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cudaDataType); + static auto func_ptr = + LoadSymbol("cusolverDnIRSParamsSetSolverLowestPrecision"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, solver_lowest_precision); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosDestroy( + cusolverDnIRSParams_t params, cusolverDnIRSInfos_t infos) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cusolverDnIRSInfos_t); + static auto func_ptr = LoadSymbol("cusolverDnIRSInfosDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, infos); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosCreate( + cusolverDnIRSParams_t params, cusolverDnIRSInfos_t *infos_ptr) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cusolverDnIRSInfos_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSInfosCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, infos_ptr); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetNiters( + cusolverDnIRSParams_t params, cusolverDnIRSInfos_t infos, + cusolver_int_t *niters) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnIRSParams_t, cusolverDnIRSInfos_t, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSInfosGetNiters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, infos, niters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetNiters( + cusolverDnIRSParams_t params, cusolverDnIRSInfos_t infos, + cusolver_int_t *niters) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnIRSParams_t, cusolverDnIRSInfos_t, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSInfosGetNiters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, infos, niters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetOuterNiters( + cusolverDnIRSParams_t params, cusolverDnIRSInfos_t infos, + cusolver_int_t *outer_niters) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnIRSParams_t, cusolverDnIRSInfos_t, cusolver_int_t *); + static auto func_ptr = + LoadSymbol("cusolverDnIRSInfosGetOuterNiters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, infos, outer_niters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetMaxIters( + cusolverDnIRSParams_t params, cusolverDnIRSInfos_t infos, + cusolver_int_t *maxiters) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnIRSParams_t, cusolverDnIRSInfos_t, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSInfosGetMaxIters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, infos, maxiters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosRequestResidual( + cusolverDnIRSParams_t params, cusolverDnIRSInfos_t infos) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cusolverDnIRSInfos_t); + static auto func_ptr = + LoadSymbol("cusolverDnIRSInfosRequestResidual"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, infos); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetResidualHistory( + cusolverDnIRSParams_t params, cusolverDnIRSInfos_t infos, + void **residual_history) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnIRSParams_t, cusolverDnIRSInfos_t, void **); + static auto func_ptr = + LoadSymbol("cusolverDnIRSInfosGetResidualHistory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, infos, residual_history); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZZgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnZZgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZCgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnZCgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZKgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnZKgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCCgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnCCgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCKgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnCKgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDDgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnDDgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDSgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnDSgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDHgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnDHgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSSgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnSSgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSHgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnSHgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZZgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnZZgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZCgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnZCgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZKgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnZKgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCCgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnCCgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCKgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnCKgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDDgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnDDgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDSgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnDSgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDHgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnDHgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSSgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnSSgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSHgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnSHgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgesv( + cusolverDnHandle_t handle, cusolverDnIRSParams_t gesv_irs_params, + cusolverDnIRSInfos_t gesv_irs_infos, cudaDataType inout_data_type, + cusolver_int_t n, cusolver_int_t nrhs, void *dA, cusolver_int_t ldda, + cusolver_int_t *dipiv, void *dB, cusolver_int_t lddb, void *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *niters, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnIRSParams_t, cusolverDnIRSInfos_t, + cudaDataType, cusolver_int_t, cusolver_int_t, void *, cusolver_int_t, + cusolver_int_t *, void *, cusolver_int_t, void *, cusolver_int_t, void *, + size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSXgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, gesv_irs_params, gesv_irs_infos, inout_data_type, n, + nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, niters, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgesv_bufferSize( + cusolverDnHandle_t handle, cusolverDnIRSParams_t params, cusolver_int_t n, + cusolver_int_t nrhs, size_t *lwork_bytes) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cusolverDnIRSParams_t, + cusolver_int_t, cusolver_int_t, size_t *); + static auto func_ptr = LoadSymbol("cusolverDnIRSXgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, n, nrhs, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, float *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnDpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, double *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, + float *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, + double *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + cuComplex *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const float *A, int lda, + float *B, int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const float *, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const double *A, + int lda, double *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const double *, int, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const cuComplex *A, + int lda, cuComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuComplex *, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrfBatched(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, float *Aarray[], + int lda, int *infoArray, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *[], int, int *, int); + static auto func_ptr = LoadSymbol("cusolverDnSpotrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, Aarray, lda, infoArray, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrfBatched(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, double *Aarray[], + int lda, int *infoArray, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *[], int, int *, int); + static auto func_ptr = LoadSymbol("cusolverDnDpotrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, Aarray, lda, infoArray, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotrfBatched(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, cuComplex *Aarray[], + int lda, int *infoArray, + int batchSize) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + cuComplex *[], int, int *, int); + static auto func_ptr = LoadSymbol("cusolverDnCpotrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, Aarray, lda, infoArray, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotrfBatched( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + cuDoubleComplex *Aarray[], int lda, int *infoArray, int batchSize) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + cuDoubleComplex *[], int, int *, int); + static auto func_ptr = LoadSymbol("cusolverDnZpotrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, Aarray, lda, infoArray, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrsBatched( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, /* only support rhs = 1*/ + float *A[], int lda, float *B[], int ldb, int *d_info, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, float *[], int, float *[], + int, int *, int); + static auto func_ptr = LoadSymbol("cusolverDnSpotrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, d_info, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrsBatched( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, /* only support rhs = 1*/ + double *A[], int lda, double *B[], int ldb, int *d_info, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, double *[], int, + double *[], int, int *, int); + static auto func_ptr = LoadSymbol("cusolverDnDpotrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, d_info, batchSize); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCpotrsBatched(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, /* only support rhs = 1*/ + cuComplex *A[], int lda, cuComplex *B[], int ldb, + int *d_info, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, cuComplex *[], int, + cuComplex *[], int, int *, int); + static auto func_ptr = LoadSymbol("cusolverDnCpotrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, d_info, batchSize); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZpotrsBatched(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, /* only support rhs = 1*/ + cuDoubleComplex *A[], int lda, cuDoubleComplex *B[], + int ldb, int *d_info, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, cuDoubleComplex *[], int, + cuDoubleComplex *[], int, int *, int); + static auto func_ptr = LoadSymbol("cusolverDnZpotrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, d_info, batchSize); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSpotri_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSpotri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnDpotri_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDpotri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCpotri_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCpotri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZpotri_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZpotri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, float *work, + int lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSpotri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, double *work, + int lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDpotri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + cuComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCpotri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZpotri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnStrtri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, float *, int, + int *); + static auto func_ptr = LoadSymbol("cusolverDnStrtri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDtrtri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, double *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDtrtri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCtrtri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCtrtri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZtrtri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZtrtri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnStrtri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + cublasDiagType_t diag, int n, + float *A, int lda, float *work, + int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, float *, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnStrtri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDtrtri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + cublasDiagType_t diag, int n, + double *A, int lda, double *work, + int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, double *, + int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDtrtri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCtrtri( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, cuComplex *A, int lda, cuComplex *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, cuComplex *, + int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCtrtri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZtrtri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + cublasDiagType_t diag, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, + cuDoubleComplex *, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZtrtri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSlauum_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSlauum_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnDlauum_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDlauum_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnClauum_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnClauum_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZlauum_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZlauum_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSlauum(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, float *work, + int lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSlauum"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDlauum(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, double *work, + int lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDlauum"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnClauum(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + cuComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnClauum"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZlauum(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZlauum"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, float *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, double *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgetrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgetrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuDoubleComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrf(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *Workspace, int *devIpiv, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrf(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *Workspace, int *devIpiv, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, double *, int, double *, int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgetrf(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + cuComplex *Workspace, + int *devIpiv, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, cuComplex *, + int, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgetrf(cusolverDnHandle_t handle, int m, + int n, cuDoubleComplex *A, + int lda, + cuDoubleComplex *Workspace, + int *devIpiv, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, cuDoubleComplex *, + int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSlaswp(cusolverDnHandle_t handle, int n, + float *A, int lda, int k1, int k2, + const int *devIpiv, int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, float *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol("cusolverDnSlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDlaswp(cusolverDnHandle_t handle, int n, + double *A, int lda, int k1, + int k2, const int *devIpiv, + int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, double *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol("cusolverDnDlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnClaswp(cusolverDnHandle_t handle, int n, + cuComplex *A, int lda, int k1, + int k2, const int *devIpiv, + int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, cuComplex *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol("cusolverDnClaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZlaswp(cusolverDnHandle_t handle, int n, + cuDoubleComplex *A, int lda, + int k1, int k2, + const int *devIpiv, int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + cuDoubleComplex *, int, int, + int, const int *, int); + static auto func_ptr = LoadSymbol("cusolverDnZlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const float *A, int lda, + const int *devIpiv, float *B, + int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const float *, int, + const int *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const double *A, + int lda, const int *devIpiv, + double *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const double *, int, + const int *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const cuComplex *A, + int lda, const int *devIpiv, + cuComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const cuComplex *, int, + const int *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgetrs( + cusolverDnHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuDoubleComplex *A, int lda, const int *devIpiv, cuDoubleComplex *B, + int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, + int, const int *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgeqrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgeqrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgeqrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgeqrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgeqrf(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *TAU, float *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgeqrf(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *TAU, double *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, double *, + int, double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgeqrf(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + cuComplex *TAU, + cuComplex *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, cuComplex *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgeqrf(cusolverDnHandle_t handle, int m, + int n, cuDoubleComplex *A, + int lda, cuDoubleComplex *TAU, + cuDoubleComplex *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, cuDoubleComplex *, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const float *A, int lda, + const float *tau, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const double *A, int lda, + const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const cuComplex *A, int lda, + const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + int, const cuComplex *, int, + const cuComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgqr(cusolverDnHandle_t handle, int m, + int n, int k, float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, float *, int, const float *, float *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgqr(cusolverDnHandle_t handle, int m, + int n, int k, double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, double *, int, const double *, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungqr(cusolverDnHandle_t handle, int m, + int n, int k, cuComplex *A, + int lda, const cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, cuComplex *, int, const cuComplex *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungqr( + cusolverDnHandle_t handle, int m, int n, int k, cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const float *A, int lda, const float *tau, + const float *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const float *, int, const float *, const float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSormqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const double *A, int lda, const double *tau, + const double *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const double *, int, const double *, const double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDormqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuComplex *A, int lda, const cuComplex *tau, + const cuComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuComplex *, int, const cuComplex *, const cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCunmqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, const cuDoubleComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZunmqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const float *A, int lda, const float *tau, float *C, + int ldc, float *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const float *, int, const float *, float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSormqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const double *A, int lda, const double *tau, double *C, + int ldc, double *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const double *, int, const double *, double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDormqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuComplex *A, int lda, const cuComplex *tau, + cuComplex *C, int ldc, cuComplex *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuComplex *, int, const cuComplex *, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCunmqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, cuDoubleComplex *C, int ldc, + cuDoubleComplex *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZunmqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrf_bufferSize( + cusolverDnHandle_t handle, int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrf_bufferSize( + cusolverDnHandle_t handle, int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytrf_bufferSize( + cusolverDnHandle_t handle, int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrf_bufferSize( + cusolverDnHandle_t handle, int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, int *ipiv, + float *work, int lwork, + int *info) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, int *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, int *ipiv, + double *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *, double *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, int *ipiv, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + int *ipiv, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrs_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const float *A, int lda, const int *ipiv, float *B, int ldb, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const float *, int, + const int *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrs_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrs_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const double *A, int lda, const int *ipiv, double *B, int ldb, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const double *, int, + const int *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrs_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytrs_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const cuComplex *A, int lda, const int *ipiv, cuComplex *B, int ldb, + int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuComplex *, int, + const int *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCsytrs_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrs_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const cuDoubleComplex *A, int lda, const int *ipiv, cuDoubleComplex *B, + int ldb, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + int, const int *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZsytrs_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const float *A, int lda, + const int *ipiv, float *B, + int ldb, float *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const float *, int, + const int *, float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const double *A, + int lda, const int *ipiv, + double *B, int ldb, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const double *, int, + const int *, double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCsytrs(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, const cuComplex *A, int lda, const int *ipiv, + cuComplex *B, int ldb, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuComplex *, int, + const int *, cuComplex *, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCsytrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrs( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const cuDoubleComplex *A, int lda, const int *ipiv, cuDoubleComplex *B, + int ldb, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + int, const int *, cuDoubleComplex *, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZsytrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float *A, int lda, + const int *ipiv, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, const int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double *A, int lda, + const int *ipiv, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, const int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuComplex *A, + int lda, const int *ipiv, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + cuComplex *, int, const int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCsytri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, const int *ipiv, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + const int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZsytri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, + const int *ipiv, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, const int *, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, + const int *ipiv, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, const int *, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + const int *ipiv, cuComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, const int *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCsytri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytri( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, const int *ipiv, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + const int *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZsytri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgebrd(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *D, float *E, float *TAUQ, + float *TAUP, float *Work, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, float *, float *, + float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgebrd(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *D, double *E, + double *TAUQ, double *TAUP, + double *Work, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, double *, int, double *, double *, double *, + double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgebrd(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + float *D, float *E, + cuComplex *TAUQ, cuComplex *TAUP, + cuComplex *Work, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuComplex *, int, float *, float *, + cuComplex *, cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgebrd( + cusolverDnHandle_t handle, int m, int n, cuDoubleComplex *A, int lda, + double *D, double *E, cuDoubleComplex *TAUQ, cuDoubleComplex *TAUP, + cuDoubleComplex *Work, int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, double *, double *, + cuDoubleComplex *, cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const float *A, int lda, const float *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const float *, int, + const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const double *A, int lda, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const cuComplex *A, int lda, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const cuComplex *, + int, const cuComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, float *, int, + const float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, double *, int, + const double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, cuComplex *A, + int lda, const cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, cuComplex *, int, + const cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZungbr(cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, + int k, cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, + cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, cuDoubleComplex *, + int, const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const float *A, + int lda, const float *d, const float *e, const float *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const float *, int, + const float *, const float *, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const double *A, + int lda, const double *d, const double *e, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const double *, int, + const double *, const double *, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChetrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const cuComplex *A, + int lda, const float *d, const float *e, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuComplex *, int, + const float *, const float *, const cuComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnChetrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhetrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, const double *d, const double *e, + const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + const double *, const double *, const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhetrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrd(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, float *d, + float *e, float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, float *, float *, + float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrd( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double *A, int lda, + double *d, double *e, double *tau, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, double *, + double *, double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChetrd(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *d, + float *e, cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, float *, + float *, cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnChetrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhetrd( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, double *d, double *e, cuDoubleComplex *tau, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + double *, double *, cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhetrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const float *A, + int lda, const float *tau, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const double *A, + int lda, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const cuComplex *A, + int lda, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuComplex *, int, + const cuComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, const float *, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, const double *, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungtr( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuComplex *A, + int lda, const cuComplex *tau, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, + const cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, + cuDoubleComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const float *A, int lda, + const float *tau, const float *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const float *, int, const float *, const float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSormtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const double *A, int lda, + const double *tau, const double *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const double *, int, const double *, const double *, int, + int *); + static auto func_ptr = LoadSymbol("cusolverDnDormtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const cuComplex *A, int lda, + const cuComplex *tau, const cuComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const cuComplex *, int, const cuComplex *, const cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCunmtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, const cuDoubleComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const cuDoubleComplex *, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZunmtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, float *A, int lda, float *tau, + float *C, int ldc, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, float *, int, float *, float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSormtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, double *A, int lda, double *tau, + double *C, int ldc, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, double *, int, double *, double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDormtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCunmtr(cusolverDnHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, cublasOperation_t trans, int m, int n, + cuComplex *A, int lda, cuComplex *tau, cuComplex *C, int ldc, + cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, cuComplex *, int, cuComplex *, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCunmtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, cuDoubleComplex *A, int lda, + cuDoubleComplex *tau, cuDoubleComplex *C, int ldc, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, cuDoubleComplex *, int, cuDoubleComplex *, cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZunmtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvd( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, float *A, int lda, float *S, float *U, int ldu, float *VT, int ldvt, + float *work, int lwork, float *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, float *, int, + float *, float *, int, float *, int, float *, int, float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvd( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, double *A, int lda, double *S, double *U, int ldu, double *VT, + int ldvt, double *work, int lwork, double *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, double *, int, + double *, double *, int, double *, int, double *, int, double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgesvd(cusolverDnHandle_t handle, signed char jobu, signed char jobvt, + int m, int n, cuComplex *A, int lda, float *S, cuComplex *U, + int ldu, cuComplex *VT, int ldvt, cuComplex *work, int lwork, + float *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, float *, + int *); + static auto func_ptr = LoadSymbol("cusolverDnCgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgesvd(cusolverDnHandle_t handle, signed char jobu, signed char jobvt, + int m, int n, cuDoubleComplex *A, int lda, double *S, + cuDoubleComplex *U, int ldu, cuDoubleComplex *VT, int ldvt, + cuDoubleComplex *work, int lwork, double *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, cuDoubleComplex *, + int, double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsyevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsyevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCheevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZheevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevd( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, float *A, int lda, float *W, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsyevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevd( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, double *A, int lda, double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsyevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevd(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *W, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCheevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevd(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + double *W, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZheevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, const float *A, int lda, float vl, float vu, + int il, int iu, int *meig, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, const float *, int, float, float, int, int, int *, + const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsyevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, const double *A, int lda, double vl, + double vu, int il, int iu, int *meig, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, const double *, int, double, double, int, int, + int *, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsyevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, const cuComplex *A, int lda, float vl, + float vu, int il, int iu, int *meig, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, const cuComplex *, int, float, float, int, int, + int *, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCheevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, const cuDoubleComplex *A, int lda, double vl, + double vu, int il, int iu, int *meig, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, const cuDoubleComplex *, int, double, double, int, + int, int *, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZheevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevdx( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, float *A, int lda, float vl, float vu, int il, + int iu, int *meig, float *W, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, float *, int, float, float, int, int, int *, + float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsyevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevdx( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, double *A, int lda, double vl, double vu, + int il, int iu, int *meig, double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, double *, int, double, double, int, int, int *, + double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsyevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCheevdx(cusolverDnHandle_t handle, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float vl, float vu, int il, int iu, + int *meig, float *W, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, cuComplex *, int, float, float, int, int, int *, + float *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCheevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevdx( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, cuDoubleComplex *A, int lda, double vl, + double vu, int il, int iu, int *meig, double *W, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, cuDoubleComplex *, int, double, double, int, int, + int *, double *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZheevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, const float *A, + int lda, const float *B, int ldb, float vl, float vu, int il, int iu, + int *meig, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, const float *, int, + const float *, int, float, float, int, int, int *, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsygvdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, const double *A, + int lda, const double *B, int ldb, double vl, double vu, int il, int iu, + int *meig, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, const double *, int, + const double *, int, double, double, int, int, int *, const double *, + int *); + static auto func_ptr = LoadSymbol("cusolverDnDsygvdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, const cuComplex *A, + int lda, const cuComplex *B, int ldb, float vl, float vu, int il, int iu, + int *meig, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, const cuComplex *, int, + const cuComplex *, int, float, float, int, int, int *, const float *, + int *); + static auto func_ptr = LoadSymbol("cusolverDnChegvdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + double vl, double vu, int il, int iu, int *meig, const double *W, + int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, double, double, int, int, int *, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhegvdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvdx( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, float *A, int lda, + float *B, int ldb, float vl, float vu, int il, int iu, int *meig, float *W, + float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, float *, int, float *, int, + float, float, int, int, int *, float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsygvdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvdx( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, double *A, int lda, + double *B, int ldb, double vl, double vu, int il, int iu, int *meig, + double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, double *, int, double *, int, + double, double, int, int, int *, double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsygvdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvdx( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, cuComplex *A, + int lda, cuComplex *B, int ldb, float vl, float vu, int il, int iu, + int *meig, float *W, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, float, float, int, int, int *, float *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnChegvdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvdx( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, cuDoubleComplex *B, int ldb, double vl, double vu, int il, int iu, + int *meig, double *W, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, double, double, int, int, int *, double *, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhegvdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const float *A, int lda, const float *B, + int ldb, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const float *, int, const float *, int, + const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsygvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const double *A, int lda, const double *B, + int ldb, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const double *, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsygvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuComplex *A, int lda, + const cuComplex *B, int ldb, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuComplex *, int, const cuComplex *, int, + const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnChegvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhegvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, float *A, int lda, float *B, int ldb, + float *W, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, float *, int, float *, int, float *, float *, int, + int *); + static auto func_ptr = LoadSymbol("cusolverDnSsygvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, double *A, int lda, double *B, int ldb, + double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, double *, int, double *, int, double *, double *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsygvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuComplex *A, int lda, cuComplex *B, int ldb, + float *W, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuComplex *, int, cuComplex *, int, float *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnChegvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZhegvd(cusolverDnHandle_t handle, cusolverEigType_t itype, + cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb, + double *W, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhegvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreateSyevjInfo(syevjInfo_t *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t *); + static auto func_ptr = LoadSymbol("cusolverDnCreateSyevjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroySyevjInfo(syevjInfo_t info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDestroySyevjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetTolerance(syevjInfo_t info, + double tolerance) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, double); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjSetTolerance"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, tolerance); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetMaxSweeps(syevjInfo_t info, + int max_sweeps) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjSetMaxSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, max_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetSortEig(syevjInfo_t info, + int sort_eig) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjSetSortEig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, sort_eig); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjGetResidual( + cusolverDnHandle_t handle, syevjInfo_t info, double *residual) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, + syevjInfo_t, double *); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjGetResidual"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, residual); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjGetSweeps( + cusolverDnHandle_t handle, syevjInfo_t info, int *executed_sweeps) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, syevjInfo_t, int *); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjGetSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, executed_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnSsyevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnDsyevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnCheevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnZheevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, float *A, int lda, float *W, float *work, int lwork, int *info, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnSsyevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, double *A, int lda, double *W, double *work, int lwork, int *info, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnDsyevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, float *W, cuComplex *work, int lwork, + int *info, syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnCheevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, double *W, cuDoubleComplex *work, + int lwork, int *info, syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *, + syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnZheevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSsyevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDsyevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnCheevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZheevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + float *A, int lda, float *W, + float *work, int lwork, int *info, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSsyevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + double *A, int lda, double *W, + double *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDsyevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *W, + cuComplex *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnCheevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, double *W, cuDoubleComplex *work, + int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *, + syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZheevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const float *A, int lda, const float *B, + int ldb, const float *W, int *lwork, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const float *, int, const float *, int, + const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSsygvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const double *A, int lda, const double *B, + int ldb, const double *W, int *lwork, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const double *, int, const double *, int, + const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDsygvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuComplex *A, int lda, + const cuComplex *B, int ldb, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuComplex *, int, const cuComplex *, int, + const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnChegvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZhegvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, float *A, int lda, float *B, int ldb, + float *W, float *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, float *, int, float *, int, float *, float *, int, + int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSsygvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, double *A, int lda, double *B, int ldb, + double *W, double *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, double *, int, double *, int, double *, double *, + int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDsygvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuComplex *A, int lda, cuComplex *B, int ldb, + float *W, cuComplex *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuComplex *, int, cuComplex *, int, float *, + cuComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnChegvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb, double *W, cuDoubleComplex *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZhegvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreateGesvdjInfo(gesvdjInfo_t *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t *); + static auto func_ptr = LoadSymbol("cusolverDnCreateGesvdjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroyGesvdjInfo(gesvdjInfo_t info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDestroyGesvdjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetTolerance(gesvdjInfo_t info, + double tolerance) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, double); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjSetTolerance"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, tolerance); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetMaxSweeps(gesvdjInfo_t info, + int max_sweeps) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjSetMaxSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, max_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetSortEig(gesvdjInfo_t info, + int sort_svd) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjSetSortEig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, sort_svd); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjGetResidual( + cusolverDnHandle_t handle, gesvdjInfo_t info, double *residual) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, + gesvdjInfo_t, double *); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjGetResidual"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, residual); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjGetSweeps( + cusolverDnHandle_t handle, gesvdjInfo_t info, int *executed_sweeps) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, gesvdjInfo_t, int *); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjGetSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, executed_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const float *A, int lda, const float *S, const float *U, int ldu, + const float *V, int ldv, int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const float *, int, + const float *, const float *, int, const float *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnSgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const double *A, int lda, const double *S, const double *U, int ldu, + const double *V, int ldv, int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const double *, int, + const double *, const double *, int, const double *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnDgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const cuComplex *A, int lda, const float *S, const cuComplex *U, int ldu, + const cuComplex *V, int ldv, int *lwork, gesvdjInfo_t params, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const cuComplex *, int, + const float *, const cuComplex *, int, const cuComplex *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnCgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const cuDoubleComplex *A, int lda, const double *S, + const cuDoubleComplex *U, int ldu, const cuDoubleComplex *V, int ldv, + int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const cuDoubleComplex *, + int, const double *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnZgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, float *A, + int lda, float *S, float *U, int ldu, float *V, int ldv, float *work, + int lwork, int *info, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, float *, int, float *, + float *, int, float *, int, float *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnSgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, double *A, + int lda, double *S, double *U, int ldu, double *V, int ldv, double *work, + int lwork, int *info, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, double *, int, double *, + double *, int, double *, int, double *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnDgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + cuComplex *A, int lda, float *S, cuComplex *U, int ldu, cuComplex *V, + int ldv, cuComplex *work, int lwork, int *info, gesvdjInfo_t params, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnCgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + cuDoubleComplex *A, int lda, double *S, cuDoubleComplex *U, int ldu, + cuDoubleComplex *V, int ldv, cuDoubleComplex *work, int lwork, int *info, + gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnZgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const float *A, int lda, const float *S, const float *U, int ldu, + const float *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const float *, int, + const float *, const float *, int, const float *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const double *A, int lda, const double *S, const double *U, int ldu, + const double *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const double *, int, + const double *, const double *, int, const double *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const cuComplex *A, int lda, const float *S, const cuComplex *U, int ldu, + const cuComplex *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const cuComplex *, + int, const float *, const cuComplex *, int, const cuComplex *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnCgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const cuDoubleComplex *A, int lda, const double *S, + const cuDoubleComplex *U, int ldu, const cuDoubleComplex *V, int ldv, + int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, + const cuDoubleComplex *, int, const double *, const cuDoubleComplex *, + int, const cuDoubleComplex *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + float *A, int lda, float *S, float *U, int ldu, float *V, int ldv, + float *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, float *, int, + float *, float *, int, float *, int, float *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + double *A, int lda, double *S, double *U, int ldu, double *V, int ldv, + double *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, double *, int, + double *, double *, int, double *, int, double *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + cuComplex *A, int lda, float *S, cuComplex *U, int ldu, cuComplex *V, + int ldv, cuComplex *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnCgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + cuDoubleComplex *A, int lda, double *S, cuDoubleComplex *U, int ldu, + cuDoubleComplex *V, int ldv, cuDoubleComplex *work, int lwork, int *info, + gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, cuDoubleComplex *, + int, double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const float *d_A, int lda, long long int strideA, const float *d_S, + long long int strideS, const float *d_U, int ldu, long long int strideU, + const float *d_V, int ldv, long long int strideV, int *lwork, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const float *, int, + long long, const float *, long long, const float *, int, long long, + const float *, int, long long, int *, int); + static auto func_ptr = + LoadSymbol("cusolverDnSgesvdaStridedBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, lwork, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const double *d_A, int lda, long long int strideA, const double *d_S, + long long int strideS, const double *d_U, int ldu, long long int strideU, + const double *d_V, int ldv, long long int strideV, int *lwork, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const double *, int, + long long, const double *, long long, const double *, int, long long, + const double *, int, long long, int *, int); + static auto func_ptr = + LoadSymbol("cusolverDnDgesvdaStridedBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, lwork, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const cuComplex *d_A, int lda, long long int strideA, const float *d_S, + long long int strideS, const cuComplex *d_U, int ldu, long long int strideU, + const cuComplex *d_V, int ldv, long long int strideV, int *lwork, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const cuComplex *, + int, long long, const float *, long long, const cuComplex *, int, + long long, const cuComplex *, int, long long, int *, int); + static auto func_ptr = + LoadSymbol("cusolverDnCgesvdaStridedBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, lwork, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const cuDoubleComplex *d_A, int lda, long long int strideA, + const double *d_S, long long int strideS, const cuDoubleComplex *d_U, + int ldu, long long int strideU, const cuDoubleComplex *d_V, int ldv, + long long int strideV, int *lwork, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, + const cuDoubleComplex *, int, long long, const double *, long long, + const cuDoubleComplex *, int, long long, const cuDoubleComplex *, int, + long long, int *, int); + static auto func_ptr = + LoadSymbol("cusolverDnZgesvdaStridedBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, lwork, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdaStridedBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const float *d_A, int lda, long long int strideA, float *d_S, + long long int strideS, float *d_U, int ldu, long long int strideU, + float *d_V, int ldv, long long int strideV, float *d_work, int lwork, + int *d_info, double *h_R_nrmF, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const float *, int, + long long, float *, long long, float *, int, long long, float *, int, + long long, float *, int, int *, double *, int); + static auto func_ptr = LoadSymbol("cusolverDnSgesvdaStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, d_work, lwork, d_info, + h_R_nrmF, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdaStridedBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const double *d_A, int lda, long long int strideA, double *d_S, + long long int strideS, double *d_U, int ldu, long long int strideU, + double *d_V, int ldv, long long int strideV, double *d_work, int lwork, + int *d_info, double *h_R_nrmF, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const double *, int, + long long, double *, long long, double *, int, long long, double *, int, + long long, double *, int, int *, double *, int); + static auto func_ptr = LoadSymbol("cusolverDnDgesvdaStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, d_work, lwork, d_info, + h_R_nrmF, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdaStridedBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const cuComplex *d_A, int lda, long long int strideA, float *d_S, + long long int strideS, cuComplex *d_U, int ldu, long long int strideU, + cuComplex *d_V, int ldv, long long int strideV, cuComplex *d_work, + int lwork, int *d_info, double *h_R_nrmF, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const cuComplex *, + int, long long, float *, long long, cuComplex *, int, long long, + cuComplex *, int, long long, cuComplex *, int, int *, double *, int); + static auto func_ptr = LoadSymbol("cusolverDnCgesvdaStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, d_work, lwork, d_info, + h_R_nrmF, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdaStridedBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const cuDoubleComplex *d_A, int lda, long long int strideA, double *d_S, + long long int strideS, cuDoubleComplex *d_U, int ldu, long long int strideU, + cuDoubleComplex *d_V, int ldv, long long int strideV, + cuDoubleComplex *d_work, int lwork, int *d_info, double *h_R_nrmF, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, + const cuDoubleComplex *, int, long long, double *, long long, + cuDoubleComplex *, int, long long, cuDoubleComplex *, int, long long, + cuDoubleComplex *, int, int *, double *, int); + static auto func_ptr = LoadSymbol("cusolverDnZgesvdaStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, d_work, lwork, d_info, + h_R_nrmF, batchSize); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cusolver_dense_9_0.inc b/tensorflow/stream_executor/cuda/cusolver_dense_9_0.inc new file mode 100644 index 00000000000..fab9afff8e4 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cusolver_dense_9_0.inc @@ -0,0 +1,2185 @@ +// Auto-generated, do not edit. + +extern "C" { + +cusolverStatus_t CUSOLVERAPI cusolverGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol("cusolverGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreate(cusolverDnHandle_t *handle) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t *); + static auto func_ptr = LoadSymbol("cusolverDnCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroy(cusolverDnHandle_t handle) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t); + static auto func_ptr = LoadSymbol("cusolverDnDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSetStream(cusolverDnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cusolverDnSetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnGetStream(cusolverDnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cudaStream_t *); + static auto func_ptr = LoadSymbol("cusolverDnGetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, float *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnDpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, double *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, + float *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, + double *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + cuComplex *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const float *A, int lda, + float *B, int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const float *, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const double *A, + int lda, double *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const double *, int, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const cuComplex *A, + int lda, cuComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuComplex *, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, float *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, double *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgetrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgetrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuDoubleComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrf(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *Workspace, int *devIpiv, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrf(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *Workspace, int *devIpiv, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, double *, int, double *, int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgetrf(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + cuComplex *Workspace, + int *devIpiv, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, cuComplex *, + int, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgetrf(cusolverDnHandle_t handle, int m, + int n, cuDoubleComplex *A, + int lda, + cuDoubleComplex *Workspace, + int *devIpiv, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, cuDoubleComplex *, + int *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSlaswp(cusolverDnHandle_t handle, int n, + float *A, int lda, int k1, int k2, + const int *devIpiv, int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, float *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol("cusolverDnSlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDlaswp(cusolverDnHandle_t handle, int n, + double *A, int lda, int k1, + int k2, const int *devIpiv, + int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, double *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol("cusolverDnDlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnClaswp(cusolverDnHandle_t handle, int n, + cuComplex *A, int lda, int k1, + int k2, const int *devIpiv, + int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, cuComplex *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol("cusolverDnClaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZlaswp(cusolverDnHandle_t handle, int n, + cuDoubleComplex *A, int lda, + int k1, int k2, + const int *devIpiv, int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + cuDoubleComplex *, int, int, + int, const int *, int); + static auto func_ptr = LoadSymbol("cusolverDnZlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const float *A, int lda, + const int *devIpiv, float *B, + int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const float *, int, + const int *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const double *A, + int lda, const int *devIpiv, + double *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const double *, int, + const int *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const cuComplex *A, + int lda, const int *devIpiv, + cuComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const cuComplex *, int, + const int *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgetrs( + cusolverDnHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuDoubleComplex *A, int lda, const int *devIpiv, cuDoubleComplex *B, + int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, + int, const int *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgeqrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgeqrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgeqrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgeqrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgeqrf(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *TAU, float *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgeqrf(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *TAU, double *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, double *, + int, double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgeqrf(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + cuComplex *TAU, + cuComplex *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, cuComplex *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgeqrf(cusolverDnHandle_t handle, int m, + int n, cuDoubleComplex *A, + int lda, cuDoubleComplex *TAU, + cuDoubleComplex *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, cuDoubleComplex *, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const float *A, int lda, + const float *tau, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const double *A, int lda, + const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const cuComplex *A, int lda, + const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + int, const cuComplex *, int, + const cuComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgqr(cusolverDnHandle_t handle, int m, + int n, int k, float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, float *, int, const float *, float *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgqr(cusolverDnHandle_t handle, int m, + int n, int k, double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, double *, int, const double *, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungqr(cusolverDnHandle_t handle, int m, + int n, int k, cuComplex *A, + int lda, const cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, cuComplex *, int, const cuComplex *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungqr( + cusolverDnHandle_t handle, int m, int n, int k, cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const float *A, int lda, const float *tau, + const float *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const float *, int, const float *, const float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSormqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const double *A, int lda, const double *tau, + const double *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const double *, int, const double *, const double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDormqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuComplex *A, int lda, const cuComplex *tau, + const cuComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuComplex *, int, const cuComplex *, const cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCunmqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, const cuDoubleComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZunmqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const float *A, int lda, const float *tau, float *C, + int ldc, float *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const float *, int, const float *, float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSormqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const double *A, int lda, const double *tau, double *C, + int ldc, double *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const double *, int, const double *, double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDormqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuComplex *A, int lda, const cuComplex *tau, + cuComplex *C, int ldc, cuComplex *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuComplex *, int, const cuComplex *, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCunmqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, cuDoubleComplex *C, int ldc, + cuDoubleComplex *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZunmqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrf_bufferSize( + cusolverDnHandle_t handle, int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrf_bufferSize( + cusolverDnHandle_t handle, int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytrf_bufferSize( + cusolverDnHandle_t handle, int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrf_bufferSize( + cusolverDnHandle_t handle, int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, int *ipiv, + float *work, int lwork, + int *info) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, int *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, int *ipiv, + double *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *, double *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, int *ipiv, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + int *ipiv, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgebrd(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *D, float *E, float *TAUQ, + float *TAUP, float *Work, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, float *, float *, + float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgebrd(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *D, double *E, + double *TAUQ, double *TAUP, + double *Work, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, double *, int, double *, double *, double *, + double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgebrd(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + float *D, float *E, + cuComplex *TAUQ, cuComplex *TAUP, + cuComplex *Work, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuComplex *, int, float *, float *, + cuComplex *, cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgebrd( + cusolverDnHandle_t handle, int m, int n, cuDoubleComplex *A, int lda, + double *D, double *E, cuDoubleComplex *TAUQ, cuDoubleComplex *TAUP, + cuDoubleComplex *Work, int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, double *, double *, + cuDoubleComplex *, cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const float *A, int lda, const float *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const float *, int, + const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const double *A, int lda, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const cuComplex *A, int lda, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const cuComplex *, + int, const cuComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, float *, int, + const float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, double *, int, + const double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, cuComplex *A, + int lda, const cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, cuComplex *, int, + const cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZungbr(cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, + int k, cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, + cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, cuDoubleComplex *, + int, const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const float *A, + int lda, const float *d, const float *e, const float *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const float *, int, + const float *, const float *, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const double *A, + int lda, const double *d, const double *e, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const double *, int, + const double *, const double *, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChetrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const cuComplex *A, + int lda, const float *d, const float *e, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuComplex *, int, + const float *, const float *, const cuComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnChetrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhetrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, const double *d, const double *e, + const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + const double *, const double *, const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhetrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrd(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, float *d, + float *e, float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, float *, float *, + float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsytrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrd( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double *A, int lda, + double *d, double *e, double *tau, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, double *, + double *, double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsytrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChetrd(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *d, + float *e, cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, float *, + float *, cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnChetrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhetrd( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, double *d, double *e, cuDoubleComplex *tau, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + double *, double *, cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhetrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const float *A, + int lda, const float *tau, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const double *A, + int lda, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const cuComplex *A, + int lda, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuComplex *, int, + const cuComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, const float *, + float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSorgtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, const double *, + double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDorgtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungtr( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuComplex *A, + int lda, const cuComplex *tau, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, + const cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCungtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, + cuDoubleComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZungtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const float *A, int lda, + const float *tau, const float *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const float *, int, const float *, const float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSormtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const double *A, int lda, + const double *tau, const double *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const double *, int, const double *, const double *, int, + int *); + static auto func_ptr = LoadSymbol("cusolverDnDormtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const cuComplex *A, int lda, + const cuComplex *tau, const cuComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const cuComplex *, int, const cuComplex *, const cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCunmtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, const cuDoubleComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const cuDoubleComplex *, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZunmtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, float *A, int lda, float *tau, + float *C, int ldc, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, float *, int, float *, float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSormtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, double *A, int lda, double *tau, + double *C, int ldc, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, double *, int, double *, double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDormtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCunmtr(cusolverDnHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, cublasOperation_t trans, int m, int n, + cuComplex *A, int lda, cuComplex *tau, cuComplex *C, int ldc, + cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, cuComplex *, int, cuComplex *, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCunmtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, cuDoubleComplex *A, int lda, + cuDoubleComplex *tau, cuDoubleComplex *C, int ldc, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, cuDoubleComplex *, int, cuDoubleComplex *, cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZunmtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvd( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, float *A, int lda, float *S, float *U, int ldu, float *VT, int ldvt, + float *work, int lwork, float *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, float *, int, + float *, float *, int, float *, int, float *, int, float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvd( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, double *A, int lda, double *S, double *U, int ldu, double *VT, + int ldvt, double *work, int lwork, double *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, double *, int, + double *, double *, int, double *, int, double *, int, double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgesvd(cusolverDnHandle_t handle, signed char jobu, signed char jobvt, + int m, int n, cuComplex *A, int lda, float *S, cuComplex *U, + int ldu, cuComplex *VT, int ldvt, cuComplex *work, int lwork, + float *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, float *, + int *); + static auto func_ptr = LoadSymbol("cusolverDnCgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgesvd(cusolverDnHandle_t handle, signed char jobu, signed char jobvt, + int m, int n, cuDoubleComplex *A, int lda, double *S, + cuDoubleComplex *U, int ldu, cuDoubleComplex *VT, int ldvt, + cuDoubleComplex *work, int lwork, double *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, cuDoubleComplex *, + int, double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsyevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsyevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnCheevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZheevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevd( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, float *A, int lda, float *W, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsyevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevd( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, double *A, int lda, double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsyevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevd(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *W, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnCheevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevd(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + double *W, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZheevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const float *A, int lda, const float *B, + int ldb, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const float *, int, const float *, int, + const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnSsygvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const double *A, int lda, const double *B, + int ldb, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const double *, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsygvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuComplex *A, int lda, + const cuComplex *B, int ldb, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuComplex *, int, const cuComplex *, int, + const float *, int *); + static auto func_ptr = LoadSymbol("cusolverDnChegvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhegvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, float *A, int lda, float *B, int ldb, + float *W, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, float *, int, float *, int, float *, float *, int, + int *); + static auto func_ptr = LoadSymbol("cusolverDnSsygvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, double *A, int lda, double *B, int ldb, + double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, double *, int, double *, int, double *, double *, + int, int *); + static auto func_ptr = LoadSymbol("cusolverDnDsygvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuComplex *A, int lda, cuComplex *B, int ldb, + float *W, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuComplex *, int, cuComplex *, int, float *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnChegvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZhegvd(cusolverDnHandle_t handle, cusolverEigType_t itype, + cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb, + double *W, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol("cusolverDnZhegvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreateSyevjInfo(syevjInfo_t *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t *); + static auto func_ptr = LoadSymbol("cusolverDnCreateSyevjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroySyevjInfo(syevjInfo_t info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDestroySyevjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetTolerance(syevjInfo_t info, + double tolerance) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, double); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjSetTolerance"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, tolerance); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetMaxSweeps(syevjInfo_t info, + int max_sweeps) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjSetMaxSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, max_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetSortEig(syevjInfo_t info, + int sort_eig) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjSetSortEig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, sort_eig); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjGetResidual( + cusolverDnHandle_t handle, syevjInfo_t info, double *residual) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, + syevjInfo_t, double *); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjGetResidual"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, residual); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjGetSweeps( + cusolverDnHandle_t handle, syevjInfo_t info, int *executed_sweeps) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, syevjInfo_t, int *); + static auto func_ptr = LoadSymbol("cusolverDnXsyevjGetSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, executed_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnSsyevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnDsyevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnCheevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnZheevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, float *A, int lda, float *W, float *work, int lwork, int *info, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnSsyevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, double *A, int lda, double *W, double *work, int lwork, int *info, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnDsyevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, float *W, cuComplex *work, int lwork, + int *info, syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnCheevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, double *W, cuDoubleComplex *work, + int lwork, int *info, syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *, + syevjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnZheevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSsyevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDsyevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnCheevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZheevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + float *A, int lda, float *W, + float *work, int lwork, int *info, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSsyevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + double *A, int lda, double *W, + double *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDsyevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *W, + cuComplex *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnCheevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, double *W, cuDoubleComplex *work, + int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *, + syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZheevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const float *A, int lda, const float *B, + int ldb, const float *W, int *lwork, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const float *, int, const float *, int, + const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSsygvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const double *A, int lda, const double *B, + int ldb, const double *W, int *lwork, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const double *, int, const double *, int, + const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDsygvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuComplex *A, int lda, + const cuComplex *B, int ldb, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuComplex *, int, const cuComplex *, int, + const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnChegvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZhegvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, float *A, int lda, float *B, int ldb, + float *W, float *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, float *, int, float *, int, float *, float *, int, + int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSsygvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, double *A, int lda, double *B, int ldb, + double *W, double *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, double *, int, double *, int, double *, double *, + int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDsygvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuComplex *A, int lda, cuComplex *B, int ldb, + float *W, cuComplex *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuComplex *, int, cuComplex *, int, float *, + cuComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnChegvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb, double *W, cuDoubleComplex *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZhegvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreateGesvdjInfo(gesvdjInfo_t *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t *); + static auto func_ptr = LoadSymbol("cusolverDnCreateGesvdjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroyGesvdjInfo(gesvdjInfo_t info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDestroyGesvdjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetTolerance(gesvdjInfo_t info, + double tolerance) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, double); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjSetTolerance"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, tolerance); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetMaxSweeps(gesvdjInfo_t info, + int max_sweeps) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjSetMaxSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, max_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetSortEig(gesvdjInfo_t info, + int sort_svd) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjSetSortEig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, sort_svd); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjGetResidual( + cusolverDnHandle_t handle, gesvdjInfo_t info, double *residual) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, + gesvdjInfo_t, double *); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjGetResidual"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, residual); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjGetSweeps( + cusolverDnHandle_t handle, gesvdjInfo_t info, int *executed_sweeps) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, gesvdjInfo_t, int *); + static auto func_ptr = LoadSymbol("cusolverDnXgesvdjGetSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, executed_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const float *A, int lda, const float *S, const float *U, int ldu, + const float *V, int ldv, int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const float *, int, + const float *, const float *, int, const float *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnSgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const double *A, int lda, const double *S, const double *U, int ldu, + const double *V, int ldv, int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const double *, int, + const double *, const double *, int, const double *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnDgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const cuComplex *A, int lda, const float *S, const cuComplex *U, int ldu, + const cuComplex *V, int ldv, int *lwork, gesvdjInfo_t params, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const cuComplex *, int, + const float *, const cuComplex *, int, const cuComplex *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnCgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const cuDoubleComplex *A, int lda, const double *S, + const cuDoubleComplex *U, int ldu, const cuDoubleComplex *V, int ldv, + int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const cuDoubleComplex *, + int, const double *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol("cusolverDnZgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, float *A, + int lda, float *S, float *U, int ldu, float *V, int ldv, float *work, + int lwork, int *info, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, float *, int, float *, + float *, int, float *, int, float *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnSgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, double *A, + int lda, double *S, double *U, int ldu, double *V, int ldv, double *work, + int lwork, int *info, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, double *, int, double *, + double *, int, double *, int, double *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnDgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + cuComplex *A, int lda, float *S, cuComplex *U, int ldu, cuComplex *V, + int ldv, cuComplex *work, int lwork, int *info, gesvdjInfo_t params, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnCgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + cuDoubleComplex *A, int lda, double *S, cuDoubleComplex *U, int ldu, + cuDoubleComplex *V, int ldv, cuDoubleComplex *work, int lwork, int *info, + gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol("cusolverDnZgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const float *A, int lda, const float *S, const float *U, int ldu, + const float *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const float *, int, + const float *, const float *, int, const float *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const double *A, int lda, const double *S, const double *U, int ldu, + const double *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const double *, int, + const double *, const double *, int, const double *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const cuComplex *A, int lda, const float *S, const cuComplex *U, int ldu, + const cuComplex *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const cuComplex *, + int, const float *, const cuComplex *, int, const cuComplex *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnCgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const cuDoubleComplex *A, int lda, const double *S, + const cuDoubleComplex *U, int ldu, const cuDoubleComplex *V, int ldv, + int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, + const cuDoubleComplex *, int, const double *, const cuDoubleComplex *, + int, const cuDoubleComplex *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + float *A, int lda, float *S, float *U, int ldu, float *V, int ldv, + float *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, float *, int, + float *, float *, int, float *, int, float *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnSgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + double *A, int lda, double *S, double *U, int ldu, double *V, int ldv, + double *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, double *, int, + double *, double *, int, double *, int, double *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnDgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + cuComplex *A, int lda, float *S, cuComplex *U, int ldu, cuComplex *V, + int ldv, cuComplex *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnCgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + cuDoubleComplex *A, int lda, double *S, cuDoubleComplex *U, int ldu, + cuDoubleComplex *V, int ldv, cuDoubleComplex *work, int lwork, int *info, + gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, cuDoubleComplex *, + int, double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol("cusolverDnZgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cusolver_stub.cc b/tensorflow/stream_executor/cuda/cusolver_stub.cc index f92af64fcf1..a4b9cc37f9b 100644 --- a/tensorflow/stream_executor/cuda/cusolver_stub.cc +++ b/tensorflow/stream_executor/cuda/cusolver_stub.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/platform/dso_loader.h" @@ -50,8 +51,16 @@ cusolverStatus_t GetSymbolNotFoundError() { } } // namespace -#if CUDA_VERSION < 10010 +#if CUDA_VERSION < 10000 +#include "tensorflow/stream_executor/cuda/cusolver_dense_9_0.inc" +#elif CUDA_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cusolver_dense_10_0.inc" -#else +#elif CUDA_VERSION == 10010 #include "tensorflow/stream_executor/cuda/cusolver_dense_10_1.inc" +#elif CUDA_VERSION == 10020 +#include "tensorflow/stream_executor/cuda/cusolver_dense_10_2.inc" +#elif CUDA_VERSION == 11000 +#include "tensorflow/stream_executor/cuda/cusolver_dense_11_0.inc" +#else +#error "We don't have a wrapper for this version." #endif diff --git a/tensorflow/stream_executor/cuda/cusparse_10_1.inc b/tensorflow/stream_executor/cuda/cusparse_10_1.inc index 09b3ad11138..c63300697fe 100644 --- a/tensorflow/stream_executor/cuda/cusparse_10_1.inc +++ b/tensorflow/stream_executor/cuda/cusparse_10_1.inc @@ -116,14 +116,6 @@ cusparseStatus_t CUSPARSEAPI cusparseSetMatType(cusparseMatDescr_t descrA, return func_ptr(descrA, type); } -cusparseMatrixType_t CUSPARSEAPI -cusparseGetMatType(const cusparseMatDescr_t descrA) { - using FuncPtr = cusparseMatrixType_t(CUSPARSEAPI *)(const cusparseMatDescr_t); - static auto func_ptr = LoadSymbol("cusparseGetMatType"); - if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(descrA); -} - cusparseStatus_t CUSPARSEAPI cusparseSetMatFillMode(cusparseMatDescr_t descrA, cusparseFillMode_t fillMode) { using FuncPtr = @@ -133,14 +125,6 @@ cusparseSetMatFillMode(cusparseMatDescr_t descrA, cusparseFillMode_t fillMode) { return func_ptr(descrA, fillMode); } -cusparseFillMode_t CUSPARSEAPI -cusparseGetMatFillMode(const cusparseMatDescr_t descrA) { - using FuncPtr = cusparseFillMode_t(CUSPARSEAPI *)(const cusparseMatDescr_t); - static auto func_ptr = LoadSymbol("cusparseGetMatFillMode"); - if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(descrA); -} - cusparseStatus_t CUSPARSEAPI cusparseSetMatDiagType(cusparseMatDescr_t descrA, cusparseDiagType_t diagType) { using FuncPtr = @@ -150,14 +134,6 @@ cusparseSetMatDiagType(cusparseMatDescr_t descrA, cusparseDiagType_t diagType) { return func_ptr(descrA, diagType); } -cusparseDiagType_t CUSPARSEAPI -cusparseGetMatDiagType(const cusparseMatDescr_t descrA) { - using FuncPtr = cusparseDiagType_t(CUSPARSEAPI *)(const cusparseMatDescr_t); - static auto func_ptr = LoadSymbol("cusparseGetMatDiagType"); - if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(descrA); -} - cusparseStatus_t CUSPARSEAPI cusparseSetMatIndexBase(cusparseMatDescr_t descrA, cusparseIndexBase_t base) { using FuncPtr = @@ -167,14 +143,6 @@ cusparseStatus_t CUSPARSEAPI cusparseSetMatIndexBase(cusparseMatDescr_t descrA, return func_ptr(descrA, base); } -cusparseIndexBase_t CUSPARSEAPI -cusparseGetMatIndexBase(const cusparseMatDescr_t descrA) { - using FuncPtr = cusparseIndexBase_t(CUSPARSEAPI *)(const cusparseMatDescr_t); - static auto func_ptr = LoadSymbol("cusparseGetMatIndexBase"); - if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(descrA); -} - cusparseStatus_t CUSPARSEAPI cusparseCreateSolveAnalysisInfo(cusparseSolveAnalysisInfo_t *info) { using FuncPtr = diff --git a/tensorflow/stream_executor/cuda/cusparse_10_2.inc b/tensorflow/stream_executor/cuda/cusparse_10_2.inc new file mode 100644 index 00000000000..c63300697fe --- /dev/null +++ b/tensorflow/stream_executor/cuda/cusparse_10_2.inc @@ -0,0 +1,8226 @@ +// Auto-generated, do not edit. + +extern "C" { + +cusparseStatus_t CUSPARSEAPI cusparseCreate(cusparseHandle_t *handle) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t *); + static auto func_ptr = LoadSymbol("cusparseCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroy(cusparseHandle_t handle) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t); + static auto func_ptr = LoadSymbol("cusparseDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusparseStatus_t CUSPARSEAPI cusparseGetVersion(cusparseHandle_t handle, + int *version) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int *); + static auto func_ptr = LoadSymbol("cusparseGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, version); +} + +cusparseStatus_t CUSPARSEAPI cusparseGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol("cusparseGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +const char *CUSPARSEAPI cusparseGetErrorName(cusparseStatus_t status) { + using FuncPtr = const char *(CUSPARSEAPI *)(cusparseStatus_t); + static auto func_ptr = LoadSymbol("cusparseGetErrorName"); + if (!func_ptr) return "cusparseGetErrorName symbol not found."; + return func_ptr(status); +} + +const char *CUSPARSEAPI cusparseGetErrorString(cusparseStatus_t status) { + using FuncPtr = const char *(CUSPARSEAPI *)(cusparseStatus_t); + static auto func_ptr = LoadSymbol("cusparseGetErrorString"); + if (!func_ptr) return "cusparseGetErrorString symbol not found."; + return func_ptr(status); +} + +cusparseStatus_t CUSPARSEAPI cusparseSetStream(cusparseHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cusparseSetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusparseStatus_t CUSPARSEAPI cusparseGetStream(cusparseHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, cudaStream_t *); + static auto func_ptr = LoadSymbol("cusparseGetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusparseStatus_t CUSPARSEAPI +cusparseGetPointerMode(cusparseHandle_t handle, cusparsePointerMode_t *mode) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, + cusparsePointerMode_t *); + static auto func_ptr = LoadSymbol("cusparseGetPointerMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSetPointerMode(cusparseHandle_t handle, cusparsePointerMode_t mode) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, cusparsePointerMode_t); + static auto func_ptr = LoadSymbol("cusparseSetPointerMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateMatDescr(cusparseMatDescr_t *descrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t *); + static auto func_ptr = LoadSymbol("cusparseCreateMatDescr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroyMatDescr(cusparseMatDescr_t descrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t); + static auto func_ptr = LoadSymbol("cusparseDestroyMatDescr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCopyMatDescr(cusparseMatDescr_t dest, const cusparseMatDescr_t src) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, + const cusparseMatDescr_t); + static auto func_ptr = LoadSymbol("cusparseCopyMatDescr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dest, src); +} + +cusparseStatus_t CUSPARSEAPI cusparseSetMatType(cusparseMatDescr_t descrA, + cusparseMatrixType_t type) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, cusparseMatrixType_t); + static auto func_ptr = LoadSymbol("cusparseSetMatType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA, type); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSetMatFillMode(cusparseMatDescr_t descrA, cusparseFillMode_t fillMode) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, cusparseFillMode_t); + static auto func_ptr = LoadSymbol("cusparseSetMatFillMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA, fillMode); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSetMatDiagType(cusparseMatDescr_t descrA, cusparseDiagType_t diagType) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, cusparseDiagType_t); + static auto func_ptr = LoadSymbol("cusparseSetMatDiagType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA, diagType); +} + +cusparseStatus_t CUSPARSEAPI cusparseSetMatIndexBase(cusparseMatDescr_t descrA, + cusparseIndexBase_t base) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseSetMatIndexBase"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA, base); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateSolveAnalysisInfo(cusparseSolveAnalysisInfo_t *info) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseSolveAnalysisInfo_t *); + static auto func_ptr = LoadSymbol("cusparseCreateSolveAnalysisInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroySolveAnalysisInfo(cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSolveAnalysisInfo_t); + static auto func_ptr = + LoadSymbol("cusparseDestroySolveAnalysisInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseGetLevelInfo(cusparseHandle_t handle, cusparseSolveAnalysisInfo_t info, + int *nlevels, int **levelPtr, int **levelInd) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseSolveAnalysisInfo_t, int *, int **, int **); + static auto func_ptr = LoadSymbol("cusparseGetLevelInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, nlevels, levelPtr, levelInd); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsrsv2Info(csrsv2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrsv2Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateCsrsv2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsrsv2Info(csrsv2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrsv2Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyCsrsv2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsric02Info(csric02Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csric02Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateCsric02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsric02Info(csric02Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csric02Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyCsric02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateBsric02Info(bsric02Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsric02Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateBsric02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyBsric02Info(bsric02Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsric02Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyBsric02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsrilu02Info(csrilu02Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrilu02Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateCsrilu02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsrilu02Info(csrilu02Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrilu02Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyCsrilu02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateBsrilu02Info(bsrilu02Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrilu02Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateBsrilu02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyBsrilu02Info(bsrilu02Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrilu02Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyBsrilu02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateBsrsv2Info(bsrsv2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrsv2Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateBsrsv2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyBsrsv2Info(bsrsv2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrsv2Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyBsrsv2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateBsrsm2Info(bsrsm2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrsm2Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateBsrsm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyBsrsm2Info(bsrsm2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrsm2Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyBsrsm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateHybMat(cusparseHybMat_t *hybA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHybMat_t *); + static auto func_ptr = LoadSymbol("cusparseCreateHybMat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hybA); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyHybMat(cusparseHybMat_t hybA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHybMat_t); + static auto func_ptr = LoadSymbol("cusparseDestroyHybMat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hybA); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsru2csrInfo(csru2csrInfo_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csru2csrInfo_t *); + static auto func_ptr = LoadSymbol("cusparseCreateCsru2csrInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsru2csrInfo(csru2csrInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csru2csrInfo_t); + static auto func_ptr = LoadSymbol("cusparseDestroyCsru2csrInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateColorInfo(cusparseColorInfo_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseColorInfo_t *); + static auto func_ptr = LoadSymbol("cusparseCreateColorInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroyColorInfo(cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseColorInfo_t); + static auto func_ptr = LoadSymbol("cusparseDestroyColorInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseSetColorAlgs(cusparseColorInfo_t info, + cusparseColorAlg_t alg) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseColorInfo_t, cusparseColorAlg_t); + static auto func_ptr = LoadSymbol("cusparseSetColorAlgs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, alg); +} + +cusparseStatus_t CUSPARSEAPI cusparseGetColorAlgs(cusparseColorInfo_t info, + cusparseColorAlg_t *alg) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseColorInfo_t, + cusparseColorAlg_t *); + static auto func_ptr = LoadSymbol("cusparseGetColorAlgs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, alg); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreatePruneInfo(pruneInfo_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(pruneInfo_t *); + static auto func_ptr = LoadSymbol("cusparseCreatePruneInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyPruneInfo(pruneInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(pruneInfo_t); + static auto func_ptr = LoadSymbol("cusparseDestroyPruneInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseSaxpyi(cusparseHandle_t handle, int nnz, + const float *alpha, + const float *xVal, const int *xInd, + float *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, const float *, const int *, float *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseSaxpyi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, alpha, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDaxpyi(cusparseHandle_t handle, int nnz, + const double *alpha, + const double *xVal, const int *xInd, + double *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const double *, const int *, + double *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseDaxpyi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, alpha, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCaxpyi(cusparseHandle_t handle, int nnz, + const cuComplex *alpha, + const cuComplex *xVal, + const int *xInd, cuComplex *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const cuComplex *, const int *, + cuComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseCaxpyi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, alpha, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZaxpyi(cusparseHandle_t handle, int nnz, + const cuDoubleComplex *alpha, + const cuDoubleComplex *xVal, + const int *xInd, cuDoubleComplex *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + const int *, cuDoubleComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseZaxpyi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, alpha, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSdoti(cusparseHandle_t handle, int nnz, + const float *xVal, const int *xInd, + const float *y, + float *resultDevHostPtr, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, const int *, const float *, float *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseSdoti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, resultDevHostPtr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDdoti(cusparseHandle_t handle, int nnz, + const double *xVal, const int *xInd, + const double *y, + double *resultDevHostPtr, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const int *, const double *, + double *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseDdoti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, resultDevHostPtr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCdoti(cusparseHandle_t handle, int nnz, + const cuComplex *xVal, + const int *xInd, const cuComplex *y, + cuComplex *resultDevHostPtr, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const int *, const cuComplex *, + cuComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseCdoti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, resultDevHostPtr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZdoti(cusparseHandle_t handle, int nnz, + const cuDoubleComplex *xVal, + const int *xInd, + const cuDoubleComplex *y, + cuDoubleComplex *resultDevHostPtr, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const int *, + const cuDoubleComplex *, cuDoubleComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseZdoti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, resultDevHostPtr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCdotci(cusparseHandle_t handle, int nnz, + const cuComplex *xVal, + const int *xInd, const cuComplex *y, + cuComplex *resultDevHostPtr, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const int *, const cuComplex *, + cuComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseCdotci"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, resultDevHostPtr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZdotci(cusparseHandle_t handle, int nnz, + const cuDoubleComplex *xVal, + const int *xInd, + const cuDoubleComplex *y, + cuDoubleComplex *resultDevHostPtr, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const int *, + const cuDoubleComplex *, cuDoubleComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseZdotci"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, resultDevHostPtr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgthr(cusparseHandle_t handle, int nnz, + const float *y, float *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, float *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseSgthr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgthr(cusparseHandle_t handle, int nnz, + const double *y, double *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, double *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseDgthr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgthr(cusparseHandle_t handle, int nnz, + const cuComplex *y, cuComplex *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, cuComplex *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseCgthr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgthr(cusparseHandle_t handle, int nnz, + const cuDoubleComplex *y, + cuDoubleComplex *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, cuDoubleComplex *, + const int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseZgthr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgthrz(cusparseHandle_t handle, int nnz, + float *y, float *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, float *, float *, + const int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseSgthrz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgthrz(cusparseHandle_t handle, int nnz, + double *y, double *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, double *, double *, + const int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseDgthrz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgthrz(cusparseHandle_t handle, int nnz, + cuComplex *y, cuComplex *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cuComplex *, cuComplex *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseCgthrz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgthrz(cusparseHandle_t handle, int nnz, + cuDoubleComplex *y, + cuDoubleComplex *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cuDoubleComplex *, cuDoubleComplex *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseZgthrz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSsctr(cusparseHandle_t handle, int nnz, + const float *xVal, const int *xInd, + float *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, + const float *, const int *, + float *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseSsctr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDsctr(cusparseHandle_t handle, int nnz, + const double *xVal, const int *xInd, + double *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const int *, double *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseDsctr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsctr(cusparseHandle_t handle, int nnz, + const cuComplex *xVal, + const int *xInd, cuComplex *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const int *, cuComplex *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseCsctr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZsctr(cusparseHandle_t handle, int nnz, + const cuDoubleComplex *xVal, + const int *xInd, cuDoubleComplex *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const int *, + cuDoubleComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseZsctr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSroti(cusparseHandle_t handle, int nnz, + float *xVal, const int *xInd, + float *y, const float *c, + const float *s, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, float *, const int *, float *, const float *, + const float *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseSroti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, c, s, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDroti(cusparseHandle_t handle, int nnz, + double *xVal, const int *xInd, + double *y, const double *c, + const double *s, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, double *, const int *, double *, const double *, + const double *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseDroti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, c, s, idxBase); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSgemvi(cusparseHandle_t handle, cusparseOperation_t transA, int m, + int n, const float *alpha, const float *A, int lda, int nnz, + const float *xVal, const int *xInd, const float *beta, float *y, + cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const float *, + const float *, int, int, const float *, const int *, const float *, + float *, cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol("cusparseSgemvi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, A, lda, nnz, xVal, xInd, beta, y, + idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSgemvi_bufferSize(cusparseHandle_t handle, cusparseOperation_t transA, + int m, int n, int nnz, int *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseSgemvi_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDgemvi(cusparseHandle_t handle, cusparseOperation_t transA, int m, + int n, const double *alpha, const double *A, int lda, int nnz, + const double *xVal, const int *xInd, const double *beta, + double *y, cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const double *, + const double *, int, int, const double *, const int *, const double *, + double *, cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol("cusparseDgemvi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, A, lda, nnz, xVal, xInd, beta, y, + idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDgemvi_bufferSize(cusparseHandle_t handle, cusparseOperation_t transA, + int m, int n, int nnz, int *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseDgemvi_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgemvi( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, + const cuComplex *alpha, const cuComplex *A, int lda, int nnz, + const cuComplex *xVal, const int *xInd, const cuComplex *beta, cuComplex *y, + cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuComplex *, + const cuComplex *, int, int, const cuComplex *, const int *, + const cuComplex *, cuComplex *, cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol("cusparseCgemvi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, A, lda, nnz, xVal, xInd, beta, y, + idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCgemvi_bufferSize(cusparseHandle_t handle, cusparseOperation_t transA, + int m, int n, int nnz, int *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseCgemvi_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgemvi( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, + const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, int nnz, + const cuDoubleComplex *xVal, const int *xInd, const cuDoubleComplex *beta, + cuDoubleComplex *y, cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, int, const cuDoubleComplex *, const int *, + const cuDoubleComplex *, cuDoubleComplex *, cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol("cusparseZgemvi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, A, lda, nnz, xVal, xInd, beta, y, + idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZgemvi_bufferSize(cusparseHandle_t handle, cusparseOperation_t transA, + int m, int n, int nnz, int *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseZgemvi_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrmv( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, int nnz, + const float *alpha, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const float *x, const float *beta, float *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + const float *, const float *, float *); + static auto func_ptr = LoadSymbol("cusparseScsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDcsrmv(cusparseHandle_t handle, cusparseOperation_t transA, int m, + int n, int nnz, const double *alpha, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *x, const double *beta, double *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + const double *, const double *, double *); + static auto func_ptr = LoadSymbol("cusparseDcsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCcsrmv(cusparseHandle_t handle, cusparseOperation_t transA, int m, + int n, int nnz, const cuComplex *alpha, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cuComplex *x, const cuComplex *beta, cuComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + const cuComplex *, const cuComplex *, cuComplex *); + static auto func_ptr = LoadSymbol("cusparseCcsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrmv( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *x, + const cuDoubleComplex *beta, cuDoubleComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, + const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZcsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrmvEx_bufferSize( + cusparseHandle_t handle, cusparseAlgMode_t alg, cusparseOperation_t transA, + int m, int n, int nnz, const void *alpha, cudaDataType alphatype, + const cusparseMatDescr_t descrA, const void *csrValA, + cudaDataType csrValAtype, const int *csrRowPtrA, const int *csrColIndA, + const void *x, cudaDataType xtype, const void *beta, cudaDataType betatype, + void *y, cudaDataType ytype, cudaDataType executiontype, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseAlgMode_t, cusparseOperation_t, int, int, int, + const void *, cudaDataType, const cusparseMatDescr_t, const void *, + cudaDataType, const int *, const int *, const void *, cudaDataType, + const void *, cudaDataType, void *, cudaDataType, cudaDataType, size_t *); + static auto func_ptr = LoadSymbol("cusparseCsrmvEx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, alg, transA, m, n, nnz, alpha, alphatype, descrA, + csrValA, csrValAtype, csrRowPtrA, csrColIndA, x, xtype, beta, + betatype, y, ytype, executiontype, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrmvEx( + cusparseHandle_t handle, cusparseAlgMode_t alg, cusparseOperation_t transA, + int m, int n, int nnz, const void *alpha, cudaDataType alphatype, + const cusparseMatDescr_t descrA, const void *csrValA, + cudaDataType csrValAtype, const int *csrRowPtrA, const int *csrColIndA, + const void *x, cudaDataType xtype, const void *beta, cudaDataType betatype, + void *y, cudaDataType ytype, cudaDataType executiontype, void *buffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseAlgMode_t, cusparseOperation_t, int, int, int, + const void *, cudaDataType, const cusparseMatDescr_t, const void *, + cudaDataType, const int *, const int *, const void *, cudaDataType, + const void *, cudaDataType, void *, cudaDataType, cudaDataType, void *); + static auto func_ptr = LoadSymbol("cusparseCsrmvEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, alg, transA, m, n, nnz, alpha, alphatype, descrA, + csrValA, csrValAtype, csrRowPtrA, csrColIndA, x, xtype, beta, + betatype, y, ytype, executiontype, buffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrmv_mp( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, int nnz, + const float *alpha, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const float *x, const float *beta, float *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + const float *, const float *, float *); + static auto func_ptr = LoadSymbol("cusparseScsrmv_mp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDcsrmv_mp(cusparseHandle_t handle, cusparseOperation_t transA, int m, + int n, int nnz, const double *alpha, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *x, const double *beta, double *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + const double *, const double *, double *); + static auto func_ptr = LoadSymbol("cusparseDcsrmv_mp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrmv_mp( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuComplex *x, const cuComplex *beta, + cuComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + const cuComplex *, const cuComplex *, cuComplex *); + static auto func_ptr = LoadSymbol("cusparseCcsrmv_mp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrmv_mp( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *x, + const cuDoubleComplex *beta, cuDoubleComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, + const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZcsrmv_mp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseShybmv( + cusparseHandle_t handle, cusparseOperation_t transA, const float *alpha, + const cusparseMatDescr_t descrA, const cusparseHybMat_t hybA, + const float *x, const float *beta, float *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const float *, + const cusparseMatDescr_t, const cusparseHybMat_t, const float *, + const float *, float *); + static auto func_ptr = LoadSymbol("cusparseShybmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, alpha, descrA, hybA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseDhybmv( + cusparseHandle_t handle, cusparseOperation_t transA, const double *alpha, + const cusparseMatDescr_t descrA, const cusparseHybMat_t hybA, + const double *x, const double *beta, double *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const double *, + const cusparseMatDescr_t, const cusparseHybMat_t, const double *, + const double *, double *); + static auto func_ptr = LoadSymbol("cusparseDhybmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, alpha, descrA, hybA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseChybmv( + cusparseHandle_t handle, cusparseOperation_t transA, const cuComplex *alpha, + const cusparseMatDescr_t descrA, const cusparseHybMat_t hybA, + const cuComplex *x, const cuComplex *beta, cuComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cuComplex *, + const cusparseMatDescr_t, const cusparseHybMat_t, const cuComplex *, + const cuComplex *, cuComplex *); + static auto func_ptr = LoadSymbol("cusparseChybmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, alpha, descrA, hybA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZhybmv(cusparseHandle_t handle, cusparseOperation_t transA, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, const cuDoubleComplex *x, + const cuDoubleComplex *beta, cuDoubleComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cuDoubleComplex *, + const cusparseMatDescr_t, const cusparseHybMat_t, const cuDoubleComplex *, + const cuDoubleComplex *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZhybmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, alpha, descrA, hybA, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nb, int nnzb, const float *alpha, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const float *x, const float *beta, float *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, int, const float *, const float *, float *); + static auto func_ptr = LoadSymbol("cusparseSbsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockDim, + x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nb, int nnzb, const double *alpha, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const double *x, const double *beta, double *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + const double *, const cusparseMatDescr_t, const double *, const int *, + const int *, int, const double *, const double *, double *); + static auto func_ptr = LoadSymbol("cusparseDbsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockDim, + x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCbsrmv(cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nb, int nnzb, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, int blockDim, const cuComplex *x, + const cuComplex *beta, cuComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, int, const cuComplex *, const cuComplex *, + cuComplex *); + static auto func_ptr = LoadSymbol("cusparseCbsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockDim, + x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nb, int nnzb, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, int blockDim, const cuDoubleComplex *x, + const cuDoubleComplex *beta, cuDoubleComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZbsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockDim, + x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSbsrxmv(cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int sizeOfMask, int mb, int nb, + int nnzb, const float *alpha, const cusparseMatDescr_t descrA, + const float *bsrSortedValA, const int *bsrSortedMaskPtrA, + const int *bsrSortedRowPtrA, const int *bsrSortedEndPtrA, + const int *bsrSortedColIndA, int blockDim, const float *x, + const float *beta, float *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, const int *, const int *, int, const float *, const float *, + float *); + static auto func_ptr = LoadSymbol("cusparseSbsrxmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, sizeOfMask, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedMaskPtrA, bsrSortedRowPtrA, + bsrSortedEndPtrA, bsrSortedColIndA, blockDim, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDbsrxmv(cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int sizeOfMask, int mb, int nb, + int nnzb, const double *alpha, const cusparseMatDescr_t descrA, + const double *bsrSortedValA, const int *bsrSortedMaskPtrA, + const int *bsrSortedRowPtrA, const int *bsrSortedEndPtrA, + const int *bsrSortedColIndA, int blockDim, const double *x, + const double *beta, double *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, const int *, const int *, int, const double *, + const double *, double *); + static auto func_ptr = LoadSymbol("cusparseDbsrxmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, sizeOfMask, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedMaskPtrA, bsrSortedRowPtrA, + bsrSortedEndPtrA, bsrSortedColIndA, blockDim, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrxmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int sizeOfMask, int mb, int nb, int nnzb, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *bsrSortedValA, const int *bsrSortedMaskPtrA, + const int *bsrSortedRowPtrA, const int *bsrSortedEndPtrA, + const int *bsrSortedColIndA, int blockDim, const cuComplex *x, + const cuComplex *beta, cuComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const int *, const int *, int, + const cuComplex *, const cuComplex *, cuComplex *); + static auto func_ptr = LoadSymbol("cusparseCbsrxmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, sizeOfMask, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedMaskPtrA, bsrSortedRowPtrA, + bsrSortedEndPtrA, bsrSortedColIndA, blockDim, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrxmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int sizeOfMask, int mb, int nb, int nnzb, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *bsrSortedValA, const int *bsrSortedMaskPtrA, + const int *bsrSortedRowPtrA, const int *bsrSortedEndPtrA, + const int *bsrSortedColIndA, int blockDim, const cuDoubleComplex *x, + const cuDoubleComplex *beta, cuDoubleComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, const int *, + const int *, int, const cuDoubleComplex *, const cuDoubleComplex *, + cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZbsrxmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, sizeOfMask, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedMaskPtrA, bsrSortedRowPtrA, + bsrSortedEndPtrA, bsrSortedColIndA, blockDim, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrsv_analysisEx( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const void *csrSortedValA, + cudaDataType csrSortedValAtype, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + cudaDataType executiontype) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const void *, cudaDataType, const int *, const int *, + cusparseSolveAnalysisInfo_t, cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCsrsv_analysisEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedValAtype, csrSortedRowPtrA, csrSortedColIndA, info, + executiontype); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseScsrsv_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseDcsrsv_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseCcsrsv_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseZcsrsv_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrsv_solveEx( + cusparseHandle_t handle, cusparseOperation_t transA, int m, + const void *alpha, cudaDataType alphatype, const cusparseMatDescr_t descrA, + const void *csrSortedValA, cudaDataType csrSortedValAtype, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info, const void *f, cudaDataType ftype, + void *x, cudaDataType xtype, cudaDataType executiontype) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const void *, cudaDataType, + const cusparseMatDescr_t, const void *, cudaDataType, const int *, + const int *, cusparseSolveAnalysisInfo_t, const void *, cudaDataType, + void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCsrsv_solveEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, alpha, alphatype, descrA, csrSortedValA, + csrSortedValAtype, csrSortedRowPtrA, csrSortedColIndA, info, + f, ftype, x, xtype, executiontype); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, + const float *alpha, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + const float *f, float *x) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + cusparseSolveAnalysisInfo_t, const float *, float *); + static auto func_ptr = LoadSymbol("cusparseScsrsv_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, + const double *alpha, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + const double *f, double *x) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + cusparseSolveAnalysisInfo_t, const double *, double *); + static auto func_ptr = LoadSymbol("cusparseDcsrsv_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + const cuComplex *f, cuComplex *x) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + cusparseSolveAnalysisInfo_t, const cuComplex *, cuComplex *); + static auto func_ptr = LoadSymbol("cusparseCcsrsv_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + const cuDoubleComplex *f, cuDoubleComplex *x) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cuDoubleComplex *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, cusparseSolveAnalysisInfo_t, const cuDoubleComplex *, + cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZcsrsv_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrsv2_zeroPivot(cusparseHandle_t handle, + csrsv2Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseXcsrsv2_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv2_bufferSize( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseScsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv2_bufferSize( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseDcsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv2_bufferSize( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseCcsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv2_bufferSize( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseZcsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, csrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseScsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, csrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDcsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, csrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCcsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, csrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseZcsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv2_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, csrsv2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv2_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, csrsv2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv2_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, csrsv2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv2_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, csrsv2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv2_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const float *alpha, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrsv2Info_t info, const float *f, float *x, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + csrsv2Info_t, const float *, float *, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv2_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const double *alpha, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrsv2Info_t info, const double *f, double *x, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + csrsv2Info_t, const double *, double *, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv2_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrsv2Info_t info, const cuComplex *f, + cuComplex *x, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + csrsv2Info_t, const cuComplex *, cuComplex *, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol("cusparseCcsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv2_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrsv2Info_t info, const cuDoubleComplex *f, + cuDoubleComplex *x, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, csrsv2Info_t, const cuDoubleComplex *, cuDoubleComplex *, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXbsrsv2_zeroPivot(cusparseHandle_t handle, + bsrsv2Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseXbsrsv2_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsv2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, float *, const int *, const int *, int, + bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseSbsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsv2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, double *, const int *, const int *, int, + bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseDbsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsv2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, cuComplex *, const int *, const int *, int, + bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseCbsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsv2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, cuDoubleComplex *, const int *, const int *, + int, bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseZbsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockSize, + bsrsv2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, float *, const int *, const int *, int, + bsrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseSbsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockSize, info, + pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockSize, + bsrsv2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, double *, const int *, const int *, int, + bsrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDbsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockSize, info, + pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockSize, + bsrsv2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, cuComplex *, const int *, const int *, int, + bsrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCbsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockSize, info, + pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockSize, + bsrsv2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, cuDoubleComplex *, const int *, const int *, + int, bsrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseZbsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockSize, info, + pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsv2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + bsrsv2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseSbsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsv2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + bsrsv2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDbsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsv2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, bsrsv2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCbsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsv2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, bsrsv2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZbsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsv2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, const float *alpha, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, const float *f, float *x, cusparseSolvePolicy_t policy, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, int, bsrsv2Info_t, const float *, float *, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseSbsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, alpha, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, f, x, + policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsv2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, const double *alpha, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, const double *f, double *x, cusparseSolvePolicy_t policy, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const double *, const cusparseMatDescr_t, const double *, const int *, + const int *, int, bsrsv2Info_t, const double *, double *, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDbsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, alpha, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, f, x, + policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsv2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, const cuComplex *alpha, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, const cuComplex *f, cuComplex *x, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, int, bsrsv2Info_t, const cuComplex *, + cuComplex *, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCbsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, alpha, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, f, x, + policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsv2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, const cuDoubleComplex *f, cuDoubleComplex *x, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, bsrsv2Info_t, + const cuDoubleComplex *, cuDoubleComplex *, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol("cusparseZbsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, alpha, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, f, x, + policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseShybsv_analysis(cusparseHandle_t handle, cusparseOperation_t transA, + const cusparseMatDescr_t descrA, cusparseHybMat_t hybA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cusparseMatDescr_t, + cusparseHybMat_t, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseShybsv_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, descrA, hybA, info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDhybsv_analysis(cusparseHandle_t handle, cusparseOperation_t transA, + const cusparseMatDescr_t descrA, cusparseHybMat_t hybA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cusparseMatDescr_t, + cusparseHybMat_t, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseDhybsv_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, descrA, hybA, info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseChybsv_analysis(cusparseHandle_t handle, cusparseOperation_t transA, + const cusparseMatDescr_t descrA, cusparseHybMat_t hybA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cusparseMatDescr_t, + cusparseHybMat_t, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseChybsv_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, descrA, hybA, info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZhybsv_analysis(cusparseHandle_t handle, cusparseOperation_t transA, + const cusparseMatDescr_t descrA, cusparseHybMat_t hybA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cusparseMatDescr_t, + cusparseHybMat_t, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseZhybsv_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, descrA, hybA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseShybsv_solve( + cusparseHandle_t handle, cusparseOperation_t trans, const float *alpha, + const cusparseMatDescr_t descrA, const cusparseHybMat_t hybA, + cusparseSolveAnalysisInfo_t info, const float *f, float *x) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const float *, + const cusparseMatDescr_t, const cusparseHybMat_t, + cusparseSolveAnalysisInfo_t, const float *, float *); + static auto func_ptr = LoadSymbol("cusparseShybsv_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, alpha, descrA, hybA, info, f, x); +} + +cusparseStatus_t CUSPARSEAPI cusparseChybsv_solve( + cusparseHandle_t handle, cusparseOperation_t trans, const cuComplex *alpha, + const cusparseMatDescr_t descrA, const cusparseHybMat_t hybA, + cusparseSolveAnalysisInfo_t info, const cuComplex *f, cuComplex *x) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cuComplex *, + const cusparseMatDescr_t, const cusparseHybMat_t, + cusparseSolveAnalysisInfo_t, const cuComplex *, cuComplex *); + static auto func_ptr = LoadSymbol("cusparseChybsv_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, alpha, descrA, hybA, info, f, x); +} + +cusparseStatus_t CUSPARSEAPI cusparseDhybsv_solve( + cusparseHandle_t handle, cusparseOperation_t trans, const double *alpha, + const cusparseMatDescr_t descrA, const cusparseHybMat_t hybA, + cusparseSolveAnalysisInfo_t info, const double *f, double *x) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const double *, + const cusparseMatDescr_t, const cusparseHybMat_t, + cusparseSolveAnalysisInfo_t, const double *, double *); + static auto func_ptr = LoadSymbol("cusparseDhybsv_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, alpha, descrA, hybA, info, f, x); +} + +cusparseStatus_t CUSPARSEAPI cusparseZhybsv_solve( + cusparseHandle_t handle, cusparseOperation_t trans, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, cusparseSolveAnalysisInfo_t info, + const cuDoubleComplex *f, cuDoubleComplex *x) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cuDoubleComplex *, + const cusparseMatDescr_t, const cusparseHybMat_t, + cusparseSolveAnalysisInfo_t, const cuDoubleComplex *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZhybsv_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, alpha, descrA, hybA, info, f, x); +} + +cusparseStatus_t CUSPARSEAPI +cusparseScsrmm(cusparseHandle_t handle, cusparseOperation_t transA, int m, + int n, int k, int nnz, const float *alpha, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const float *B, int ldb, const float *beta, float *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cusparseScsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, k, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrmm( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, int k, + int nnz, const double *alpha, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const double *B, int ldb, const double *beta, + double *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cusparseDcsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, k, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrmm( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, int k, + int nnz, const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuComplex *B, int ldb, + const cuComplex *beta, cuComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int, + const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const cuComplex *, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseCcsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, k, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrmm( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, int k, + int nnz, const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, cuDoubleComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int, + const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol("cusparseZcsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, k, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI +cusparseScsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, int nnz, + const float *alpha, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const float *B, int ldb, + const float *beta, float *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cusparseScsrmm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, int nnz, + const double *alpha, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const double *B, int ldb, + const double *beta, double *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, int, const double *, double *, + int); + static auto func_ptr = LoadSymbol("cusparseDcsrmm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuComplex *B, int ldb, + const cuComplex *beta, cuComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const cuComplex *, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseCcsrmm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrmm2( + cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, cuDoubleComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol("cusparseZcsrmm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrmm( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int kb, int nnzb, const float *alpha, const cusparseMatDescr_t descrA, + const float *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, const int blockSize, const float *B, + const int ldb, const float *beta, float *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + const int, const float *, const int, const float *, float *, int); + static auto func_ptr = LoadSymbol("cusparseSbsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, kb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockSize, + B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrmm( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int kb, int nnzb, const double *alpha, const cusparseMatDescr_t descrA, + const double *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, const int blockSize, const double *B, + const int ldb, const double *beta, double *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + const int, const double *, const int, const double *, double *, int); + static auto func_ptr = LoadSymbol("cusparseDbsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, kb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockSize, + B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrmm( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int kb, int nnzb, const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, const int blockSize, const cuComplex *B, + const int ldb, const cuComplex *beta, cuComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + const int, const cuComplex *, const int, const cuComplex *, cuComplex *, + int); + static auto func_ptr = LoadSymbol("cusparseCbsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, kb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockSize, + B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrmm( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int kb, int nnzb, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, + const int blockSize, const cuDoubleComplex *B, const int ldb, + const cuDoubleComplex *beta, cuDoubleComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, const int, const cuDoubleComplex *, const int, + const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cusparseZbsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, kb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockSize, + B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgemmi( + cusparseHandle_t handle, int m, int n, int k, int nnz, const float *alpha, + const float *A, int lda, const float *cscValB, const int *cscColPtrB, + const int *cscRowIndB, const float *beta, float *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int, const float *, const float *, int, + const float *, const int *, const int *, const float *, float *, int); + static auto func_ptr = LoadSymbol("cusparseSgemmi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, + cscRowIndB, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgemmi( + cusparseHandle_t handle, int m, int n, int k, int nnz, const double *alpha, + const double *A, int lda, const double *cscValB, const int *cscColPtrB, + const int *cscRowIndB, const double *beta, double *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int, const double *, const double *, int, + const double *, const int *, const int *, const double *, double *, int); + static auto func_ptr = LoadSymbol("cusparseDgemmi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, + cscRowIndB, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgemmi( + cusparseHandle_t handle, int m, int n, int k, int nnz, + const cuComplex *alpha, const cuComplex *A, int lda, + const cuComplex *cscValB, const int *cscColPtrB, const int *cscRowIndB, + const cuComplex *beta, cuComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, const int *, const int *, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseCgemmi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, + cscRowIndB, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZgemmi(cusparseHandle_t handle, int m, int n, int k, int nnz, + const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *cscValB, const int *cscColPtrB, + const int *cscRowIndB, const cuDoubleComplex *beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, const int *, + const int *, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cusparseZgemmi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, + cscRowIndB, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsm_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseScsrsm_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsm_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseDcsrsm_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsm_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseCcsrsm_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsm_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseZcsrsm_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsm_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, + const float *alpha, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + const float *B, int ldb, float *X, int ldx) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + cusparseSolveAnalysisInfo_t, const float *, int, float *, int); + static auto func_ptr = LoadSymbol("cusparseScsrsm_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, B, ldb, X, ldx); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsm_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, + const double *alpha, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + const double *B, int ldb, double *X, int ldx) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + cusparseSolveAnalysisInfo_t, const double *, int, double *, int); + static auto func_ptr = LoadSymbol("cusparseDcsrsm_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, B, ldb, X, ldx); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsm_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + const cuComplex *B, int ldb, cuComplex *X, int ldx) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + cusparseSolveAnalysisInfo_t, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseCcsrsm_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, B, ldb, X, ldx); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsm_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + const cuDoubleComplex *B, int ldb, cuDoubleComplex *X, int ldx) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, cusparseSolveAnalysisInfo_t, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cusparseZcsrsm_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, B, ldb, X, ldx); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsrsm2Info(csrsm2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrsm2Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateCsrsm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsrsm2Info(csrsm2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrsm2Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyCsrsm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrsm2_zeroPivot(cusparseHandle_t handle, + csrsm2Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, csrsm2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseXcsrsm2_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsm2_bufferSizeExt( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const float *alpha, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, const float *, int, csrsm2Info_t, cusparseSolvePolicy_t, + size_t *); + static auto func_ptr = LoadSymbol("cusparseScsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsm2_bufferSizeExt( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const double *alpha, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const double *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, int, csrsm2Info_t, + cusparseSolvePolicy_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDcsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsm2_bufferSizeExt( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuComplex *B, int ldb, csrsm2Info_t info, + cusparseSolvePolicy_t policy, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const cuComplex *, int, csrsm2Info_t, + cusparseSolvePolicy_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCcsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsm2_bufferSizeExt( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *B, int ldb, + csrsm2Info_t info, cusparseSolvePolicy_t policy, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cuDoubleComplex *, int, csrsm2Info_t, cusparseSolvePolicy_t, + size_t *); + static auto func_ptr = LoadSymbol("cusparseZcsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsm2_analysis( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const float *alpha, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, const float *, int, csrsm2Info_t, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol("cusparseScsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsm2_analysis( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const double *alpha, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const double *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, int, csrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsm2_analysis( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuComplex *B, int ldb, csrsm2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const cuComplex *, int, csrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsm2_analysis( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *B, int ldb, + csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cuDoubleComplex *, int, csrsm2Info_t, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol("cusparseZcsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsm2_solve( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const float *alpha, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float *B, int ldb, + csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, float *, int, csrsm2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsm2_solve( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const double *alpha, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, double *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, double *, int, csrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsm2_solve( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cuComplex *B, int ldb, csrsm2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, cuComplex *, int, csrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsm2_solve( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cuDoubleComplex *B, int ldb, csrsm2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, cuDoubleComplex *, int, + csrsm2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXbsrsm2_zeroPivot(cusparseHandle_t handle, + bsrsm2Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseXbsrsm2_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsm2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, int, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseSbsrsm2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsm2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, int, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseDbsrsm2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsm2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, int, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseCbsrsm2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsm2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseZbsrsm2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsm2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, int, bsrsm2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseSbsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsm2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, int, bsrsm2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDbsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsm2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, int, bsrsm2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCbsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsm2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrsm2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseZbsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsm2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, const float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, bsrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseSbsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsm2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, const double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, bsrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDbsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsm2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, const cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, bsrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCbsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsm2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, bsrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZbsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsm2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const float *alpha, const cusparseMatDescr_t descrA, + const float *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + const float *B, int ldb, float *X, int ldx, cusparseSolvePolicy_t policy, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + bsrsm2Info_t, const float *, int, float *, int, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol("cusparseSbsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, alpha, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, B, ldb, X, ldx, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsm2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const double *alpha, const cusparseMatDescr_t descrA, + const double *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + const double *B, int ldb, double *X, int ldx, cusparseSolvePolicy_t policy, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + bsrsm2Info_t, const double *, int, double *, int, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol("cusparseDbsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, alpha, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, B, ldb, X, ldx, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsm2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + const cuComplex *B, int ldb, cuComplex *X, int ldx, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, bsrsm2Info_t, const cuComplex *, int, cuComplex *, int, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCbsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, alpha, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, B, ldb, X, ldx, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsm2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + const cuDoubleComplex *B, int ldb, cuDoubleComplex *X, int ldx, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, bsrsm2Info_t, const cuDoubleComplex *, int, + cuDoubleComplex *, int, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZbsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, alpha, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, B, ldb, X, ldx, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrilu0Ex( + cusparseHandle_t handle, cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, void *csrSortedValA_ValM, + cudaDataType csrSortedValA_ValMtype, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseSolveAnalysisInfo_t info, + cudaDataType executiontype) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + void *, cudaDataType, const int *, const int *, + cusparseSolveAnalysisInfo_t, cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCsrilu0Ex"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedValA_ValMtype, csrSortedRowPtrA, csrSortedColIndA, + info, executiontype); +} + +cusparseStatus_t CUSPARSEAPI +cusparseScsrilu0(cusparseHandle_t handle, cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, float *csrSortedValA_ValM, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + float *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseScsrilu0"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDcsrilu0(cusparseHandle_t handle, cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, double *csrSortedValA_ValM, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + double *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseDcsrilu0"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCcsrilu0(cusparseHandle_t handle, cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, cuComplex *csrSortedValA_ValM, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseCcsrilu0"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu0( + cusparseHandle_t handle, cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrSortedValA_ValM, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseZcsrilu0"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_numericBoost( + cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double *tol, + float *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, csrilu02Info_t, int, double *, float *); + static auto func_ptr = LoadSymbol("cusparseScsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_numericBoost( + cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double *tol, + double *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, csrilu02Info_t, int, double *, double *); + static auto func_ptr = LoadSymbol("cusparseDcsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_numericBoost( + cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double *tol, + cuComplex *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, csrilu02Info_t, int, double *, cuComplex *); + static auto func_ptr = LoadSymbol("cusparseCcsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_numericBoost( + cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double *tol, + cuDoubleComplex *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, csrilu02Info_t, int, double *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZcsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrilu02_zeroPivot( + cusparseHandle_t handle, csrilu02Info_t info, int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseXcsrilu02_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseScsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseDcsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseCcsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseZcsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedVal, const int *csrSortedRowPtr, const int *csrSortedColInd, + csrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseScsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDcsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCcsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseZcsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, csrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02_numericBoost( + cusparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double *tol, + float *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, bsrilu02Info_t, int, double *, float *); + static auto func_ptr = LoadSymbol("cusparseSbsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02_numericBoost( + cusparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double *tol, + double *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, bsrilu02Info_t, int, double *, double *); + static auto func_ptr = LoadSymbol("cusparseDbsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02_numericBoost( + cusparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double *tol, + cuComplex *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, bsrilu02Info_t, int, double *, cuComplex *); + static auto func_ptr = LoadSymbol("cusparseCbsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02_numericBoost( + cusparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double *tol, + cuDoubleComplex *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, bsrilu02Info_t, int, double *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol("cusparseZbsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseXbsrilu02_zeroPivot( + cusparseHandle_t handle, bsrilu02Info_t info, int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseXbsrilu02_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseSbsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseDbsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseCbsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseZbsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseSbsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDbsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCbsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrilu02Info_t, + size_t *); + static auto func_ptr = LoadSymbol("cusparseZbsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseSbsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDbsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCbsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZbsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseSbsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDbsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCbsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZbsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric0(cusparseHandle_t handle, + cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, + float *csrSortedValA_ValM, + const int *csrSortedRowPtrA, + const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + float *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseScsric0"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric0(cusparseHandle_t handle, + cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, + double *csrSortedValA_ValM, + const int *csrSortedRowPtrA, + const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + double *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseDcsric0"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric0(cusparseHandle_t handle, + cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA_ValM, + const int *csrSortedRowPtrA, + const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseCcsric0"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric0( + cusparseHandle_t handle, cusparseOperation_t trans, int m, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrSortedValA_ValM, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + cusparseSolveAnalysisInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, cusparseSolveAnalysisInfo_t); + static auto func_ptr = LoadSymbol("cusparseZcsric0"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, descrA, csrSortedValA_ValM, + csrSortedRowPtrA, csrSortedColIndA, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsric02_zeroPivot(cusparseHandle_t handle, + csric02Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, csric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseXcsric02_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseScsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseDcsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseCcsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseZcsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedVal, const int *csrSortedRowPtr, const int *csrSortedColInd, + csric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csric02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseScsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csric02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDcsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csric02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCcsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csric02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseZcsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, csric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXbsric02_zeroPivot(cusparseHandle_t handle, + bsric02Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseXbsric02_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsric02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseSbsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsric02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseDbsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsric02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseCbsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsric02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol("cusparseZbsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsric02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsric02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseSbsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsric02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsric02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDbsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsric02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsric02Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCbsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsric02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsric02Info_t, + size_t *); + static auto func_ptr = LoadSymbol("cusparseZbsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsric02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pInputBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseSbsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pInputBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsric02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pInputBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDbsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pInputBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsric02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pInputBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCbsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pInputBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsric02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pInputBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZbsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pInputBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsric02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseSbsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsric02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseDbsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsric02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseCbsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsric02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol("cusparseZbsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv(cusparseHandle_t handle, int m, + int n, const float *dl, + const float *d, const float *du, + float *B, int ldb) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, int, + const float *, const float *, + const float *, float *, int); + static auto func_ptr = LoadSymbol("cusparseSgtsv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv(cusparseHandle_t handle, int m, + int n, const double *dl, + const double *d, const double *du, + double *B, int ldb) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, double *, int); + static auto func_ptr = LoadSymbol("cusparseDgtsv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv(cusparseHandle_t handle, int m, + int n, const cuComplex *dl, + const cuComplex *d, + const cuComplex *du, cuComplex *B, + int ldb) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseCgtsv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv(cusparseHandle_t handle, int m, + int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, + const cuDoubleComplex *du, + cuDoubleComplex *B, int ldb) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cusparseZgtsv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *dl, const float *d, + const float *du, const float *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + const float *, int, size_t *); + static auto func_ptr = LoadSymbol("cusparseSgtsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *dl, const double *d, + const double *du, const double *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, const double *, int, size_t *); + static auto func_ptr = LoadSymbol("cusparseDgtsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, const cuComplex *B, int ldb, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, int, size_t *); + static auto func_ptr = LoadSymbol("cusparseCgtsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, + const cuDoubleComplex *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, + int, size_t *); + static auto func_ptr = LoadSymbol("cusparseZgtsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2(cusparseHandle_t handle, int m, + int n, const float *dl, + const float *d, const float *du, + float *B, int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + float *, int, void *); + static auto func_ptr = LoadSymbol("cusparseSgtsv2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2(cusparseHandle_t handle, int m, + int n, const double *dl, + const double *d, const double *du, + double *B, int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, double *, int, void *); + static auto func_ptr = LoadSymbol("cusparseDgtsv2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2(cusparseHandle_t handle, int m, + int n, const cuComplex *dl, + const cuComplex *d, + const cuComplex *du, cuComplex *B, + int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int, void *); + static auto func_ptr = LoadSymbol("cusparseCgtsv2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2(cusparseHandle_t handle, int m, + int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, + const cuDoubleComplex *du, + cuDoubleComplex *B, int ldb, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *, int, + void *); + static auto func_ptr = LoadSymbol("cusparseZgtsv2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSgtsv_nopivot(cusparseHandle_t handle, int m, int n, const float *dl, + const float *d, const float *du, float *B, int ldb) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, int, + const float *, const float *, + const float *, float *, int); + static auto func_ptr = LoadSymbol("cusparseSgtsv_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDgtsv_nopivot(cusparseHandle_t handle, int m, int n, const double *dl, + const double *d, const double *du, double *B, int ldb) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, double *, int); + static auto func_ptr = LoadSymbol("cusparseDgtsv_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv_nopivot( + cusparseHandle_t handle, int m, int n, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, cuComplex *B, int ldb) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseCgtsv_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZgtsv_nopivot(cusparseHandle_t handle, int m, int n, + const cuDoubleComplex *dl, const cuDoubleComplex *d, + const cuDoubleComplex *du, cuDoubleComplex *B, int ldb) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cusparseZgtsv_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2_nopivot_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *dl, const float *d, + const float *du, const float *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + const float *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSgtsv2_nopivot_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2_nopivot_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *dl, const double *d, + const double *du, const double *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, const double *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDgtsv2_nopivot_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2_nopivot_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, const cuComplex *B, int ldb, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseCgtsv2_nopivot_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2_nopivot_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, + const cuDoubleComplex *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, + int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseZgtsv2_nopivot_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2_nopivot( + cusparseHandle_t handle, int m, int n, const float *dl, const float *d, + const float *du, float *B, int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + float *, int, void *); + static auto func_ptr = LoadSymbol("cusparseSgtsv2_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2_nopivot( + cusparseHandle_t handle, int m, int n, const double *dl, const double *d, + const double *du, double *B, int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, double *, int, void *); + static auto func_ptr = LoadSymbol("cusparseDgtsv2_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2_nopivot( + cusparseHandle_t handle, int m, int n, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, cuComplex *B, int ldb, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int, void *); + static auto func_ptr = LoadSymbol("cusparseCgtsv2_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2_nopivot( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, cuDoubleComplex *B, + int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *, int, + void *); + static auto func_ptr = LoadSymbol("cusparseZgtsv2_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsvStridedBatch( + cusparseHandle_t handle, int m, const float *dl, const float *d, + const float *du, float *x, int batchCount, int batchStride) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, const float *, const float *, + float *, int, int); + static auto func_ptr = LoadSymbol("cusparseSgtsvStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsvStridedBatch( + cusparseHandle_t handle, int m, const double *dl, const double *d, + const double *du, double *x, int batchCount, int batchStride) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const double *, const double *, + double *, int, int); + static auto func_ptr = LoadSymbol("cusparseDgtsvStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsvStridedBatch( + cusparseHandle_t handle, int m, const cuComplex *dl, const cuComplex *d, + const cuComplex *du, cuComplex *x, int batchCount, int batchStride) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int, int); + static auto func_ptr = LoadSymbol("cusparseCgtsvStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsvStridedBatch( + cusparseHandle_t handle, int m, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, cuDoubleComplex *x, + int batchCount, int batchStride) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + const cuDoubleComplex *, cuDoubleComplex *, int, int); + static auto func_ptr = LoadSymbol("cusparseZgtsvStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2StridedBatch_bufferSizeExt( + cusparseHandle_t handle, int m, const float *dl, const float *d, + const float *du, const float *x, int batchCount, int batchStride, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, const float *, const float *, + const float *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSgtsv2StridedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, + bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2StridedBatch_bufferSizeExt( + cusparseHandle_t handle, int m, const double *dl, const double *d, + const double *du, const double *x, int batchCount, int batchStride, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const double *, const double *, + const double *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDgtsv2StridedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, + bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2StridedBatch_bufferSizeExt( + cusparseHandle_t handle, int m, const cuComplex *dl, const cuComplex *d, + const cuComplex *du, const cuComplex *x, int batchCount, int batchStride, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseCgtsv2StridedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, + bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2StridedBatch_bufferSizeExt( + cusparseHandle_t handle, int m, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, + const cuDoubleComplex *x, int batchCount, int batchStride, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseZgtsv2StridedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, + bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2StridedBatch( + cusparseHandle_t handle, int m, const float *dl, const float *d, + const float *du, float *x, int batchCount, int batchStride, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, const float *, const float *, + float *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseSgtsv2StridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDgtsv2StridedBatch(cusparseHandle_t handle, int m, const double *dl, + const double *d, const double *du, double *x, + int batchCount, int batchStride, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const double *, const double *, + double *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseDgtsv2StridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2StridedBatch( + cusparseHandle_t handle, int m, const cuComplex *dl, const cuComplex *d, + const cuComplex *du, cuComplex *x, int batchCount, int batchStride, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseCgtsv2StridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2StridedBatch( + cusparseHandle_t handle, int m, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, cuDoubleComplex *x, + int batchCount, int batchStride, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + const cuDoubleComplex *, cuDoubleComplex *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseZgtsv2StridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const float *dl, const float *d, + const float *du, const float *x, int batchCount, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + const float *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSgtsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const double *dl, const double *d, + const double *du, const double *x, int batchCount, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, const double *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDgtsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, const cuComplex *x, int batchCount, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseCgtsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, + const cuDoubleComplex *x, int batchCount, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, + int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseZgtsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, float *dl, float *d, float *du, + float *x, int batchCount, void *pBuffer) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, int, float *, + float *, float *, float *, int, void *); + static auto func_ptr = LoadSymbol("cusparseSgtsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, double *dl, double *d, double *du, + double *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, int, + double *, double *, double *, + double *, int, void *); + static auto func_ptr = LoadSymbol("cusparseDgtsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, cuComplex *dl, cuComplex *d, + cuComplex *du, cuComplex *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, cuComplex *, cuComplex *, cuComplex *, + cuComplex *, int, void *); + static auto func_ptr = LoadSymbol("cusparseCgtsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, cuDoubleComplex *dl, + cuDoubleComplex *d, cuDoubleComplex *du, cuDoubleComplex *x, int batchCount, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, cuDoubleComplex *, cuDoubleComplex *, + cuDoubleComplex *, cuDoubleComplex *, int, void *); + static auto func_ptr = LoadSymbol("cusparseZgtsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgpsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const float *ds, const float *dl, + const float *d, const float *du, const float *dw, const float *x, + int batchCount, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + const float *, const float *, const float *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSgpsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgpsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const double *ds, + const double *dl, const double *d, const double *du, const double *dw, + const double *x, int batchCount, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, const double *, const double *, const double *, int, + size_t *); + static auto func_ptr = + LoadSymbol("cusparseDgpsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgpsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const cuComplex *ds, + const cuComplex *dl, const cuComplex *d, const cuComplex *du, + const cuComplex *dw, const cuComplex *x, int batchCount, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, const cuComplex *, + const cuComplex *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseCgpsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgpsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const cuDoubleComplex *ds, + const cuDoubleComplex *dl, const cuDoubleComplex *d, + const cuDoubleComplex *du, const cuDoubleComplex *dw, + const cuDoubleComplex *x, int batchCount, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseZgpsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgpsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, float *ds, float *dl, float *d, + float *du, float *dw, float *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, float *, float *, float *, float *, float *, + float *, int, void *); + static auto func_ptr = LoadSymbol("cusparseSgpsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgpsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, double *ds, double *dl, double *d, + double *du, double *dw, double *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, double *, double *, double *, double *, + double *, double *, int, void *); + static auto func_ptr = LoadSymbol("cusparseDgpsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgpsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, cuComplex *ds, cuComplex *dl, + cuComplex *d, cuComplex *du, cuComplex *dw, cuComplex *x, int batchCount, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, cuComplex *, cuComplex *, cuComplex *, + cuComplex *, cuComplex *, cuComplex *, int, void *); + static auto func_ptr = LoadSymbol("cusparseCgpsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgpsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, cuDoubleComplex *ds, + cuDoubleComplex *dl, cuDoubleComplex *d, cuDoubleComplex *du, + cuDoubleComplex *dw, cuDoubleComplex *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, cuDoubleComplex *, cuDoubleComplex *, + cuDoubleComplex *, cuDoubleComplex *, cuDoubleComplex *, + cuDoubleComplex *, int, void *); + static auto func_ptr = LoadSymbol("cusparseZgpsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseXcsrgemmNnz(cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, + const cusparseMatDescr_t descrA, const int nnzA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, const int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, int *csrSortedRowPtrC, + int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + const cusparseMatDescr_t, const int, const int *, const int *, + const cusparseMatDescr_t, const int, const int *, const int *, + const cusparseMatDescr_t, int *, int *); + static auto func_ptr = LoadSymbol("cusparseXcsrgemmNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, descrA, nnzA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrgemm( + cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, + const cusparseMatDescr_t descrA, const int nnzA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, const int nnzB, const float *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, float *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + const cusparseMatDescr_t, const int, const float *, const int *, + const int *, const cusparseMatDescr_t, const int, const float *, + const int *, const int *, const cusparseMatDescr_t, float *, const int *, + int *); + static auto func_ptr = LoadSymbol("cusparseScsrgemm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrgemm( + cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, + const cusparseMatDescr_t descrA, int nnzA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const double *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, double *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, double *, const int *, int *); + static auto func_ptr = LoadSymbol("cusparseDcsrgemm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrgemm( + cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, + const cusparseMatDescr_t descrA, int nnzA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const cuComplex *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, cuComplex *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + const cusparseMatDescr_t, int, const cuComplex *, const int *, + const int *, const cusparseMatDescr_t, int, const cuComplex *, + const int *, const int *, const cusparseMatDescr_t, cuComplex *, + const int *, int *); + static auto func_ptr = LoadSymbol("cusparseCcsrgemm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrgemm( + cusparseHandle_t handle, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int n, int k, + const cusparseMatDescr_t descrA, int nnzA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const cuDoubleComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + cuDoubleComplex *csrSortedValC, const int *csrSortedRowPtrC, + int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, int, int, int, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cusparseMatDescr_t, int, const cuDoubleComplex *, + const int *, const int *, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, int *); + static auto func_ptr = LoadSymbol("cusparseZcsrgemm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsrgemm2Info(csrgemm2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrgemm2Info_t *); + static auto func_ptr = LoadSymbol("cusparseCreateCsrgemm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsrgemm2Info(csrgemm2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrgemm2Info_t); + static auto func_ptr = LoadSymbol("cusparseDestroyCsrgemm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrgemm2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int k, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, const float *beta, + const cusparseMatDescr_t descrD, int nnzD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, csrgemm2Info_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const cusparseMatDescr_t, + int, const int *, const int *, const cusparseMatDescr_t, int, const int *, + const int *, const float *, const cusparseMatDescr_t, int, const int *, + const int *, csrgemm2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseScsrgemm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, beta, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrgemm2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int k, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const double *beta, const cusparseMatDescr_t descrD, int nnzD, + const int *csrSortedRowPtrD, const int *csrSortedColIndD, + csrgemm2Info_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const cusparseMatDescr_t, + int, const int *, const int *, const cusparseMatDescr_t, int, const int *, + const int *, const double *, const cusparseMatDescr_t, int, const int *, + const int *, csrgemm2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDcsrgemm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, beta, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrgemm2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int k, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cuComplex *beta, const cusparseMatDescr_t descrD, int nnzD, + const int *csrSortedRowPtrD, const int *csrSortedColIndD, + csrgemm2Info_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, + const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int, const int *, const int *, + const cuComplex *, const cusparseMatDescr_t, int, const int *, + const int *, csrgemm2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCcsrgemm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, beta, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrgemm2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int k, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cuDoubleComplex *beta, const cusparseMatDescr_t descrD, int nnzD, + const int *csrSortedRowPtrD, const int *csrSortedColIndD, + csrgemm2Info_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int, const int *, const int *, + const cuDoubleComplex *, const cusparseMatDescr_t, int, const int *, + const int *, csrgemm2Info_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseZcsrgemm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, beta, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrgemm2Nnz( + cusparseHandle_t handle, int m, int n, int k, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrD, int nnzD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, const csrgemm2Info_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, int, + const int *, const int *, const cusparseMatDescr_t, int, const int *, + const int *, const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int *, int *, const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol("cusparseXcsrgemm2Nnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrgemm2( + cusparseHandle_t handle, int m, int n, int k, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const float *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, const float *beta, + const cusparseMatDescr_t descrD, int nnzD, const float *csrSortedValD, + const int *csrSortedRowPtrD, const int *csrSortedColIndD, + const cusparseMatDescr_t descrC, float *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, + const csrgemm2Info_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const cusparseMatDescr_t, + int, const float *, const int *, const int *, const cusparseMatDescr_t, + int, const float *, const int *, const int *, const float *, + const cusparseMatDescr_t, int, const float *, const int *, const int *, + const cusparseMatDescr_t, float *, const int *, int *, + const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsrgemm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, beta, + descrD, nnzD, csrSortedValD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedValC, csrSortedRowPtrC, + csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrgemm2( + cusparseHandle_t handle, int m, int n, int k, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const double *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const double *beta, const cusparseMatDescr_t descrD, int nnzD, + const double *csrSortedValD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, const cusparseMatDescr_t descrC, + double *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + const csrgemm2Info_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const cusparseMatDescr_t, + int, const double *, const int *, const int *, const cusparseMatDescr_t, + int, const double *, const int *, const int *, const double *, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, double *, const int *, int *, + const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsrgemm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, beta, + descrD, nnzD, csrSortedValD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedValC, csrSortedRowPtrC, + csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrgemm2( + cusparseHandle_t handle, int m, int n, int k, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const cuComplex *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cuComplex *beta, const cusparseMatDescr_t descrD, int nnzD, + const cuComplex *csrSortedValD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, const cusparseMatDescr_t descrC, + cuComplex *csrSortedValC, const int *csrSortedRowPtrC, + int *csrSortedColIndC, const csrgemm2Info_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, + const cusparseMatDescr_t, int, const cuComplex *, const int *, + const int *, const cusparseMatDescr_t, int, const cuComplex *, + const int *, const int *, const cuComplex *, const cusparseMatDescr_t, + int, const cuComplex *, const int *, const int *, + const cusparseMatDescr_t, cuComplex *, const int *, int *, + const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsrgemm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, beta, + descrD, nnzD, csrSortedValD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedValC, csrSortedRowPtrC, + csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrgemm2( + cusparseHandle_t handle, int m, int n, int k, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const cuDoubleComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cuDoubleComplex *beta, + const cusparseMatDescr_t descrD, int nnzD, + const cuDoubleComplex *csrSortedValD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, const cusparseMatDescr_t descrC, + cuDoubleComplex *csrSortedValC, const int *csrSortedRowPtrC, + int *csrSortedColIndC, const csrgemm2Info_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cusparseMatDescr_t, int, const cuDoubleComplex *, + const int *, const int *, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cusparseMatDescr_t, cuDoubleComplex *, const int *, + int *, const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsrgemm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, beta, + descrD, nnzD, csrSortedValD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedValC, csrSortedRowPtrC, + csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrgeamNnz( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, int, const int *, + const int *, const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int *, int *); + static auto func_ptr = LoadSymbol("cusparseXcsrgeamNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrgeam( + cusparseHandle_t handle, int m, int n, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *beta, + const cusparseMatDescr_t descrB, int nnzB, const float *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, float *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const cusparseMatDescr_t, int, + const float *, const int *, const int *, const float *, + const cusparseMatDescr_t, int, const float *, const int *, const int *, + const cusparseMatDescr_t, float *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseScsrgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrgeam( + cusparseHandle_t handle, int m, int n, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *beta, const cusparseMatDescr_t descrB, int nnzB, + const double *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + double *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const cusparseMatDescr_t, int, + const double *, const int *, const int *, const double *, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, double *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDcsrgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrgeam( + cusparseHandle_t handle, int m, int n, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cuComplex *beta, const cusparseMatDescr_t descrB, int nnzB, + const cuComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + cuComplex *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cusparseMatDescr_t, + int, const cuComplex *, const int *, const int *, const cuComplex *, + const cusparseMatDescr_t, int, const cuComplex *, const int *, + const int *, const cusparseMatDescr_t, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseCcsrgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrgeam( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *beta, + const cusparseMatDescr_t descrB, int nnzB, + const cuDoubleComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + cuDoubleComplex *csrSortedValC, int *csrSortedRowPtrC, + int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cuDoubleComplex *, const cusparseMatDescr_t, int, + const cuDoubleComplex *, const int *, const int *, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseZcsrgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrgeam2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *beta, + const cusparseMatDescr_t descrB, int nnzB, const float *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, const float *csrSortedValC, + const int *csrSortedRowPtrC, const int *csrSortedColIndC, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const cusparseMatDescr_t, int, + const float *, const int *, const int *, const float *, + const cusparseMatDescr_t, int, const float *, const int *, const int *, + const cusparseMatDescr_t, const float *, const int *, const int *, + size_t *); + static auto func_ptr = LoadSymbol("cusparseScsrgeam2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrgeam2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *beta, const cusparseMatDescr_t descrB, int nnzB, + const double *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + const double *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const cusparseMatDescr_t, int, + const double *, const int *, const int *, const double *, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, const double *, const int *, const int *, + size_t *); + static auto func_ptr = LoadSymbol("cusparseDcsrgeam2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrgeam2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cuComplex *beta, const cusparseMatDescr_t descrB, int nnzB, + const cuComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + const cuComplex *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cusparseMatDescr_t, + int, const cuComplex *, const int *, const int *, const cuComplex *, + const cusparseMatDescr_t, int, const cuComplex *, const int *, + const int *, const cusparseMatDescr_t, const cuComplex *, const int *, + const int *, size_t *); + static auto func_ptr = LoadSymbol("cusparseCcsrgeam2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrgeam2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *beta, + const cusparseMatDescr_t descrB, int nnzB, + const cuDoubleComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + const cuDoubleComplex *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cuDoubleComplex *, const cusparseMatDescr_t, int, + const cuDoubleComplex *, const int *, const int *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, size_t *); + static auto func_ptr = LoadSymbol("cusparseZcsrgeam2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrgeam2Nnz( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, int, const int *, + const int *, const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseXcsrgeam2Nnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, workspace); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrgeam2( + cusparseHandle_t handle, int m, int n, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *beta, + const cusparseMatDescr_t descrB, int nnzB, const float *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, float *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const cusparseMatDescr_t, int, + const float *, const int *, const int *, const float *, + const cusparseMatDescr_t, int, const float *, const int *, const int *, + const cusparseMatDescr_t, float *, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseScsrgeam2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrgeam2( + cusparseHandle_t handle, int m, int n, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *beta, const cusparseMatDescr_t descrB, int nnzB, + const double *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + double *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const cusparseMatDescr_t, int, + const double *, const int *, const int *, const double *, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, double *, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseDcsrgeam2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrgeam2( + cusparseHandle_t handle, int m, int n, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cuComplex *beta, const cusparseMatDescr_t descrB, int nnzB, + const cuComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + cuComplex *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cusparseMatDescr_t, + int, const cuComplex *, const int *, const int *, const cuComplex *, + const cusparseMatDescr_t, int, const cuComplex *, const int *, + const int *, const cusparseMatDescr_t, cuComplex *, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseCcsrgeam2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrgeam2( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *beta, + const cusparseMatDescr_t descrB, int nnzB, + const cuDoubleComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + cuDoubleComplex *csrSortedValC, int *csrSortedRowPtrC, + int *csrSortedColIndC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cuDoubleComplex *, const cusparseMatDescr_t, int, + const cuDoubleComplex *, const int *, const int *, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseZcsrgeam2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrcolor( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const float *fractionToColor, int *ncolors, + int *coloring, int *reordering, const cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, const float *, int *, int *, int *, + const cusparseColorInfo_t); + static auto func_ptr = LoadSymbol("cusparseScsrcolor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, fractionToColor, ncolors, coloring, + reordering, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrcolor( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const double *fractionToColor, int *ncolors, + int *coloring, int *reordering, const cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, int *, int *, int *, + const cusparseColorInfo_t); + static auto func_ptr = LoadSymbol("cusparseDcsrcolor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, fractionToColor, ncolors, coloring, + reordering, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrcolor( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const float *fractionToColor, int *ncolors, + int *coloring, int *reordering, const cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const float *, int *, int *, int *, + const cusparseColorInfo_t); + static auto func_ptr = LoadSymbol("cusparseCcsrcolor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, fractionToColor, ncolors, coloring, + reordering, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrcolor( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const double *fractionToColor, int *ncolors, + int *coloring, int *reordering, const cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, const double *, int *, + int *, int *, const cusparseColorInfo_t); + static auto func_ptr = LoadSymbol("cusparseZcsrcolor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, fractionToColor, ncolors, coloring, + reordering, info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSnnz(cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *A, int lda, + int *nnzPerRowCol, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, int, int *, int *); + static auto func_ptr = LoadSymbol("cusparseSnnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnnz(cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *A, int lda, + int *nnzPerRowCol, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, int, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDnnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCnnz(cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *A, int lda, + int *nnzPerRowCol, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, int, int *, int *); + static auto func_ptr = LoadSymbol("cusparseCnnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZnnz(cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *A, int lda, + int *nnzPerRowCol, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, int, int *, int *); + static auto func_ptr = LoadSymbol("cusparseZnnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseSnnz_compress( + cusparseHandle_t handle, int m, const cusparseMatDescr_t descr, + const float *csrSortedValA, const int *csrSortedRowPtrA, int *nnzPerRow, + int *nnzC, float tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cusparseMatDescr_t, const float *, + const int *, int *, int *, float); + static auto func_ptr = LoadSymbol("cusparseSnnz_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, descr, csrSortedValA, csrSortedRowPtrA, nnzPerRow, + nnzC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseDnnz_compress( + cusparseHandle_t handle, int m, const cusparseMatDescr_t descr, + const double *csrSortedValA, const int *csrSortedRowPtrA, int *nnzPerRow, + int *nnzC, double tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cusparseMatDescr_t, const double *, + const int *, int *, int *, double); + static auto func_ptr = LoadSymbol("cusparseDnnz_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, descr, csrSortedValA, csrSortedRowPtrA, nnzPerRow, + nnzC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseCnnz_compress( + cusparseHandle_t handle, int m, const cusparseMatDescr_t descr, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, int *nnzPerRow, + int *nnzC, cuComplex tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cusparseMatDescr_t, const cuComplex *, + const int *, int *, int *, cuComplex); + static auto func_ptr = LoadSymbol("cusparseCnnz_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, descr, csrSortedValA, csrSortedRowPtrA, nnzPerRow, + nnzC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseZnnz_compress( + cusparseHandle_t handle, int m, const cusparseMatDescr_t descr, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + int *nnzPerRow, int *nnzC, cuDoubleComplex tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cusparseMatDescr_t, const cuDoubleComplex *, + const int *, int *, int *, cuDoubleComplex); + static auto func_ptr = LoadSymbol("cusparseZnnz_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, descr, csrSortedValA, csrSortedRowPtrA, nnzPerRow, + nnzC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2csr_compress( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedColIndA, + const int *csrSortedRowPtrA, int nnzA, const int *nnzPerRow, + float *csrSortedValC, int *csrSortedColIndC, int *csrSortedRowPtrC, + float tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, int, const int *, float *, int *, int *, float); + static auto func_ptr = LoadSymbol("cusparseScsr2csr_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedColIndA, + csrSortedRowPtrA, nnzA, nnzPerRow, csrSortedValC, + csrSortedColIndC, csrSortedRowPtrC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2csr_compress( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedColIndA, + const int *csrSortedRowPtrA, int nnzA, const int *nnzPerRow, + double *csrSortedValC, int *csrSortedColIndC, int *csrSortedRowPtrC, + double tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, int, const int *, double *, int *, int *, + double); + static auto func_ptr = LoadSymbol("cusparseDcsr2csr_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedColIndA, + csrSortedRowPtrA, nnzA, nnzPerRow, csrSortedValC, + csrSortedColIndC, csrSortedRowPtrC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2csr_compress( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedColIndA, + const int *csrSortedRowPtrA, int nnzA, const int *nnzPerRow, + cuComplex *csrSortedValC, int *csrSortedColIndC, int *csrSortedRowPtrC, + cuComplex tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, int, const int *, cuComplex *, int *, int *, + cuComplex); + static auto func_ptr = LoadSymbol("cusparseCcsr2csr_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedColIndA, + csrSortedRowPtrA, nnzA, nnzPerRow, csrSortedValC, + csrSortedColIndC, csrSortedRowPtrC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2csr_compress( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedColIndA, + const int *csrSortedRowPtrA, int nnzA, const int *nnzPerRow, + cuDoubleComplex *csrSortedValC, int *csrSortedColIndC, + int *csrSortedRowPtrC, cuDoubleComplex tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, const int *, + cuDoubleComplex *, int *, int *, cuDoubleComplex); + static auto func_ptr = LoadSymbol("cusparseZcsr2csr_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedColIndA, + csrSortedRowPtrA, nnzA, nnzPerRow, csrSortedValC, + csrSortedColIndC, csrSortedRowPtrC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseSdense2csr( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *A, int lda, const int *nnzPerRow, float *csrSortedValA, + int *csrSortedRowPtrA, int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, int, + const int *, float *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseSdense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseDdense2csr( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *A, int lda, const int *nnzPerRow, double *csrSortedValA, + int *csrSortedRowPtrA, int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, int, + const int *, double *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDdense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseCdense2csr( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *A, int lda, const int *nnzPerRow, cuComplex *csrSortedValA, + int *csrSortedRowPtrA, int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + int, const int *, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseCdense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseZdense2csr( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *A, int lda, const int *nnzPerRow, + cuDoubleComplex *csrSortedValA, int *csrSortedRowPtrA, + int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, int, const int *, cuDoubleComplex *, int *, + int *); + static auto func_ptr = LoadSymbol("cusparseZdense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, float *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float *, int); + static auto func_ptr = LoadSymbol("cusparseScsr2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, double *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, double *, int); + static auto func_ptr = LoadSymbol("cusparseDcsr2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cuComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseCcsr2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cuDoubleComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol("cusparseZcsr2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseSdense2csc( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *A, int lda, const int *nnzPerCol, float *cscSortedValA, + int *cscSortedRowIndA, int *cscSortedColPtrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, int, + const int *, float *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseSdense2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerCol, cscSortedValA, + cscSortedRowIndA, cscSortedColPtrA); +} + +cusparseStatus_t CUSPARSEAPI cusparseDdense2csc( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *A, int lda, const int *nnzPerCol, double *cscSortedValA, + int *cscSortedRowIndA, int *cscSortedColPtrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, int, + const int *, double *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDdense2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerCol, cscSortedValA, + cscSortedRowIndA, cscSortedColPtrA); +} + +cusparseStatus_t CUSPARSEAPI cusparseCdense2csc( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *A, int lda, const int *nnzPerCol, cuComplex *cscSortedValA, + int *cscSortedRowIndA, int *cscSortedColPtrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + int, const int *, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseCdense2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerCol, cscSortedValA, + cscSortedRowIndA, cscSortedColPtrA); +} + +cusparseStatus_t CUSPARSEAPI cusparseZdense2csc( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *A, int lda, const int *nnzPerCol, + cuDoubleComplex *cscSortedValA, int *cscSortedRowIndA, + int *cscSortedColPtrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, int, const int *, cuDoubleComplex *, int *, + int *); + static auto func_ptr = LoadSymbol("cusparseZdense2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerCol, cscSortedValA, + cscSortedRowIndA, cscSortedColPtrA); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsc2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, float *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float *, int); + static auto func_ptr = LoadSymbol("cusparseScsc2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsc2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, double *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, double *, int); + static auto func_ptr = LoadSymbol("cusparseDcsc2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsc2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, cuComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseCcsc2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsc2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, cuDoubleComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol("cusparseZcsc2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle, + const int *cooRowInd, int nnz, + int m, int *csrSortedRowPtr, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const int *, int, int, int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseXcoo2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, cooRowInd, nnz, m, csrSortedRowPtr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle, + const int *csrSortedRowPtr, + int nnz, int m, int *cooRowInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const int *, int, int, int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseXcsr2coo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, csrSortedRowPtr, nnz, m, cooRowInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsr2cscEx( + cusparseHandle_t handle, int m, int n, int nnz, const void *csrSortedVal, + cudaDataType csrSortedValtype, const int *csrSortedRowPtr, + const int *csrSortedColInd, void *cscSortedVal, + cudaDataType cscSortedValtype, int *cscSortedRowInd, int *cscSortedColPtr, + cusparseAction_t copyValues, cusparseIndexBase_t idxBase, + cudaDataType executiontype) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const void *, cudaDataType, const int *, + const int *, void *, cudaDataType, int *, int *, cusparseAction_t, + cusparseIndexBase_t, cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCsr2cscEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrSortedVal, csrSortedValtype, + csrSortedRowPtr, csrSortedColInd, cscSortedVal, + cscSortedValtype, cscSortedRowInd, cscSortedColPtr, + copyValues, idxBase, executiontype); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2csc( + cusparseHandle_t handle, int m, int n, int nnz, const float *csrSortedVal, + const int *csrSortedRowPtr, const int *csrSortedColInd, float *cscSortedVal, + int *cscSortedRowInd, int *cscSortedColPtr, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const int *, const int *, + float *, int *, int *, cusparseAction_t, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseScsr2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, cscSortedVal, cscSortedRowInd, + cscSortedColPtr, copyValues, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2csc( + cusparseHandle_t handle, int m, int n, int nnz, const double *csrSortedVal, + const int *csrSortedRowPtr, const int *csrSortedColInd, + double *cscSortedVal, int *cscSortedRowInd, int *cscSortedColPtr, + cusparseAction_t copyValues, cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const int *, const int *, + double *, int *, int *, cusparseAction_t, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseDcsr2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, cscSortedVal, cscSortedRowInd, + cscSortedColPtr, copyValues, idxBase); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCcsr2csc(cusparseHandle_t handle, int m, int n, int nnz, + const cuComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, cuComplex *cscSortedVal, + int *cscSortedRowInd, int *cscSortedColPtr, + cusparseAction_t copyValues, cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, const int *, + const int *, cuComplex *, int *, int *, cusparseAction_t, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseCcsr2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, cscSortedVal, cscSortedRowInd, + cscSortedColPtr, copyValues, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2csc( + cusparseHandle_t handle, int m, int n, int nnz, + const cuDoubleComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, cuDoubleComplex *cscSortedVal, + int *cscSortedRowInd, int *cscSortedColPtr, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, const int *, + const int *, cuDoubleComplex *, int *, int *, cusparseAction_t, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol("cusparseZcsr2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, cscSortedVal, cscSortedRowInd, + cscSortedColPtr, copyValues, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSdense2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *A, int lda, const int *nnzPerRow, cusparseHybMat_t hybA, + int userEllWidth, cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, int, + const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseSdense2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, hybA, userEllWidth, + partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseDdense2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *A, int lda, const int *nnzPerRow, cusparseHybMat_t hybA, + int userEllWidth, cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, int, + const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseDdense2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, hybA, userEllWidth, + partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCdense2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *A, int lda, const int *nnzPerRow, cusparseHybMat_t hybA, + int userEllWidth, cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + int, const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseCdense2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, hybA, userEllWidth, + partitionType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZdense2hyb(cusparseHandle_t handle, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *A, + int lda, const int *nnzPerRow, cusparseHybMat_t hybA, + int userEllWidth, cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, int, const int *, cusparseHybMat_t, int, + cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseZdense2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, hybA, userEllWidth, + partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseShyb2dense(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + float *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + float *, int); + static auto func_ptr = LoadSymbol("cusparseShyb2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseDhyb2dense(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + double *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + double *, int); + static auto func_ptr = LoadSymbol("cusparseDhyb2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseChyb2dense(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + cuComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + cuComplex *, int); + static auto func_ptr = LoadSymbol("cusparseChyb2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseZhyb2dense(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + cuDoubleComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol("cusparseZhyb2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseHybMat_t hybA, int userEllWidth, + cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseScsr2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, hybA, userEllWidth, partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseHybMat_t hybA, int userEllWidth, + cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseDcsr2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, hybA, userEllWidth, partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseHybMat_t hybA, int userEllWidth, + cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseCcsr2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, hybA, userEllWidth, partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cusparseHybMat_t hybA, int userEllWidth, + cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, cusparseHybMat_t, int, + cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseZcsr2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, hybA, userEllWidth, partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseShyb2csr(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + float *csrSortedValA, + int *csrSortedRowPtrA, + int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + float *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseShyb2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseDhyb2csr(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + double *csrSortedValA, + int *csrSortedRowPtrA, + int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + double *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDhyb2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseChyb2csr(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + cuComplex *csrSortedValA, + int *csrSortedRowPtrA, + int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseChyb2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseZhyb2csr(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + cuDoubleComplex *csrSortedValA, + int *csrSortedRowPtrA, + int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseZhyb2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsc2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, cusparseHybMat_t hybA, int userEllWidth, + cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseScsc2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, hybA, userEllWidth, partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsc2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, cusparseHybMat_t hybA, int userEllWidth, + cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseDcsc2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, hybA, userEllWidth, partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsc2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, cusparseHybMat_t hybA, int userEllWidth, + cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, cusparseHybMat_t, int, cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseCcsc2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, hybA, userEllWidth, partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsc2hyb( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, cusparseHybMat_t hybA, int userEllWidth, + cusparseHybPartition_t partitionType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, cusparseHybMat_t, int, + cusparseHybPartition_t); + static auto func_ptr = LoadSymbol("cusparseZcsc2hyb"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, hybA, userEllWidth, partitionType); +} + +cusparseStatus_t CUSPARSEAPI cusparseShyb2csc(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + float *cscSortedVal, + int *cscSortedRowInd, + int *cscSortedColPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + float *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseShyb2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, cscSortedVal, cscSortedRowInd, + cscSortedColPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseDhyb2csc(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + double *cscSortedVal, + int *cscSortedRowInd, + int *cscSortedColPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + double *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDhyb2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, cscSortedVal, cscSortedRowInd, + cscSortedColPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseChyb2csc(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + cuComplex *cscSortedVal, + int *cscSortedRowInd, + int *cscSortedColPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseChyb2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, cscSortedVal, cscSortedRowInd, + cscSortedColPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseZhyb2csc(cusparseHandle_t handle, + const cusparseMatDescr_t descrA, + const cusparseHybMat_t hybA, + cuDoubleComplex *cscSortedVal, + int *cscSortedRowInd, + int *cscSortedColPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const cusparseMatDescr_t, const cusparseHybMat_t, + cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseZhyb2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, descrA, hybA, cscSortedVal, cscSortedRowInd, + cscSortedColPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsr2bsrNnz( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, int blockDim, const cusparseMatDescr_t descrC, + int *bsrSortedRowPtrC, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const int *, const int *, int, const cusparseMatDescr_t, int *, int *); + static auto func_ptr = LoadSymbol("cusparseXcsr2bsrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedRowPtrC, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2bsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, float *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, const cusparseMatDescr_t, + float *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseScsr2bsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedValC, + bsrSortedRowPtrC, bsrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2bsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, double *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, const cusparseMatDescr_t, + double *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDcsr2bsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedValC, + bsrSortedRowPtrC, bsrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2bsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, cuComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, + const cusparseMatDescr_t, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseCcsr2bsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedValC, + bsrSortedRowPtrC, bsrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2bsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, cuDoubleComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseZcsr2bsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedValC, + bsrSortedRowPtrC, bsrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, float *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, const cusparseMatDescr_t, + float *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseSbsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, blockDim, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, double *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, const cusparseMatDescr_t, + double *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDbsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, blockDim, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, cuComplex *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, + const cusparseMatDescr_t, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseCbsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, blockDim, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, cuDoubleComplex *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseZbsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, blockDim, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsc_bufferSize( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const float *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const int *, const int *, + int, int, int *); + static auto func_ptr = LoadSymbol("cusparseSgebsr2gebsc_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsc_bufferSize( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const double *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const int *, const int *, + int, int, int *); + static auto func_ptr = LoadSymbol("cusparseDgebsr2gebsc_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsc_bufferSize( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, const int *, + const int *, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseCgebsr2gebsc_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsc_bufferSize( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, const int *, + const int *, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseZgebsr2gebsc_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsc_bufferSizeExt( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const float *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const int *, const int *, + int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSgebsr2gebsc_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsc_bufferSizeExt( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const double *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const int *, const int *, + int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDgebsr2gebsc_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsc_bufferSizeExt( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, const int *, + const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseCgebsr2gebsc_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsc_bufferSizeExt( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, const int *, + const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseZgebsr2gebsc_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsc( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const float *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, float *bscVal, + int *bscRowInd, int *bscColPtr, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const int *, const int *, + int, int, float *, int *, int *, cusparseAction_t, cusparseIndexBase_t, + void *); + static auto func_ptr = LoadSymbol("cusparseSgebsr2gebsc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, bscVal, bscRowInd, + bscColPtr, copyValues, idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsc( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const double *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + double *bscVal, int *bscRowInd, int *bscColPtr, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const int *, const int *, + int, int, double *, int *, int *, cusparseAction_t, cusparseIndexBase_t, + void *); + static auto func_ptr = LoadSymbol("cusparseDgebsr2gebsc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, bscVal, bscRowInd, + bscColPtr, copyValues, idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsc( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + cuComplex *bscVal, int *bscRowInd, int *bscColPtr, + cusparseAction_t copyValues, cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, const int *, + const int *, int, int, cuComplex *, int *, int *, cusparseAction_t, + cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol("cusparseCgebsr2gebsc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, bscVal, bscRowInd, + bscColPtr, copyValues, idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsc( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + cuDoubleComplex *bscVal, int *bscRowInd, int *bscColPtr, + cusparseAction_t copyValues, cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, const int *, + const int *, int, int, cuDoubleComplex *, int *, int *, cusparseAction_t, + cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol("cusparseZgebsr2gebsc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, bscVal, bscRowInd, + bscColPtr, copyValues, idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, int rowBlockDim, int colBlockDim, + const cusparseMatDescr_t descrC, int *csrSortedRowPtrC, + int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const int *, const int *, int, int, const cusparseMatDescr_t, int *, + int *); + static auto func_ptr = LoadSymbol("cusparseXgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDim, + int colBlockDim, const cusparseMatDescr_t descrC, float *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, int, + const cusparseMatDescr_t, float *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseSgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDim, + int colBlockDim, const cusparseMatDescr_t descrC, double *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, int, + const cusparseMatDescr_t, double *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseDgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDim, + int colBlockDim, const cusparseMatDescr_t descrC, cuComplex *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, int, + const cusparseMatDescr_t, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseCgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDim, + int colBlockDim, const cusparseMatDescr_t descrC, + cuDoubleComplex *csrSortedValC, int *csrSortedRowPtrC, + int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, int, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol("cusparseZgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseScsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseDcsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseCcsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseZcsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseScsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDcsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseCcsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseZcsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsr2gebsrNnz( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrC, + int *bsrSortedRowPtrC, int rowBlockDim, int colBlockDim, + int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const int *, const int *, const cusparseMatDescr_t, int *, int, int, + int *, void *); + static auto func_ptr = LoadSymbol("cusparseXcsr2gebsrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedRowPtrC, rowBlockDim, + colBlockDim, nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrC, float *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDim, + int colBlockDim, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, const cusparseMatDescr_t, + float *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseScsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDim, colBlockDim, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrC, double *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDim, + int colBlockDim, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, const cusparseMatDescr_t, + double *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseDcsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDim, colBlockDim, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrC, cuComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDim, + int colBlockDim, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, const cusparseMatDescr_t, + cuComplex *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseCcsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDim, colBlockDim, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrC, cuDoubleComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDim, + int colBlockDim, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *, int, int, + void *); + static auto func_ptr = LoadSymbol("cusparseZcsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDim, colBlockDim, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + int, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseSgebsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + int, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseDgebsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, int, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseCgebsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, int, int, int, int *); + static auto func_ptr = LoadSymbol("cusparseZgebsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + int, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSgebsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + int, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDgebsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, int, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseCgebsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, int, int, int, size_t *); + static auto func_ptr = + LoadSymbol("cusparseZgebsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseXgebsr2gebsrNnz( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, int rowBlockDimA, int colBlockDimA, + const cusparseMatDescr_t descrC, int *bsrSortedRowPtrC, int rowBlockDimC, + int colBlockDimC, int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const int *, const int *, int, int, + const cusparseMatDescr_t, int *, int, int, int *, void *); + static auto func_ptr = LoadSymbol("cusparseXgebsr2gebsrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDimA, colBlockDimA, descrC, + bsrSortedRowPtrC, rowBlockDimC, colBlockDimC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, const cusparseMatDescr_t descrC, float *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDimC, + int colBlockDimC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + int, const cusparseMatDescr_t, float *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseSgebsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDimC, colBlockDimC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, const cusparseMatDescr_t descrC, double *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDimC, + int colBlockDimC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + int, const cusparseMatDescr_t, double *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseDgebsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDimC, colBlockDimC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, const cusparseMatDescr_t descrC, cuComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDimC, + int colBlockDimC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, int, const cusparseMatDescr_t, cuComplex *, int *, int *, int, int, + void *); + static auto func_ptr = LoadSymbol("cusparseCgebsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDimC, colBlockDimC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, const cusparseMatDescr_t descrC, + cuDoubleComplex *bsrSortedValC, int *bsrSortedRowPtrC, + int *bsrSortedColIndC, int rowBlockDimC, int colBlockDimC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, int, const cusparseMatDescr_t, cuDoubleComplex *, int *, + int *, int, int, void *); + static auto func_ptr = LoadSymbol("cusparseZgebsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDimC, colBlockDimC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateIdentityPermutation(cusparseHandle_t handle, int n, int *p) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, int *); + static auto func_ptr = + LoadSymbol("cusparseCreateIdentityPermutation"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, p); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcoosort_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, const int *cooRowsA, + const int *cooColsA, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const int *, const int *, size_t *); + static auto func_ptr = LoadSymbol("cusparseXcoosort_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, cooRowsA, cooColsA, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcoosortByRow(cusparseHandle_t handle, + int m, int n, int nnz, + int *cooRowsA, int *cooColsA, + int *P, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int *, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseXcoosortByRow"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, cooRowsA, cooColsA, P, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcoosortByColumn(cusparseHandle_t handle, + int m, int n, int nnz, + int *cooRowsA, + int *cooColsA, int *P, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int *, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseXcoosortByColumn"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, cooRowsA, cooColsA, P, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrsort_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, const int *csrRowPtrA, + const int *csrColIndA, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const int *, const int *, size_t *); + static auto func_ptr = LoadSymbol("cusparseXcsrsort_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrRowPtrA, csrColIndA, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrsort(cusparseHandle_t handle, int m, + int n, int nnz, + const cusparseMatDescr_t descrA, + const int *csrRowPtrA, + int *csrColIndA, int *P, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const int *, + int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseXcsrsort"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrRowPtrA, csrColIndA, P, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcscsort_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, const int *cscColPtrA, + const int *cscRowIndA, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const int *, const int *, size_t *); + static auto func_ptr = LoadSymbol("cusparseXcscsort_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, cscColPtrA, cscRowIndA, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcscsort(cusparseHandle_t handle, int m, + int n, int nnz, + const cusparseMatDescr_t descrA, + const int *cscColPtrA, + int *cscRowIndA, int *P, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const int *, + int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseXcscsort"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, cscColPtrA, cscRowIndA, P, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsru2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, float *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, float *, const int *, int *, + csru2csrInfo_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseScsru2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsru2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, double *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, double *, const int *, int *, + csru2csrInfo_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseDcsru2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsru2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, cuComplex *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, cuComplex *, const int *, int *, + csru2csrInfo_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCcsru2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsru2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, cuDoubleComplex *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, cuDoubleComplex *, const int *, int *, + csru2csrInfo_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseZcsru2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsru2csr( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, float *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, float *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsru2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsru2csr( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, double *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, double *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsru2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsru2csr( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, cuComplex *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsru2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsru2csr( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsru2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2csru( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, float *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, float *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol("cusparseScsr2csru"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2csru( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, double *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, double *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol("cusparseDcsr2csru"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2csru( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, cuComplex *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol("cusparseCcsr2csru"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2csru( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol("cusparseZcsr2csru"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + const float *threshold, const cusparseMatDescr_t descrC, + const float *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + size_t *); + static auto func_ptr = + LoadSymbol("cusparseSpruneDense2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + const double *threshold, const cusparseMatDescr_t descrC, + const double *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + size_t *); + static auto func_ptr = + LoadSymbol("cusparseDpruneDense2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csrNnz( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + const float *threshold, const cusparseMatDescr_t descrC, int *csrRowPtrC, + int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, const float *, + const cusparseMatDescr_t, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseSpruneDense2csrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrRowPtrC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csrNnz( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + const double *threshold, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, const double *, + const cusparseMatDescr_t, int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseDpruneDense2csrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csr( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + const float *threshold, const cusparseMatDescr_t descrC, + float *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, const float *, + const cusparseMatDescr_t, float *, const int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseSpruneDense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csr( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + const double *threshold, const cusparseMatDescr_t descrC, + double *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, const double *, + const cusparseMatDescr_t, double *, const int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseDpruneDense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const float *threshold, const cusparseMatDescr_t descrC, + const float *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, const float *, const cusparseMatDescr_t, + const float *, const int *, const int *, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSpruneCsr2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *threshold, const cusparseMatDescr_t descrC, + const double *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, const cusparseMatDescr_t, + const double *, const int *, const int *, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDpruneCsr2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csrNnz( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const float *threshold, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, const float *, const cusparseMatDescr_t, int *, + int *, void *); + static auto func_ptr = LoadSymbol("cusparseSpruneCsr2csrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csrNnz( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *threshold, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, const cusparseMatDescr_t, int *, + int *, void *); + static auto func_ptr = LoadSymbol("cusparseDpruneCsr2csrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csr( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const float *threshold, const cusparseMatDescr_t descrC, + float *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, const float *, const cusparseMatDescr_t, + float *, const int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseSpruneCsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csr( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *threshold, const cusparseMatDescr_t descrC, + double *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, const cusparseMatDescr_t, + double *, const int *, int *, void *); + static auto func_ptr = LoadSymbol("cusparseDpruneCsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csrByPercentage_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + float percentage, const cusparseMatDescr_t descrC, + const float *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, pruneInfo_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, float, + const cusparseMatDescr_t, const float *, const int *, const int *, + pruneInfo_t, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSpruneDense2csrByPercentage_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csrByPercentage_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + float percentage, const cusparseMatDescr_t descrC, + const double *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, pruneInfo_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, float, + const cusparseMatDescr_t, const double *, const int *, const int *, + pruneInfo_t, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDpruneDense2csrByPercentage_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csrNnzByPercentage( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + float percentage, const cusparseMatDescr_t descrC, int *csrRowPtrC, + int *nnzTotalDevHostPtr, pruneInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, float, + const cusparseMatDescr_t, int *, int *, pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol("cusparseSpruneDense2csrNnzByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csrNnzByPercentage( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + float percentage, const cusparseMatDescr_t descrC, int *csrRowPtrC, + int *nnzTotalDevHostPtr, pruneInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, float, + const cusparseMatDescr_t, int *, int *, pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol("cusparseDpruneDense2csrNnzByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csrByPercentage( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + float percentage, const cusparseMatDescr_t descrC, float *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, pruneInfo_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, float, + const cusparseMatDescr_t, float *, const int *, int *, pruneInfo_t, + void *); + static auto func_ptr = + LoadSymbol("cusparseSpruneDense2csrByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csrByPercentage( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + float percentage, const cusparseMatDescr_t descrC, double *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, pruneInfo_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, float, + const cusparseMatDescr_t, double *, const int *, int *, pruneInfo_t, + void *); + static auto func_ptr = + LoadSymbol("cusparseDpruneDense2csrByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csrByPercentage_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, const float *csrSortedValC, + const int *csrSortedRowPtrC, const int *csrSortedColIndC, pruneInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float, const cusparseMatDescr_t, const float *, + const int *, const int *, pruneInfo_t, size_t *); + static auto func_ptr = + LoadSymbol("cusparseSpruneCsr2csrByPercentage_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csrByPercentage_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, const double *csrSortedValC, + const int *csrSortedRowPtrC, const int *csrSortedColIndC, pruneInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, float, const cusparseMatDescr_t, const double *, + const int *, const int *, pruneInfo_t, size_t *); + static auto func_ptr = + LoadSymbol("cusparseDpruneCsr2csrByPercentage_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csrNnzByPercentage( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, int *csrSortedRowPtrC, + int *nnzTotalDevHostPtr, pruneInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float, const cusparseMatDescr_t, int *, int *, + pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol("cusparseSpruneCsr2csrNnzByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csrNnzByPercentage( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, int *csrSortedRowPtrC, + int *nnzTotalDevHostPtr, pruneInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, float, const cusparseMatDescr_t, int *, int *, + pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol("cusparseDpruneCsr2csrNnzByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csrByPercentage( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, float *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, pruneInfo_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float, const cusparseMatDescr_t, float *, + const int *, int *, pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol("cusparseSpruneCsr2csrByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csrByPercentage( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, double *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, pruneInfo_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, float, const cusparseMatDescr_t, double *, + const int *, int *, pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol("cusparseDpruneCsr2csrByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsr2cscEx2( + cusparseHandle_t handle, int m, int n, int nnz, const void *csrVal, + const int *csrRowPtr, const int *csrColInd, void *cscVal, int *cscColPtr, + int *cscRowInd, cudaDataType valType, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase, cusparseCsr2CscAlg_t alg, void *buffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const void *, const int *, const int *, + void *, int *, int *, cudaDataType, cusparseAction_t, cusparseIndexBase_t, + cusparseCsr2CscAlg_t, void *); + static auto func_ptr = LoadSymbol("cusparseCsr2cscEx2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal, + cscColPtr, cscRowInd, valType, copyValues, idxBase, alg, + buffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsr2cscEx2_bufferSize( + cusparseHandle_t handle, int m, int n, int nnz, const void *csrVal, + const int *csrRowPtr, const int *csrColInd, void *cscVal, int *cscColPtr, + int *cscRowInd, cudaDataType valType, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase, cusparseCsr2CscAlg_t alg, size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const void *, const int *, const int *, + void *, int *, int *, cudaDataType, cusparseAction_t, cusparseIndexBase_t, + cusparseCsr2CscAlg_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseCsr2cscEx2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal, + cscColPtr, cscRowInd, valType, copyValues, idxBase, alg, + bufferSize); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateSpVec(cusparseSpVecDescr_t *spVecDescr, int64_t size, int64_t nnz, + void *indices, void *values, cusparseIndexType_t idxType, + cusparseIndexBase_t idxBase, cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpVecDescr_t *, int64_t, int64_t, void *, void *, + cusparseIndexType_t, cusparseIndexBase_t, cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCreateSpVec"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, size, nnz, indices, values, idxType, idxBase, + valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroySpVec(cusparseSpVecDescr_t spVecDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpVecDescr_t); + static auto func_ptr = LoadSymbol("cusparseDestroySpVec"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpVecGet( + const cusparseSpVecDescr_t spVecDescr, int64_t *size, int64_t *nnz, + void **indices, void **values, cusparseIndexType_t *idxType, + cusparseIndexBase_t *idxBase, cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + const cusparseSpVecDescr_t, int64_t *, int64_t *, void **, void **, + cusparseIndexType_t *, cusparseIndexBase_t *, cudaDataType *); + static auto func_ptr = LoadSymbol("cusparseSpVecGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, size, nnz, indices, values, idxType, idxBase, + valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpVecGetIndexBase( + const cusparseSpVecDescr_t spVecDescr, cusparseIndexBase_t *idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(const cusparseSpVecDescr_t, + cusparseIndexBase_t *); + static auto func_ptr = LoadSymbol("cusparseSpVecGetIndexBase"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpVecGetValues(const cusparseSpVecDescr_t spVecDescr, void **values) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(const cusparseSpVecDescr_t, void **); + static auto func_ptr = LoadSymbol("cusparseSpVecGetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpVecSetValues(cusparseSpVecDescr_t spVecDescr, void *values) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpVecDescr_t, void *); + static auto func_ptr = LoadSymbol("cusparseSpVecSetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateDnVec(cusparseDnVecDescr_t *dnVecDescr, int64_t size, + void *values, cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseDnVecDescr_t *, int64_t, void *, cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCreateDnVec"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr, size, values, valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroyDnVec(cusparseDnVecDescr_t dnVecDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseDnVecDescr_t); + static auto func_ptr = LoadSymbol("cusparseDestroyDnVec"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnVecGet(const cusparseDnVecDescr_t dnVecDescr, int64_t *size, + void **values, cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + const cusparseDnVecDescr_t, int64_t *, void **, cudaDataType *); + static auto func_ptr = LoadSymbol("cusparseDnVecGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr, size, values, valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnVecGetValues(const cusparseDnVecDescr_t dnVecDescr, void **values) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(const cusparseDnVecDescr_t, void **); + static auto func_ptr = LoadSymbol("cusparseDnVecGetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnVecSetValues(cusparseDnVecDescr_t dnVecDescr, void *values) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseDnVecDescr_t, void *); + static auto func_ptr = LoadSymbol("cusparseDnVecSetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr, values); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCoo(cusparseSpMatDescr_t *spMatDescr, + int64_t rows, int64_t cols, + int64_t nnz, void *cooRowInd, + void *cooColInd, void *cooValues, + cusparseIndexType_t cooIdxType, + cusparseIndexBase_t idxBase, + cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t *, int64_t, int64_t, int64_t, void *, void *, void *, + cusparseIndexType_t, cusparseIndexBase_t, cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCreateCoo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, cooRowInd, cooColInd, cooValues, + cooIdxType, idxBase, valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsr( + cusparseSpMatDescr_t *spMatDescr, int64_t rows, int64_t cols, int64_t nnz, + void *csrRowOffsets, void *csrColInd, void *csrValues, + cusparseIndexType_t csrRowOffsetsType, cusparseIndexType_t csrColIndType, + cusparseIndexBase_t idxBase, cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t *, int64_t, int64_t, int64_t, void *, void *, void *, + cusparseIndexType_t, cusparseIndexType_t, cusparseIndexBase_t, + cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCreateCsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, csrRowOffsets, csrColInd, + csrValues, csrRowOffsetsType, csrColIndType, idxBase, + valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCooAoS( + cusparseSpMatDescr_t *spMatDescr, int64_t rows, int64_t cols, int64_t nnz, + void *cooInd, void *cooValues, cusparseIndexType_t cooIdxType, + cusparseIndexBase_t idxBase, cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t *, int64_t, int64_t, int64_t, void *, void *, + cusparseIndexType_t, cusparseIndexBase_t, cudaDataType); + static auto func_ptr = LoadSymbol("cusparseCreateCooAoS"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, cooInd, cooValues, cooIdxType, + idxBase, valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroySpMat(cusparseSpMatDescr_t spMatDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t); + static auto func_ptr = LoadSymbol("cusparseDestroySpMat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCooGet(const cusparseSpMatDescr_t spMatDescr, int64_t *rows, + int64_t *cols, int64_t *nnz, + void **cooRowInd, // COO row indices + void **cooColInd, // COO column indices + void **cooValues, // COO values + cusparseIndexType_t *idxType, cusparseIndexBase_t *idxBase, + cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + const cusparseSpMatDescr_t, int64_t *, int64_t *, int64_t *, void **, + void **, void **, cusparseIndexType_t *, cusparseIndexBase_t *, + cudaDataType *); + static auto func_ptr = LoadSymbol("cusparseCooGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, cooRowInd, cooColInd, cooValues, + idxType, idxBase, valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCooAoSGet(const cusparseSpMatDescr_t spMatDescr, int64_t *rows, + int64_t *cols, int64_t *nnz, + void **cooInd, // COO indices + void **cooValues, // COO values + cusparseIndexType_t *idxType, cusparseIndexBase_t *idxBase, + cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + const cusparseSpMatDescr_t, int64_t *, int64_t *, int64_t *, void **, + void **, cusparseIndexType_t *, cusparseIndexBase_t *, cudaDataType *); + static auto func_ptr = LoadSymbol("cusparseCooAoSGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, cooInd, cooValues, idxType, + idxBase, valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrGet( + const cusparseSpMatDescr_t spMatDescr, int64_t *rows, int64_t *cols, + int64_t *nnz, void **csrRowOffsets, void **csrColInd, void **csrValues, + cusparseIndexType_t *csrRowOffsetsType, cusparseIndexType_t *csrColIndType, + cusparseIndexBase_t *idxBase, cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + const cusparseSpMatDescr_t, int64_t *, int64_t *, int64_t *, void **, + void **, void **, cusparseIndexType_t *, cusparseIndexType_t *, + cusparseIndexBase_t *, cudaDataType *); + static auto func_ptr = LoadSymbol("cusparseCsrGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, csrRowOffsets, csrColInd, + csrValues, csrRowOffsetsType, csrColIndType, idxBase, + valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMatGetFormat( + const cusparseSpMatDescr_t spMatDescr, cusparseFormat_t *format) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(const cusparseSpMatDescr_t, + cusparseFormat_t *); + static auto func_ptr = LoadSymbol("cusparseSpMatGetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, format); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMatGetIndexBase( + const cusparseSpMatDescr_t spMatDescr, cusparseIndexBase_t *idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(const cusparseSpMatDescr_t, + cusparseIndexBase_t *); + static auto func_ptr = LoadSymbol("cusparseSpMatGetIndexBase"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpMatGetValues(const cusparseSpMatDescr_t spMatDescr, void **values) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(const cusparseSpMatDescr_t, void **); + static auto func_ptr = LoadSymbol("cusparseSpMatGetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpMatSetValues(cusparseSpMatDescr_t spMatDescr, void *values) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, void *); + static auto func_ptr = LoadSymbol("cusparseSpMatSetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpMatSetStridedBatch(cusparseSpMatDescr_t spMatDescr, int batchCount) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, int); + static auto func_ptr = LoadSymbol("cusparseSpMatSetStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, batchCount); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMatGetStridedBatch( + const cusparseSpMatDescr_t spMatDescr, int *batchCount) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(const cusparseSpMatDescr_t, int *); + static auto func_ptr = LoadSymbol("cusparseSpMatGetStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, batchCount); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateDnMat( + cusparseDnMatDescr_t *dnMatDescr, int64_t rows, int64_t cols, int64_t ld, + void *values, cudaDataType valueType, cusparseOrder_t order) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseDnMatDescr_t *, int64_t, int64_t, int64_t, void *, cudaDataType, + cusparseOrder_t); + static auto func_ptr = LoadSymbol("cusparseCreateDnMat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, rows, cols, ld, values, valueType, order); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroyDnMat(cusparseDnMatDescr_t dnMatDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseDnMatDescr_t); + static auto func_ptr = LoadSymbol("cusparseDestroyDnMat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr); +} + +cusparseStatus_t CUSPARSEAPI cusparseDnMatGet( + const cusparseDnMatDescr_t dnMatDescr, int64_t *rows, int64_t *cols, + int64_t *ld, void **values, cudaDataType *type, cusparseOrder_t *order) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + const cusparseDnMatDescr_t, int64_t *, int64_t *, int64_t *, void **, + cudaDataType *, cusparseOrder_t *); + static auto func_ptr = LoadSymbol("cusparseDnMatGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, rows, cols, ld, values, type, order); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnMatGetValues(const cusparseDnMatDescr_t dnMatDescr, void **values) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(const cusparseDnMatDescr_t, void **); + static auto func_ptr = LoadSymbol("cusparseDnMatGetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnMatSetValues(cusparseDnMatDescr_t dnMatDescr, void *values) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseDnMatDescr_t, void *); + static auto func_ptr = LoadSymbol("cusparseDnMatSetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, values); +} + +cusparseStatus_t CUSPARSEAPI cusparseDnMatSetStridedBatch( + cusparseDnMatDescr_t dnMatDescr, int batchCount, int64_t batchStride) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseDnMatDescr_t, int, int64_t); + static auto func_ptr = LoadSymbol("cusparseDnMatSetStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, batchCount, batchStride); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnMatGetStridedBatch(const cusparseDnMatDescr_t dnMatDescr, + int *batchCount, int64_t *batchStride) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(const cusparseDnMatDescr_t, + int *, int64_t *); + static auto func_ptr = LoadSymbol("cusparseDnMatGetStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, batchCount, batchStride); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpVV(cusparseHandle_t handle, cusparseOperation_t opX, + const cusparseSpVecDescr_t vecX, const cusparseDnVecDescr_t vecY, + void *result, cudaDataType computeType, void *externalBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cusparseSpVecDescr_t, + const cusparseDnVecDescr_t, void *, cudaDataType, void *); + static auto func_ptr = LoadSymbol("cusparseSpVV"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opX, vecX, vecY, result, computeType, externalBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpVV_bufferSize( + cusparseHandle_t handle, cusparseOperation_t opX, + const cusparseSpVecDescr_t vecX, const cusparseDnVecDescr_t vecY, + const void *result, cudaDataType computeType, size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const cusparseSpVecDescr_t, + const cusparseDnVecDescr_t, const void *, cudaDataType, size_t *); + static auto func_ptr = LoadSymbol("cusparseSpVV_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opX, vecX, vecY, result, computeType, bufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMV( + cusparseHandle_t handle, cusparseOperation_t opA, const void *alpha, + const cusparseSpMatDescr_t matA, const cusparseDnVecDescr_t vecX, + const void *beta, const cusparseDnVecDescr_t vecY, cudaDataType computeType, + cusparseSpMVAlg_t alg, void *externalBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const void *, + const cusparseSpMatDescr_t, const cusparseDnVecDescr_t, const void *, + const cusparseDnVecDescr_t, cudaDataType, cusparseSpMVAlg_t, void *); + static auto func_ptr = LoadSymbol("cusparseSpMV"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, alpha, matA, vecX, beta, vecY, computeType, alg, + externalBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMV_bufferSize( + cusparseHandle_t handle, cusparseOperation_t opA, const void *alpha, + const cusparseSpMatDescr_t matA, const cusparseDnVecDescr_t vecX, + const void *beta, const cusparseDnVecDescr_t vecY, cudaDataType computeType, + cusparseSpMVAlg_t alg, size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const void *, + const cusparseSpMatDescr_t, const cusparseDnVecDescr_t, const void *, + const cusparseDnVecDescr_t, cudaDataType, cusparseSpMVAlg_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseSpMV_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, alpha, matA, vecX, beta, vecY, computeType, alg, + bufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMM( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, const cusparseSpMatDescr_t matA, + const cusparseDnMatDescr_t matB, const void *beta, + cusparseDnMatDescr_t matC, cudaDataType computeType, cusparseSpMMAlg_t alg, + void *externalBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + const cusparseSpMatDescr_t, const cusparseDnMatDescr_t, const void *, + cusparseDnMatDescr_t, cudaDataType, cusparseSpMMAlg_t, void *); + static auto func_ptr = LoadSymbol("cusparseSpMM"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + alg, externalBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMM_bufferSize( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, const cusparseSpMatDescr_t matA, + const cusparseDnMatDescr_t matB, const void *beta, + cusparseDnMatDescr_t matC, cudaDataType computeType, cusparseSpMMAlg_t alg, + size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + const cusparseSpMatDescr_t, const cusparseDnMatDescr_t, const void *, + cusparseDnMatDescr_t, cudaDataType, cusparseSpMMAlg_t, size_t *); + static auto func_ptr = LoadSymbol("cusparseSpMM_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + alg, bufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseConstrainedGeMM( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, const void *beta, + cusparseSpMatDescr_t matC, cudaDataType computeType, void *externalBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + const cusparseDnMatDescr_t, const cusparseDnMatDescr_t, const void *, + cusparseSpMatDescr_t, cudaDataType, void *); + static auto func_ptr = LoadSymbol("cusparseConstrainedGeMM"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + externalBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseConstrainedGeMM_bufferSize( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, const void *beta, + cusparseSpMatDescr_t matC, cudaDataType computeType, size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + const cusparseDnMatDescr_t, const cusparseDnMatDescr_t, const void *, + cusparseSpMatDescr_t, cudaDataType, size_t *); + static auto func_ptr = + LoadSymbol("cusparseConstrainedGeMM_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + bufferSize); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cusparse_9_0.inc b/tensorflow/stream_executor/cuda/cusparse_9_0.inc index bb82f3ebb46..bf3037257d8 100644 --- a/tensorflow/stream_executor/cuda/cusparse_9_0.inc +++ b/tensorflow/stream_executor/cuda/cusparse_9_0.inc @@ -4887,7 +4887,7 @@ cusparseStatus_t CUSPARSEAPI cusparseDcsr2csr_compress( int m, // number of rows int n, const cusparseMatDescr_t descra, const double *csrValA, // csr values array-the elements which are below a - // certain tolerance will be removed + // certain tolerance will be remvoed const int *csrColIndA, const int *csrRowPtrA, // corresponding input noncompressed row pointer int nnzA, const int *nnzPerRow, double *csrValC, int *csrColIndC, @@ -4907,7 +4907,7 @@ cusparseStatus_t CUSPARSEAPI cusparseCcsr2csr_compress( int m, // number of rows int n, const cusparseMatDescr_t descra, const cuComplex *csrValA, // csr values array-the elements which are below - // a certain tolerance will be removed + // a certain tolerance will be remvoed const int *csrColIndA, const int *csrRowPtrA, // corresponding input noncompressed row pointer int nnzA, const int *nnzPerRow, cuComplex *csrValC, int *csrColIndC, @@ -4926,8 +4926,9 @@ cusparseStatus_t CUSPARSEAPI cusparseZcsr2csr_compress( cusparseHandle_t handle, int m, // number of rows int n, const cusparseMatDescr_t descra, - const cuDoubleComplex *csrValA, // csr values array-the elements which are - // below a certain tolerance will be removed + const cuDoubleComplex + *csrValA, // csr values array-the elements which are + // below a certain tolerance will be remvoed const int *csrColIndA, const int *csrRowPtrA, // corresponding input noncompressed row pointer int nnzA, const int *nnzPerRow, cuDoubleComplex *csrValC, int *csrColIndC, diff --git a/tensorflow/stream_executor/cuda/cusparse_stub.cc b/tensorflow/stream_executor/cuda/cusparse_stub.cc index 4b941bc1751..b2f76fe6d5c 100644 --- a/tensorflow/stream_executor/cuda/cusparse_stub.cc +++ b/tensorflow/stream_executor/cuda/cusparse_stub.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cusparse.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/platform/dso_loader.h" @@ -52,8 +53,14 @@ cusparseStatus_t GetSymbolNotFoundError() { #if CUDA_VERSION < 9020 #include "tensorflow/stream_executor/cuda/cusparse_9_0.inc" -#elif CUDA_VERSION < 10010 +#elif CUDA_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cusparse_10_0.inc" -#else +#elif CUDA_VERSION == 10010 #include "tensorflow/stream_executor/cuda/cusparse_10_1.inc" +#elif CUDA_VERSION == 10020 +#include "tensorflow/stream_executor/cuda/cusparse_10_2.inc" +#elif CUDA_VERSION == 11000 +#include "tensorflow/stream_executor/cuda/cusparse_11_0.inc" +#else +#error "We don't have a wrapper for this version." #endif From 4c0d6b7d516b659294798cd2903dc5164cb5fd2c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 13:47:56 -0700 Subject: [PATCH 256/492] Changed TPU embedding load and retrieve ops to checked-in generated code. PiperOrigin-RevId: 301887553 Change-Id: Ib6e042e73cd4a0214239175a4e86b090a0817f12 --- tensorflow/cc/BUILD | 1 + tensorflow/core/BUILD | 3 + ...dientDescentParametersGradAccumDebug.pbtxt | 24 + ...dientDescentParametersGradAccumDebug.pbtxt | 23 + .../ops/tpu_embedding_load_retrieve_ops.cc | 575 ++++++++++++++++++ tensorflow/core/ops/tpu_embedding_ops.cc | 63 -- ...embedding_optimization_parameters_utils.cc | 302 +++------ ..._embedding_optimization_parameters_utils.h | 49 +- tensorflow/python/BUILD | 1 + tensorflow/python/__init__.py | 1 + .../tools/api/golden/v1/tensorflow.pbtxt | 4 - .../api/golden/v1/tensorflow.raw_ops.pbtxt | 8 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 8 + 13 files changed, 762 insertions(+), 300 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.pbtxt create mode 100644 tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 5251ccdf1c0..022989bfbf2 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -632,6 +632,7 @@ tf_gen_op_wrappers_cc( "tpu_configuration_ops", "tpu_cross_replica_ops", "tpu_embedding_ops", + "tpu_embedding_load_retrieve_ops", "tpu_functional_ops", "tpu_heartbeat_ops", "tpu_host_compute_ops", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index df502b675b0..7b995af7656 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -723,6 +723,7 @@ tf_gen_op_libs( "tpu_configuration_ops", "tpu_cross_replica_ops", "tpu_embedding_ops", + "tpu_embedding_load_retrieve_ops", "tpu_functional_ops", "tpu_heartbeat_ops", "tpu_host_compute_ops", @@ -735,6 +736,7 @@ tf_gen_op_libs( ":lib", ":lib_proto_parsing", ":protos_all_cc", + "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", @@ -894,6 +896,7 @@ cc_library( ":tpu_configuration_ops_op_lib", ":tpu_cross_replica_ops_op_lib", ":tpu_embedding_ops_op_lib", + ":tpu_embedding_load_retrieve_ops_op_lib", ":tpu_functional_ops_op_lib", ":tpu_heartbeat_ops_op_lib", ":tpu_host_compute_ops_op_lib", diff --git a/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.pbtxt b/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.pbtxt new file mode 100644 index 00000000000..1cd84cff202 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.pbtxt @@ -0,0 +1,24 @@ +op { + graph_op_name: "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug" + visibility: HIDDEN + in_arg { + name: "parameters" + description: <

-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterAdd(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterAdd", - Input: []tf.Input{ - resource, indices, updates, - }, - } - return scope.AddOperation(opspec) -} - // ConfigureDistributedTPUAttr is an optional argument to ConfigureDistributedTPU. type ConfigureDistributedTPUAttr func(optionalAttr) @@ -28087,6 +27827,70 @@ func FusedBatchNormGradV3(scope *Scope, y_backprop tf.Output, x tf.Output, scale return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } +// Returns the number of records this Reader has produced. +// +// This is the same as the number of ReaderRead executions that have +// succeeded. +// +// Arguments: +// reader_handle: Handle to a Reader. +func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_produced tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderNumRecordsProducedV2", + Input: []tf.Input{ + reader_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeRawAttr is an optional argument to DecodeRaw. +type DecodeRawAttr func(optionalAttr) + +// DecodeRawLittleEndian sets the optional little_endian attribute to value. +// +// value: Whether the input `bytes` are in little-endian order. +// Ignored for `out_type` values that are stored in a single byte like +// `uint8`. +// If not specified, defaults to true +func DecodeRawLittleEndian(value bool) DecodeRawAttr { + return func(m optionalAttr) { + m["little_endian"] = value + } +} + +// Reinterpret the bytes of a string as a vector of numbers. +// +// Arguments: +// bytes: All the elements must have the same length. +// +// +// Returns A Tensor with one more dimension than the input `bytes`. The +// added dimension will have size equal to the length of the elements +// of `bytes` divided by the number of bytes to represent `out_type`. +func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeRaw", + Input: []tf.Input{ + bytes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Gather ragged slices from `params` axis `0` according to `indices`. // // Outputs a `RaggedTensor` output composed from `output_dense_values` and @@ -30280,6 +30084,143 @@ func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } +// Returns the cardinality of `input_dataset`. +// +// Returns the cardinality of `input_dataset`. +// +// Arguments: +// input_dataset: A variant tensor representing the dataset to return cardinality for. +// +// Returns The cardinality of `input_dataset`. Named constants are used to represent +// infinite and unknown cardinality. +func DatasetCardinality(scope *Scope, input_dataset tf.Output) (cardinality tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DatasetCardinality", + Input: []tf.Input{ + input_dataset, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that emits each dim-0 slice of `components` once. +func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "TensorSliceDataset", + Input: []tf.Input{ + tf.OutputList(components), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to RetrieveTPUEmbeddingMDLAdagradLightParameters. +type RetrieveTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingMDLAdagradLightParametersTableId(value int64) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMDLAdagradLightParametersTableName(value string) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingMDLAdagradLightParametersConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMDLAdagradLightParametersConfig(value string) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve MDL Adagrad Light embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns: +// parameters: Parameter parameters updated by the MDL Adagrad Light optimization algorithm. +// accumulators: Parameter accumulators updated by the MDL Adagrad Light optimization algorithm. +// weights: Parameter weights updated by the MDL Adagrad Light optimization algorithm. +// benefits: Parameter benefits updated by the MDL Adagrad Light optimization algorithm. +func RetrieveTPUEmbeddingMDLAdagradLightParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMDLAdagradLightParametersAttr) (parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingMDLAdagradLightParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// Adds sparse updates to the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] += updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] += updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions add. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterAdd(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterAdd", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + // This op consumes a lock created by `MutexLock`. // // This op exists to consume a tensor created by `MutexLock` (other than @@ -31866,8 +31807,6 @@ type RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) // RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -33023,8 +32962,6 @@ type LoadTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) // LoadTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingCenteredRMSPropParametersTableId(value int64) LoadTPUEmbeddingCenteredRMSPropParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -34124,121 +34061,6 @@ func SparseSliceGrad(scope *Scope, backprop_val_grad tf.Output, input_indices tf return op.Output(0) } -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. -// -// More formally, let -// -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, -// -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ -// -// Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. -// -// Returns Computed Precision at `k` as a `bool Tensor`. -func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (precision tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"k": k} - opspec := tf.OpSpec{ - Type: "InTopK", - Input: []tf.Input{ - predictions, targets, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns x - y element-wise. -// -// *NOTE*: `Subtract` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sub", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D. -type FusedResizeAndPadConv2DAttr func(optionalAttr) - -// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr { - return func(m optionalAttr) { - m["resize_align_corners"] = value - } -} - -// Performs a resize and padding as a preprocess during a convolution. -// -// It's often possible to do spatial transformations more efficiently as part of -// the packing stage of a convolution, so this op allows for an optimized -// implementation where these stages are fused together. This prevents the need to -// write out the intermediate results as whole tensors, reducing memory pressure, -// and we can get some latency gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and defaults to -// 'NHWC' order. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedResizeAndPadConv2D", - Input: []tf.Input{ - input, size, paddings, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns the truth value of x OR y element-wise. // // *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting @@ -34604,8 +34426,6 @@ type RetrieveTPUEmbeddingProximalAdagradParametersAttr func(optionalAttr) // RetrieveTPUEmbeddingProximalAdagradParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func RetrieveTPUEmbeddingProximalAdagradParametersTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -36735,6 +36555,121 @@ func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Outpu return op.Output(0) } +// Says whether the targets are in the top `K` predictions. +// +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. +// +// More formally, let +// +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// +// Arguments: +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. +// +// Returns Computed Precision at `k` as a `bool Tensor`. +func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (precision tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"k": k} + opspec := tf.OpSpec{ + Type: "InTopK", + Input: []tf.Input{ + predictions, targets, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x - y element-wise. +// +// *NOTE*: `Subtract` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sub", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D. +type FusedResizeAndPadConv2DAttr func(optionalAttr) + +// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr { + return func(m optionalAttr) { + m["resize_align_corners"] = value + } +} + +// Performs a resize and padding as a preprocess during a convolution. +// +// It's often possible to do spatial transformations more efficiently as part of +// the packing stage of a convolution, so this op allows for an optimized +// implementation where these stages are fused together. This prevents the need to +// write out the intermediate results as whole tensors, reducing memory pressure, +// and we can get some latency gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and defaults to +// 'NHWC' order. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedResizeAndPadConv2D", + Input: []tf.Input{ + input, size, paddings, filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the product along segments of a tensor. // // Read @@ -37531,112 +37466,6 @@ func IFFT(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// Returns the next record (key, value pair) produced by a Reader. -// -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). -// -// Arguments: -// reader_handle: Handle to a Reader. -// queue_handle: Handle to a Queue, with string work items. -// -// Returns: -// key: A scalar. -// value: A scalar. -func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderReadV2", - Input: []tf.Input{ - reader_handle, queue_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// CumprodAttr is an optional argument to Cumprod. -type CumprodAttr func(optionalAttr) - -// CumprodExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumprod. -// If not specified, defaults to false -func CumprodExclusive(value bool) CumprodAttr { - return func(m optionalAttr) { - m["exclusive"] = value - } -} - -// CumprodReverse sets the optional reverse attribute to value. -// -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumprodReverse(value bool) CumprodAttr { - return func(m optionalAttr) { - m["reverse"] = value - } -} - -// Compute the cumulative product of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumprod, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is -// performed instead: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] -// ``` -// -// By setting the `reverse` kwarg to `True`, the cumprod is performed in the -// opposite direction: -// -// ```python -// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] -// ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] -// ``` -// -// Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Cumprod", - Input: []tf.Input{ - x, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // CollectiveGatherAttr is an optional argument to CollectiveGather. type CollectiveGatherAttr func(optionalAttr) @@ -38226,8 +38055,6 @@ type RetrieveTPUEmbeddingADAMParametersAttr func(optionalAttr) // RetrieveTPUEmbeddingADAMParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func RetrieveTPUEmbeddingADAMParametersTableId(value int64) RetrieveTPUEmbeddingADAMParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -38330,8 +38157,6 @@ type RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr func(optionalAttr) // RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -38890,8 +38715,6 @@ type RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) // RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -39435,8 +39258,6 @@ type LoadTPUEmbeddingMomentumParametersAttr func(optionalAttr) // LoadTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingMomentumParametersTableId(value int64) LoadTPUEmbeddingMomentumParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -39654,8 +39475,6 @@ type RetrieveTPUEmbeddingAdadeltaParametersAttr func(optionalAttr) // RetrieveTPUEmbeddingAdadeltaParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func RetrieveTPUEmbeddingAdadeltaParametersTableId(value int64) RetrieveTPUEmbeddingAdadeltaParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -40157,8 +39976,6 @@ type RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) // RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -40384,6 +40201,159 @@ func UnicodeScript(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// CropAndResizeAttr is an optional argument to CropAndResize. +type CropAndResizeAttr func(optionalAttr) + +// CropAndResizeMethod sets the optional method attribute to value. +// +// value: A string specifying the sampling method for resizing. It can be either +// `"bilinear"` or `"nearest"` and default to `"bilinear"`. Currently two sampling +// methods are supported: Bilinear and Nearest Neighbor. +// If not specified, defaults to "bilinear" +func CropAndResizeMethod(value string) CropAndResizeAttr { + return func(m optionalAttr) { + m["method"] = value + } +} + +// CropAndResizeExtrapolationValue sets the optional extrapolation_value attribute to value. +// +// value: Value used for extrapolation, when applicable. +// If not specified, defaults to 0 +func CropAndResizeExtrapolationValue(value float32) CropAndResizeAttr { + return func(m optionalAttr) { + m["extrapolation_value"] = value + } +} + +// Extracts crops from the input image tensor and resizes them. +// +// Extracts crops from the input image tensor and resizes them using bilinear +// sampling or nearest neighbor sampling (possibly with aspect ratio change) to a +// common output size specified by `crop_size`. This is more general than the +// `crop_to_bounding_box` op which extracts a fixed size slice from the input image +// and does not allow resizing or aspect ratio change. +// +// Returns a tensor with `crops` from the input `image` at positions defined at the +// bounding box locations in `boxes`. The cropped boxes are all resized (with +// bilinear or nearest neighbor interpolation) to a fixed +// `size = [crop_height, crop_width]`. The result is a 4-D tensor +// `[num_boxes, crop_height, crop_width, depth]`. The resizing is corner aligned. +// In particular, if `boxes = [[0, 0, 1, 1]]`, the method will give identical +// results to using `tf.image.resize_bilinear()` or +// `tf.image.resize_nearest_neighbor()`(depends on the `method` argument) with +// `align_corners=True`. +// +// Arguments: +// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +// Both `image_height` and `image_width` need to be positive. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1]` in image height coordinates. We do allow `y1` > `y2`, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// crop_size: A 1-D tensor of 2 elements, `size = [crop_height, crop_width]`. All +// cropped image patches are resized to this size. The aspect ratio of the image +// content is not preserved. Both `crop_height` and `crop_width` need to be +// positive. +// +// Returns A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +func CropAndResize(scope *Scope, image tf.Output, boxes tf.Output, box_ind tf.Output, crop_size tf.Output, optional ...CropAndResizeAttr) (crops tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CropAndResize", + Input: []tf.Input{ + image, boxes, box_ind, crop_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. +type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) + +// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to {i:1 i:1 i:1 i:1} +func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of depthwise convolution with respect to the filter. +// +// Arguments: +// input: 4-D with shape based on `data_format`. For example, if +// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, +// in_width, in_channels]` tensor. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DepthwiseConv2dNativeBackpropFilter", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Creates a dataset that zips together `input_datasets`. // // The elements of the resulting dataset are created by zipping corresponding @@ -40712,158 +40682,105 @@ func SendTPUEmbeddingGradients(scope *Scope, inputs []tf.Output, learning_rates return scope.AddOperation(opspec) } -// Computes softmax cross entropy cost and gradients to backpropagate. -// -// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept -// a matrix of label probabilities, but rather a single label per row -// of features. This label is considered to have probability 1.0 for the -// given row. -// -// Inputs are the logits, not probabilities. -// -// Arguments: -// features: batch_size x num_classes matrix -// labels: batch_size vector with values in [0, num_classes). -// This is the label for the given minibatch entry. -// -// Returns: -// loss: Per example loss (batch_size vector). -// backprop: backpropagated gradients (batch_size x num_classes matrix). -func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSoftmaxCrossEntropyWithLogits", - Input: []tf.Input{ - features, labels, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} +// MapPeekAttr is an optional argument to MapPeek. +type MapPeekAttr func(optionalAttr) -// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. -type ResourceApplyProximalGradientDescentAttr func(optionalAttr) - -// ResourceApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. +// MapPeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalGradientDescentUseLocking(value bool) ResourceApplyProximalGradientDescentAttr { +// REQUIRES: value >= 0 +func MapPeekCapacity(value int64) MapPeekAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["capacity"] = value } } -// Update '*var' as FOBOS algorithm with fixed learning rate. +// MapPeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// prox_v = var - alpha * delta -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// REQUIRES: value >= 0 +func MapPeekMemoryLimit(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapPeekContainer(value string) MapPeekAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapPeekSharedName(value string) MapPeekAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op peeks at the values at the specified key. If the // -// Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// delta: The change. -// -// Returns the created operation. -func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, delta tf.Output, optional ...ResourceApplyProximalGradientDescentAttr) (o *tf.Operation) { +// underlying container does not contain this key +// this op will block until it does. +func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyProximalGradientDescent", + Type: "MapPeek", Input: []tf.Input{ - var_, alpha, l1, l2, delta, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Worker heartbeat op. -// -// Heartbeats may be sent periodically to indicate the coordinator is still active, -// to retrieve the current worker status and to expedite shutdown when necessary. -// -// Arguments: -// request: A string tensor containing a serialized WorkerHeartbeatRequest -// -// Returns A string tensor containing a serialized WorkerHeartbeatResponse -func WorkerHeartbeat(scope *Scope, request tf.Output) (response tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "WorkerHeartbeat", - Input: []tf.Input{ - request, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the item in the list with the given index. -// -// input_handle: the list -// index: the position in the list from which an element will be retrieved -// item: the element at that position -// -// -func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, element_shape tf.Output, element_dtype tf.DataType) (item tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "TensorListGetItem", - Input: []tf.Input{ - input_handle, index, element_shape, + key, indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapPeek", err) + return + } + return values } -// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug. -type RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAttr) +// RetrieveTPUEmbeddingCenteredRMSPropParametersAttr is an optional argument to RetrieveTPUEmbeddingCenteredRMSPropParameters. +type RetrieveTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) -// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// RetrieveTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { +func RetrieveTPUEmbeddingCenteredRMSPropParametersTableId(value int64) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { return func(m optionalAttr) { m["table_id"] = value } } -// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// RetrieveTPUEmbeddingCenteredRMSPropParametersTableName sets the optional table_name attribute to value. // If not specified, defaults to "" -func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { +func RetrieveTPUEmbeddingCenteredRMSPropParametersTableName(value string) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { return func(m optionalAttr) { m["table_name"] = value } } -// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugConfig sets the optional config attribute to value. +// RetrieveTPUEmbeddingCenteredRMSPropParametersConfig sets the optional config attribute to value. // If not specified, defaults to "" -func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { +func RetrieveTPUEmbeddingCenteredRMSPropParametersConfig(value string) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { return func(m optionalAttr) { m["config"] = value } } -// Retrieve proximal Adagrad embedding parameters with debug support. +// Retrieve centered RMSProp embedding parameters. // // An op that retrieves optimization parameters from embedding to host // memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up @@ -40871,10 +40788,11 @@ func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugConfig(value str // used to retrieve updated parameters before saving a checkpoint. // // Returns: -// parameters: Parameter parameters updated by the proximal Adagrad optimization algorithm. -// accumulators: Parameter accumulators updated by the proximal Adagrad optimization algorithm. -// gradient_accumulators: Parameter gradient_accumulators updated by the proximal Adagrad optimization algorithm. -func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output) { +// parameters: Parameter parameters updated by the centered RMSProp optimization algorithm. +// ms: Parameter ms updated by the centered RMSProp optimization algorithm. +// mom: Parameter mom updated by the centered RMSProp optimization algorithm. +// mg: Parameter mg updated by the centered RMSProp optimization algorithm. +func RetrieveTPUEmbeddingCenteredRMSPropParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingCenteredRMSPropParametersAttr) (parameters tf.Output, ms tf.Output, mom tf.Output, mg tf.Output) { if scope.Err() != nil { return } @@ -40883,144 +40801,12 @@ func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, n a(attrs) } opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", + Type: "RetrieveTPUEmbeddingCenteredRMSPropParameters", Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Records the latency of producing `input_dataset` elements in a StatsAggregator. -func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "LatencyStatsDataset", - Input: []tf.Input{ - input_dataset, tag, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the power of one value to another. -// -// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for -// corresponding elements in `x` and `y`. For example: -// -// ``` -// # tensor 'x' is [[2, 2]], [3, 3]] -// # tensor 'y' is [[8, 16], [2, 3]] -// tf.pow(x, y) ==> [[256, 65536], [9, 27]] -// ``` -func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Pow", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Element-wise multiplication of a sparse matrix with a dense tensor. -// -// Returns a sparse matrix. -// -// The dense tensor `b` may be either a scalar; otherwise `a` must be a rank-3 -// `SparseMatrix`; in this case `b` must be shaped `[batch_size, 1, 1]` and the -// multiply operation broadcasts. -// -// **NOTE** even if `b` is zero, the sparsity structure of the output does not -// change. -// -// Arguments: -// a: A CSRSparseMatrix. -// b: A dense tensor. -// -// Returns A dense output tensor. -func SparseMatrixMul(scope *Scope, a tf.Output, b tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseMatrixMul", - Input: []tf.Input{ - a, b, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the element-wise sum of a list of tensors. -// -// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not -// wait for all of its inputs to be ready before beginning to sum. This can -// save memory if inputs are ready at different times, since minimum temporary -// storage is proportional to the output size rather than the inputs size. -// -// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. -// -// Returns a `Tensor` of same shape and type as the elements of `inputs`. -// -// Arguments: -// inputs: A list of `Tensor` objects, each with same shape and type. -// shape: Shape of elements of `inputs`. -func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shape": shape} - opspec := tf.OpSpec{ - Type: "AccumulateNV2", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// An op enabling differentiation of TPU Embeddings. -// -// This op simply returns its first input, which is assumed to have been sliced -// from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of -// this op, and its first argument being a trainable Variable, enables automatic -// differentiation of graphs containing embeddings via the TPU Embedding Python -// libraries. -// -// Arguments: -// embedding_variable: A trainable variable, enabling optimizers to find this op. -// sliced_activations: The embedding activations Tensor to return. -// table_id: The id of the table in the embedding layer configuration from which -// these activations were computed. -// lookup_id: Identifier of the set of embedding indices which produced these -// activations. -func TPUEmbeddingActivations(scope *Scope, embedding_variable tf.Output, sliced_activations tf.Output, table_id int64, lookup_id int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"table_id": table_id, "lookup_id": lookup_id} - opspec := tf.OpSpec{ - Type: "TPUEmbeddingActivations", - Input: []tf.Input{ - embedding_variable, sliced_activations, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } // Transforms a vector of brain.Example protos (as strings) into typed tensors. @@ -41102,824 +40888,6 @@ func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_ke return sparse_indices, sparse_values, sparse_shapes, dense_values } -// MapPeekAttr is an optional argument to MapPeek. -type MapPeekAttr func(optionalAttr) - -// MapPeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapPeekCapacity(value int64) MapPeekAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapPeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapPeekMemoryLimit(value int64) MapPeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapPeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapPeekContainer(value string) MapPeekAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapPeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapPeekSharedName(value string) MapPeekAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op peeks at the values at the specified key. If the -// -// underlying container does not contain this key -// this op will block until it does. -func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapPeek", - Input: []tf.Input{ - key, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapPeek", err) - return - } - return values -} - -// RetrieveTPUEmbeddingCenteredRMSPropParametersAttr is an optional argument to RetrieveTPUEmbeddingCenteredRMSPropParameters. -type RetrieveTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingCenteredRMSPropParametersTableId(value int64) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingCenteredRMSPropParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingCenteredRMSPropParametersTableName(value string) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// RetrieveTPUEmbeddingCenteredRMSPropParametersConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingCenteredRMSPropParametersConfig(value string) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Retrieve centered RMSProp embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns: -// parameters: Parameter parameters updated by the centered RMSProp optimization algorithm. -// ms: Parameter ms updated by the centered RMSProp optimization algorithm. -// mom: Parameter mom updated by the centered RMSProp optimization algorithm. -// mg: Parameter mg updated by the centered RMSProp optimization algorithm. -func RetrieveTPUEmbeddingCenteredRMSPropParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingCenteredRMSPropParametersAttr) (parameters tf.Output, ms tf.Output, mom tf.Output, mg tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingCenteredRMSPropParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// Computes the LSTM cell backward propagation for the entire time sequence. -// -// This implementation is to be used in conjunction of BlockLSTMV2. -// -// Arguments: -// seq_len_max: Maximum time length actually used by this input. Outputs are padded -// with zeros beyond this length. -// x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs). -// cs_prev: Value of the initial cell state. -// h_prev: Initial output of cell (to be used for peephole). -// w: The weight matrix. -// wci: The weight matrix for input gate peephole connection. -// wcf: The weight matrix for forget gate peephole connection. -// wco: The weight matrix for output gate peephole connection. -// b: The bias vector. -// i: The input gate over the whole time sequence. -// cs: The cell state before the tanh over the whole time sequence. -// f: The forget gate over the whole time sequence. -// o: The output gate over the whole time sequence. -// ci: The cell input over the whole time sequence. -// co: The cell after the tanh over the whole time sequence. -// h: The output h vector over the whole time sequence. -// cs_grad: The current gradient of cs. -// h_grad: The gradient of h vector. -// use_peephole: Whether to use peephole weights. -// -// Returns: -// x_grad: The gradient of x to be back-propped. -// cs_prev_grad: The gradient of cs_prev to be back-propped. -// h_prev_grad: The gradient of h_prev to be back-propped. -// w_grad: The gradient for w to be back-propped. -// wci_grad: The gradient for wci to be back-propped. -// wcf_grad: The gradient for wcf to be back-propped. -// wco_grad: The gradient for wco to be back-propped. -// b_grad: The gradient for w to be back-propped. -func BlockLSTMGradV2(scope *Scope, seq_len_max tf.Output, x tf.Output, cs_prev tf.Output, h_prev tf.Output, w tf.Output, wci tf.Output, wcf tf.Output, wco tf.Output, b tf.Output, i tf.Output, cs tf.Output, f tf.Output, o tf.Output, ci tf.Output, co tf.Output, h tf.Output, cs_grad tf.Output, h_grad tf.Output, use_peephole bool) (x_grad tf.Output, cs_prev_grad tf.Output, h_prev_grad tf.Output, w_grad tf.Output, wci_grad tf.Output, wcf_grad tf.Output, wco_grad tf.Output, b_grad tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"use_peephole": use_peephole} - opspec := tf.OpSpec{ - Type: "BlockLSTMGradV2", - Input: []tf.Input{ - seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, h, cs_grad, h_grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6), op.Output(7) -} - -// Returns the element-wise max of two SparseTensors. -// -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. -// -// Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. -// -// Returns: -// output_indices: 2-D. The indices of the output SparseTensor. -// output_values: 1-D. The values of the output SparseTensor. -func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSparseMaximum", - Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Computes the Bessel i1e function of `x` element-wise. -// -// Exponentially scaled modified Bessel function of order 0 defined as -// `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. -// -// This function is faster and numerically stabler than `bessel_i1(x)`. -func BesselI1e(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BesselI1e", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns a batched diagonal tensor with a given batched diagonal values. -// -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: -// -// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a -// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: -// -// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. -// -// For example: -// -// ``` -// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] -// -// and diagonal.shape = (2, 4) -// -// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] -// -// which has shape (2, 4, 4) -// ``` -// -// Arguments: -// diagonal: Rank `k`, where `k >= 1`. -// -// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. -func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixDiag", - Input: []tf.Input{ - diagonal, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. -type StatelessTruncatedNormalAttr func(optionalAttr) - -// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom values from a truncated normal distribution. -// -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// -// Returns Random values with specified shape. -func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatelessTruncatedNormal", - Input: []tf.Input{ - shape, seed, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. -type StatelessRandomUniformAttr func(optionalAttr) - -// StatelessRandomUniformDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// -// Returns Random values with specified shape. -func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatelessRandomUniform", - Input: []tf.Input{ - shape, seed, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolGradGradV2Attr is an optional argument to MaxPoolGradGradV2. -type MaxPoolGradGradV2Attr func(optionalAttr) - -// MaxPoolGradGradV2DataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradGradV2DataFormat(value string) MaxPoolGradGradV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradGradV2Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradGradV2", - Input: []tf.Input{ - orig_input, orig_output, grad, ksize, strides, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingMomentumParametersAttr is an optional argument to RetrieveTPUEmbeddingMomentumParameters. -type RetrieveTPUEmbeddingMomentumParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingMomentumParametersTableId(value int64) RetrieveTPUEmbeddingMomentumParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingMomentumParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingMomentumParametersTableName(value string) RetrieveTPUEmbeddingMomentumParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// RetrieveTPUEmbeddingMomentumParametersConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingMomentumParametersConfig(value string) RetrieveTPUEmbeddingMomentumParametersAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Retrieve Momentum embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns: -// parameters: Parameter parameters updated by the Momentum optimization algorithm. -// momenta: Parameter momenta updated by the Momentum optimization algorithm. -func RetrieveTPUEmbeddingMomentumParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersAttr) (parameters tf.Output, momenta tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingMomentumParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// RecvAttr is an optional argument to Recv. -type RecvAttr func(optionalAttr) - -// RecvClientTerminated sets the optional client_terminated attribute to value. -// -// value: If set to true, this indicates that the node was added -// to the graph as a result of a client-side feed or fetch of Tensor data, -// in which case the corresponding send or recv is expected to be managed -// locally by the caller. -// If not specified, defaults to false -func RecvClientTerminated(value bool) RecvAttr { - return func(m optionalAttr) { - m["client_terminated"] = value - } -} - -// Receives the named tensor from send_device on recv_device. -// -// Arguments: -// -// tensor_name: The name of the tensor to receive. -// send_device: The name of the device sending the tensor. -// send_device_incarnation: The current incarnation of send_device. -// recv_device: The name of the device receiving the tensor. -// -// Returns The tensor to receive. -func Recv(scope *Scope, tensor_type tf.DataType, tensor_name string, send_device string, send_device_incarnation int64, recv_device string, optional ...RecvAttr) (tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"tensor_type": tensor_type, "tensor_name": tensor_name, "send_device": send_device, "send_device_incarnation": send_device_incarnation, "recv_device": recv_device} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Recv", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// OrderedMapStageAttr is an optional argument to OrderedMapStage. -type OrderedMapStageAttr func(optionalAttr) - -// OrderedMapStageCapacity sets the optional capacity attribute to value. -// -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapStageCapacity(value int64) OrderedMapStageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapStageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapStageMemoryLimit(value int64) OrderedMapStageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapStageContainer sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. -// If not specified, defaults to "" -func OrderedMapStageContainer(value string) OrderedMapStageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapStageSharedName sets the optional shared_name attribute to value. -// -// value: It is necessary to match this name to the matching Unstage Op. -// If not specified, defaults to "" -func OrderedMapStageSharedName(value string) OrderedMapStageAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Stage (key, values) in the underlying container which behaves like a ordered -// -// associative container. Elements are ordered by key. -// -// Arguments: -// key: int64 -// -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. -// -// -// Returns the created operation. -func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...OrderedMapStageAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OrderedMapStage", - Input: []tf.Input{ - key, indices, tf.OutputList(values), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// TPUReplicateMetadataAttr is an optional argument to TPUReplicateMetadata. -type TPUReplicateMetadataAttr func(optionalAttr) - -// TPUReplicateMetadataNumCoresPerReplica sets the optional num_cores_per_replica attribute to value. -// -// value: Number of cores per replica. Used for model parallelism. -// If not specified, defaults to 1 -func TPUReplicateMetadataNumCoresPerReplica(value int64) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["num_cores_per_replica"] = value - } -} - -// TPUReplicateMetadataTopology sets the optional topology attribute to value. -// -// value: TopologyProto indicating the topology of the TPU pod slice. -// If not specified, defaults to "" -func TPUReplicateMetadataTopology(value string) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["topology"] = value - } -} - -// TPUReplicateMetadataUseTpu sets the optional use_tpu attribute to value. -// -// value: Whether to place the computation on the TPU. -// If not specified, defaults to true -func TPUReplicateMetadataUseTpu(value bool) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["use_tpu"] = value - } -} - -// TPUReplicateMetadataDeviceAssignment sets the optional device_assignment attribute to value. -// -// value: The assignment of devices for the computation. -// If not specified, defaults to {} -func TPUReplicateMetadataDeviceAssignment(value []int64) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["device_assignment"] = value - } -} - -// TPUReplicateMetadataComputationShape sets the optional computation_shape attribute to value. -// -// value: DEPRECATED. Use num_cores_per_replica instead. -// If not specified, defaults to {} -func TPUReplicateMetadataComputationShape(value []int64) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["computation_shape"] = value - } -} - -// TPUReplicateMetadataHostComputeCore sets the optional host_compute_core attribute to value. -// If not specified, defaults to {} -func TPUReplicateMetadataHostComputeCore(value []string) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["host_compute_core"] = value - } -} - -// TPUReplicateMetadataPaddingMap sets the optional padding_map attribute to value. -// If not specified, defaults to {} -func TPUReplicateMetadataPaddingMap(value []string) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["padding_map"] = value - } -} - -// TPUReplicateMetadataStepMarkerLocation sets the optional step_marker_location attribute to value. -// If not specified, defaults to "STEP_MARK_AT_ENTRY" -func TPUReplicateMetadataStepMarkerLocation(value string) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["step_marker_location"] = value - } -} - -// TPUReplicateMetadataAllowSoftPlacement sets the optional allow_soft_placement attribute to value. -// If not specified, defaults to false -func TPUReplicateMetadataAllowSoftPlacement(value bool) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["allow_soft_placement"] = value - } -} - -// Metadata indicating how the TPU computation should be replicated. -// -// This operation holds the metadata common to operations of a `tpu.replicate()` computation subgraph. -// -// Arguments: -// num_replicas: Number of replicas of the computation -// -// Returns the created operation. -func TPUReplicateMetadata(scope *Scope, num_replicas int64, optional ...TPUReplicateMetadataAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_replicas": num_replicas} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TPUReplicateMetadata", - - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. -type ExtractGlimpseAttr func(optionalAttr) - -// ExtractGlimpseCentered sets the optional centered attribute to value. -// -// value: indicates if the offset coordinates are centered relative to -// the image, in which case the (0, 0) offset is relative to the center -// of the input images. If false, the (0,0) offset corresponds to the -// upper left corner of the input images. -// If not specified, defaults to true -func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["centered"] = value - } -} - -// ExtractGlimpseNormalized sets the optional normalized attribute to value. -// -// value: indicates if the offset coordinates are normalized. -// If not specified, defaults to true -func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["normalized"] = value - } -} - -// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. -// -// value: indicates if the noise should be generated using a -// uniform distribution or a Gaussian distribution. -// If not specified, defaults to true -func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["uniform_noise"] = value - } -} - -// ExtractGlimpseNoise sets the optional noise attribute to value. -// -// value: indicates if the noise should `uniform`, `gaussian`, or -// `zero`. The default is `uniform` which means the the noise type -// will be decided by `uniform_noise`. -// If not specified, defaults to "uniform" -func ExtractGlimpseNoise(value string) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["noise"] = value - } -} - -// Extracts a glimpse from the input tensor. -// -// Returns a set of windows called glimpses extracted at location -// `offsets` from the input tensor. If the windows only partially -// overlaps the inputs, the non overlapping areas will be filled with -// random noise. -// -// The result is a 4-D tensor of shape `[batch_size, glimpse_height, -// glimpse_width, channels]`. The channels and batch dimensions are the -// same as that of the input tensor. The height and width of the output -// windows are specified in the `size` parameter. -// -// The argument `normalized` and `centered` controls how the windows are built: -// -// * If the coordinates are normalized but not centered, 0.0 and 1.0 -// correspond to the minimum and maximum of each height and width -// dimension. -// * If the coordinates are both normalized and centered, they range from -// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper -// left corner, the lower right corner is located at (1.0, 1.0) and the -// center is at (0, 0). -// * If the coordinates are not normalized they are interpreted as -// numbers of pixels. -// -// Arguments: -// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. -// size: A 1-D tensor of 2 elements containing the size of the glimpses -// to extract. The glimpse height must be specified first, following -// by the glimpse width. -// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing -// the y, x locations of the center of each window. -// -// Returns A tensor representing the glimpses `[batch_size, -// glimpse_height, glimpse_width, channels]`. -func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ExtractGlimpse", - Input: []tf.Input{ - input, size, offsets, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes sigmoid of `x` element-wise. -// -// Specifically, `y = 1 / (1 + exp(-x))`. -func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sigmoid", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. type ResourceSparseApplyAdadeltaAttr func(optionalAttr) @@ -41965,6 +40933,79 @@ func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, return scope.AddOperation(opspec) } +// Computes sigmoid of `x` element-wise. +// +// Specifically, `y = 1 / (1 + exp(-x))`. +func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sigmoid", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingADAMParametersGradAccumDebug. +type RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve ADAM embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns: +// parameters: Parameter parameters updated by the ADAM optimization algorithm. +// momenta: Parameter momenta updated by the ADAM optimization algorithm. +// velocities: Parameter velocities updated by the ADAM optimization algorithm. +// gradient_accumulators: Parameter gradient_accumulators updated by the ADAM optimization algorithm. +func RetrieveTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + // ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. type ResourceApplyAdamAttr func(optionalAttr) @@ -42028,347 +41069,6 @@ func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, b return scope.AddOperation(opspec) } -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingADAMParametersGradAccumDebug. -type RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingADAMParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Retrieve ADAM embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns: -// parameters: Parameter parameters updated by the ADAM optimization algorithm. -// momenta: Parameter momenta updated by the ADAM optimization algorithm. -// velocities: Parameter velocities updated by the ADAM optimization algorithm. -// gradient_accumulators: Parameter gradient_accumulators updated by the ADAM optimization algorithm. -func RetrieveTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// RequantizePerChannelAttr is an optional argument to RequantizePerChannel. -type RequantizePerChannelAttr func(optionalAttr) - -// RequantizePerChannelOutType sets the optional out_type attribute to value. -// -// value: The quantized type of output tensor that needs to be converted. -// If not specified, defaults to DT_QUINT8 -func RequantizePerChannelOutType(value tf.DataType) RequantizePerChannelAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Requantizes input with min and max values known per channel. -// -// Arguments: -// input: The original input tensor. -// input_min: The minimum value of the input tensor -// input_max: The maximum value of the input tensor. -// requested_output_min: The minimum value of the output tensor requested. -// requested_output_max: The maximum value of the output tensor requested. -// -// Returns: -// output: Output tensor. -// output_min: The minimum value of the final output tensor -// output_max: The maximum value of the final output tensor. -func RequantizePerChannel(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, optional ...RequantizePerChannelAttr) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RequantizePerChannel", - Input: []tf.Input{ - input, input_min, input_max, requested_output_min, requested_output_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// LeakyReluAttr is an optional argument to LeakyRelu. -type LeakyReluAttr func(optionalAttr) - -// LeakyReluAlpha sets the optional alpha attribute to value. -// If not specified, defaults to 0.2 -func LeakyReluAlpha(value float32) LeakyReluAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// Computes rectified linear: `max(features, features * alpha)`. -func LeakyRelu(scope *Scope, features tf.Output, optional ...LeakyReluAttr) (activations tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LeakyRelu", - Input: []tf.Input{ - features, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingFTRLParametersAttr is an optional argument to RetrieveTPUEmbeddingFTRLParameters. -type RetrieveTPUEmbeddingFTRLParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingFTRLParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingFTRLParametersTableId(value int64) RetrieveTPUEmbeddingFTRLParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingFTRLParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingFTRLParametersTableName(value string) RetrieveTPUEmbeddingFTRLParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// RetrieveTPUEmbeddingFTRLParametersConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingFTRLParametersConfig(value string) RetrieveTPUEmbeddingFTRLParametersAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Retrieve FTRL embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns: -// parameters: Parameter parameters updated by the FTRL optimization algorithm. -// accumulators: Parameter accumulators updated by the FTRL optimization algorithm. -// linears: Parameter linears updated by the FTRL optimization algorithm. -func RetrieveTPUEmbeddingFTRLParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingFTRLParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// SerializeManySparseAttr is an optional argument to SerializeManySparse. -type SerializeManySparseAttr func(optionalAttr) - -// SerializeManySparseOutType sets the optional out_type attribute to value. -// -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeManySparseOutType(value tf.DataType) SerializeManySparseAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object. -// -// The `SparseTensor` must have rank `R` greater than 1, and the first dimension -// is treated as the minibatch dimension. Elements of the `SparseTensor` -// must be sorted in increasing order of this first dimension. The serialized -// `SparseTensor` objects going into each row of `serialized_sparse` will have -// rank `R-1`. -// -// The minibatch size `N` is extracted from `sparse_shape[0]`. -// -// Arguments: -// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. -// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. -func SerializeManySparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeManySparseAttr) (serialized_sparse tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SerializeManySparse", - Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. -// -// More formally, let -// -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, -// -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ -// -// Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. -// -// Returns Computed precision at `k` as a `bool Tensor`. -func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InTopKV2", - Input: []tf.Input{ - predictions, targets, k, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates an Optional variant with no value. -func OptionalNone(scope *Scope) (optional tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "OptionalNone", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr is an optional argument to RetrieveTPUEmbeddingStochasticGradientDescentParameters. -type RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// RetrieveTPUEmbeddingStochasticGradientDescentParametersConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingStochasticGradientDescentParametersConfig(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Retrieve SGD embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the stochastic gradient descent optimization algorithm. -func RetrieveTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr) (parameters tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingStochasticGradientDescentParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // DatasetToGraphAttr is an optional argument to DatasetToGraph. type DatasetToGraphAttr func(optionalAttr) @@ -42522,6 +41222,183 @@ func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backpr return op.Output(0) } +// Computes softmax cross entropy cost and gradients to backpropagate. +// +// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept +// a matrix of label probabilities, but rather a single label per row +// of features. This label is considered to have probability 1.0 for the +// given row. +// +// Inputs are the logits, not probabilities. +// +// Arguments: +// features: batch_size x num_classes matrix +// labels: batch_size vector with values in [0, num_classes). +// This is the label for the given minibatch entry. +// +// Returns: +// loss: Per example loss (batch_size vector). +// backprop: backpropagated gradients (batch_size x num_classes matrix). +func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSoftmaxCrossEntropyWithLogits", + Input: []tf.Input{ + features, labels, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. +type ResourceApplyProximalGradientDescentAttr func(optionalAttr) + +// ResourceApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalGradientDescentUseLocking(value bool) ResourceApplyProximalGradientDescentAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' as FOBOS algorithm with fixed learning rate. +// +// prox_v = var - alpha * delta +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// delta: The change. +// +// Returns the created operation. +func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, delta tf.Output, optional ...ResourceApplyProximalGradientDescentAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyProximalGradientDescent", + Input: []tf.Input{ + var_, alpha, l1, l2, delta, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Worker heartbeat op. +// +// Heartbeats may be sent periodically to indicate the coordinator is still active, +// to retrieve the current worker status and to expedite shutdown when necessary. +// +// Arguments: +// request: A string tensor containing a serialized WorkerHeartbeatRequest +// +// Returns A string tensor containing a serialized WorkerHeartbeatResponse +func WorkerHeartbeat(scope *Scope, request tf.Output) (response tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "WorkerHeartbeat", + Input: []tf.Input{ + request, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the item in the list with the given index. +// +// input_handle: the list +// index: the position in the list from which an element will be retrieved +// item: the element at that position +// +// +func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, element_shape tf.Output, element_dtype tf.DataType) (item tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + opspec := tf.OpSpec{ + Type: "TensorListGetItem", + Input: []tf.Input{ + input_handle, index, element_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug. +type RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve proximal Adagrad embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns: +// parameters: Parameter parameters updated by the proximal Adagrad optimization algorithm. +// accumulators: Parameter accumulators updated by the proximal Adagrad optimization algorithm. +// gradient_accumulators: Parameter gradient_accumulators updated by the proximal Adagrad optimization algorithm. +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // Returns x / y element-wise. // // *NOTE*: `Div` supports broadcasting. More about broadcasting @@ -42828,8 +41705,6 @@ type LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAt // LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -43072,8 +41947,6 @@ type LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr func(optionalAttr) // LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -43212,8 +42085,6 @@ type RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr func(optionalAttr) // RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -43327,8 +42198,6 @@ type LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) // LoadTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -43391,8 +42260,6 @@ type LoadTPUEmbeddingAdadeltaParametersAttr func(optionalAttr) // LoadTPUEmbeddingAdadeltaParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingAdadeltaParametersTableId(value int64) LoadTPUEmbeddingAdadeltaParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -43449,6 +42316,614 @@ func LoadTPUEmbeddingAdadeltaParameters(scope *Scope, parameters tf.Output, accu return scope.AddOperation(opspec) } +// Records the latency of producing `input_dataset` elements in a StatsAggregator. +func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "LatencyStatsDataset", + Input: []tf.Input{ + input_dataset, tag, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the power of one value to another. +// +// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for +// corresponding elements in `x` and `y`. For example: +// +// ``` +// # tensor 'x' is [[2, 2]], [3, 3]] +// # tensor 'y' is [[8, 16], [2, 3]] +// tf.pow(x, y) ==> [[256, 65536], [9, 27]] +// ``` +func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Pow", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Element-wise multiplication of a sparse matrix with a dense tensor. +// +// Returns a sparse matrix. +// +// The dense tensor `b` may be either a scalar; otherwise `a` must be a rank-3 +// `SparseMatrix`; in this case `b` must be shaped `[batch_size, 1, 1]` and the +// multiply operation broadcasts. +// +// **NOTE** even if `b` is zero, the sparsity structure of the output does not +// change. +// +// Arguments: +// a: A CSRSparseMatrix. +// b: A dense tensor. +// +// Returns A dense output tensor. +func SparseMatrixMul(scope *Scope, a tf.Output, b tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseMatrixMul", + Input: []tf.Input{ + a, b, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the element-wise sum of a list of tensors. +// +// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not +// wait for all of its inputs to be ready before beginning to sum. This can +// save memory if inputs are ready at different times, since minimum temporary +// storage is proportional to the output size rather than the inputs size. +// +// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. +// +// Returns a `Tensor` of same shape and type as the elements of `inputs`. +// +// Arguments: +// inputs: A list of `Tensor` objects, each with same shape and type. +// shape: Shape of elements of `inputs`. +func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "AccumulateNV2", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// An op enabling differentiation of TPU Embeddings. +// +// This op simply returns its first input, which is assumed to have been sliced +// from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of +// this op, and its first argument being a trainable Variable, enables automatic +// differentiation of graphs containing embeddings via the TPU Embedding Python +// libraries. +// +// Arguments: +// embedding_variable: A trainable variable, enabling optimizers to find this op. +// sliced_activations: The embedding activations Tensor to return. +// table_id: The id of the table in the embedding layer configuration from which +// these activations were computed. +// lookup_id: Identifier of the set of embedding indices which produced these +// activations. +func TPUEmbeddingActivations(scope *Scope, embedding_variable tf.Output, sliced_activations tf.Output, table_id int64, lookup_id int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"table_id": table_id, "lookup_id": lookup_id} + opspec := tf.OpSpec{ + Type: "TPUEmbeddingActivations", + Input: []tf.Input{ + embedding_variable, sliced_activations, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the LSTM cell backward propagation for the entire time sequence. +// +// This implementation is to be used in conjunction of BlockLSTMV2. +// +// Arguments: +// seq_len_max: Maximum time length actually used by this input. Outputs are padded +// with zeros beyond this length. +// x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs). +// cs_prev: Value of the initial cell state. +// h_prev: Initial output of cell (to be used for peephole). +// w: The weight matrix. +// wci: The weight matrix for input gate peephole connection. +// wcf: The weight matrix for forget gate peephole connection. +// wco: The weight matrix for output gate peephole connection. +// b: The bias vector. +// i: The input gate over the whole time sequence. +// cs: The cell state before the tanh over the whole time sequence. +// f: The forget gate over the whole time sequence. +// o: The output gate over the whole time sequence. +// ci: The cell input over the whole time sequence. +// co: The cell after the tanh over the whole time sequence. +// h: The output h vector over the whole time sequence. +// cs_grad: The current gradient of cs. +// h_grad: The gradient of h vector. +// use_peephole: Whether to use peephole weights. +// +// Returns: +// x_grad: The gradient of x to be back-propped. +// cs_prev_grad: The gradient of cs_prev to be back-propped. +// h_prev_grad: The gradient of h_prev to be back-propped. +// w_grad: The gradient for w to be back-propped. +// wci_grad: The gradient for wci to be back-propped. +// wcf_grad: The gradient for wcf to be back-propped. +// wco_grad: The gradient for wco to be back-propped. +// b_grad: The gradient for w to be back-propped. +func BlockLSTMGradV2(scope *Scope, seq_len_max tf.Output, x tf.Output, cs_prev tf.Output, h_prev tf.Output, w tf.Output, wci tf.Output, wcf tf.Output, wco tf.Output, b tf.Output, i tf.Output, cs tf.Output, f tf.Output, o tf.Output, ci tf.Output, co tf.Output, h tf.Output, cs_grad tf.Output, h_grad tf.Output, use_peephole bool) (x_grad tf.Output, cs_prev_grad tf.Output, h_prev_grad tf.Output, w_grad tf.Output, wci_grad tf.Output, wcf_grad tf.Output, wco_grad tf.Output, b_grad tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"use_peephole": use_peephole} + opspec := tf.OpSpec{ + Type: "BlockLSTMGradV2", + Input: []tf.Input{ + seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, h, cs_grad, h_grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6), op.Output(7) +} + +// Returns the element-wise max of two SparseTensors. +// +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// +// Arguments: +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. +// +// Returns: +// output_indices: 2-D. The indices of the output SparseTensor. +// output_values: 1-D. The values of the output SparseTensor. +func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSparseMaximum", + Input: []tf.Input{ + a_indices, a_values, a_shape, b_indices, b_values, b_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Computes the Bessel i1e function of `x` element-wise. +// +// Exponentially scaled modified Bessel function of order 0 defined as +// `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. +// +// This function is faster and numerically stabler than `bessel_i1(x)`. +func BesselI1e(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BesselI1e", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes rectified linear gradients for a Relu operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Relu operation. +// features: The features passed as input to the corresponding Relu operation, OR +// the outputs of that operation (both work equivalently). +// +// Returns `gradients * (features > 0)`. +func ReluGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReluGrad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingRMSPropParameters. +type LoadTPUEmbeddingRMSPropParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func LoadTPUEmbeddingRMSPropParametersTableId(value int64) LoadTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingRMSPropParametersTableName(value string) LoadTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// LoadTPUEmbeddingRMSPropParametersConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingRMSPropParametersConfig(value string) LoadTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Load RMSProp embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the RMSProp optimization algorithm. +// ms: Value of ms used in the RMSProp optimization algorithm. +// mom: Value of mom used in the RMSProp optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingRMSPropParameters(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingRMSPropParameters", + Input: []tf.Input{ + parameters, ms, mom, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Returns a batched diagonal tensor with a given batched diagonal values. +// +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: +// +// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a +// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: +// +// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. +// +// For example: +// +// ``` +// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] +// +// and diagonal.shape = (2, 4) +// +// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] +// +// which has shape (2, 4, 4) +// ``` +// +// Arguments: +// diagonal: Rank `k`, where `k >= 1`. +// +// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. +func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixDiag", + Input: []tf.Input{ + diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. +type StatelessTruncatedNormalAttr func(optionalAttr) + +// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom values from a truncated normal distribution. +// +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessTruncatedNormal", + Input: []tf.Input{ + shape, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug. +type RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve SGD embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns: +// parameters: Parameter parameters updated by the stochastic gradient descent optimization algorithm. +// gradient_accumulators: Parameter gradient_accumulators updated by the Adadelta optimization algorithm. +func RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr) (parameters tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. +type StatelessRandomUniformAttr func(optionalAttr) + +// StatelessRandomUniformDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessRandomUniform", + Input: []tf.Input{ + shape, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolGradGradV2Attr is an optional argument to MaxPoolGradGradV2. +type MaxPoolGradGradV2Attr func(optionalAttr) + +// MaxPoolGradGradV2DataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradGradV2DataFormat(value string) MaxPoolGradGradV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes second-order gradients of the maxpooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradGradV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolGradGradV2", + Input: []tf.Input{ + orig_input, orig_output, grad, ksize, strides, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingMomentumParametersAttr is an optional argument to RetrieveTPUEmbeddingMomentumParameters. +type RetrieveTPUEmbeddingMomentumParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingMomentumParametersTableId(value int64) RetrieveTPUEmbeddingMomentumParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingMomentumParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMomentumParametersTableName(value string) RetrieveTPUEmbeddingMomentumParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingMomentumParametersConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMomentumParametersConfig(value string) RetrieveTPUEmbeddingMomentumParametersAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve Momentum embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns: +// parameters: Parameter parameters updated by the Momentum optimization algorithm. +// momenta: Parameter momenta updated by the Momentum optimization algorithm. +func RetrieveTPUEmbeddingMomentumParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersAttr) (parameters tf.Output, momenta tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingMomentumParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // PaddingFIFOQueueV2Attr is an optional argument to PaddingFIFOQueueV2. type PaddingFIFOQueueV2Attr func(optionalAttr) @@ -43536,8 +43011,6 @@ type LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) // LoadTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -43594,6 +43067,736 @@ func LoadTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, parameters t return scope.AddOperation(opspec) } +// RecvAttr is an optional argument to Recv. +type RecvAttr func(optionalAttr) + +// RecvClientTerminated sets the optional client_terminated attribute to value. +// +// value: If set to true, this indicates that the node was added +// to the graph as a result of a client-side feed or fetch of Tensor data, +// in which case the corresponding send or recv is expected to be managed +// locally by the caller. +// If not specified, defaults to false +func RecvClientTerminated(value bool) RecvAttr { + return func(m optionalAttr) { + m["client_terminated"] = value + } +} + +// Receives the named tensor from send_device on recv_device. +// +// Arguments: +// +// tensor_name: The name of the tensor to receive. +// send_device: The name of the device sending the tensor. +// send_device_incarnation: The current incarnation of send_device. +// recv_device: The name of the device receiving the tensor. +// +// Returns The tensor to receive. +func Recv(scope *Scope, tensor_type tf.DataType, tensor_name string, send_device string, send_device_incarnation int64, recv_device string, optional ...RecvAttr) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"tensor_type": tensor_type, "tensor_name": tensor_name, "send_device": send_device, "send_device_incarnation": send_device_incarnation, "recv_device": recv_device} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Recv", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// OrderedMapStageAttr is an optional argument to OrderedMapStage. +type OrderedMapStageAttr func(optionalAttr) + +// OrderedMapStageCapacity sets the optional capacity attribute to value. +// +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapStageCapacity(value int64) OrderedMapStageAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// OrderedMapStageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapStageMemoryLimit(value int64) OrderedMapStageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// OrderedMapStageContainer sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func OrderedMapStageContainer(value string) OrderedMapStageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapStageSharedName sets the optional shared_name attribute to value. +// +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func OrderedMapStageSharedName(value string) OrderedMapStageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Stage (key, values) in the underlying container which behaves like a ordered +// +// associative container. Elements are ordered by key. +// +// Arguments: +// key: int64 +// +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. +// +// +// Returns the created operation. +func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...OrderedMapStageAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OrderedMapStage", + Input: []tf.Input{ + key, indices, tf.OutputList(values), + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// TPUReplicateMetadataAttr is an optional argument to TPUReplicateMetadata. +type TPUReplicateMetadataAttr func(optionalAttr) + +// TPUReplicateMetadataNumCoresPerReplica sets the optional num_cores_per_replica attribute to value. +// +// value: Number of cores per replica. Used for model parallelism. +// If not specified, defaults to 1 +func TPUReplicateMetadataNumCoresPerReplica(value int64) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["num_cores_per_replica"] = value + } +} + +// TPUReplicateMetadataTopology sets the optional topology attribute to value. +// +// value: TopologyProto indicating the topology of the TPU pod slice. +// If not specified, defaults to "" +func TPUReplicateMetadataTopology(value string) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["topology"] = value + } +} + +// TPUReplicateMetadataUseTpu sets the optional use_tpu attribute to value. +// +// value: Whether to place the computation on the TPU. +// If not specified, defaults to true +func TPUReplicateMetadataUseTpu(value bool) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["use_tpu"] = value + } +} + +// TPUReplicateMetadataDeviceAssignment sets the optional device_assignment attribute to value. +// +// value: The assignment of devices for the computation. +// If not specified, defaults to {} +func TPUReplicateMetadataDeviceAssignment(value []int64) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["device_assignment"] = value + } +} + +// TPUReplicateMetadataComputationShape sets the optional computation_shape attribute to value. +// +// value: DEPRECATED. Use num_cores_per_replica instead. +// If not specified, defaults to {} +func TPUReplicateMetadataComputationShape(value []int64) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["computation_shape"] = value + } +} + +// TPUReplicateMetadataHostComputeCore sets the optional host_compute_core attribute to value. +// If not specified, defaults to {} +func TPUReplicateMetadataHostComputeCore(value []string) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["host_compute_core"] = value + } +} + +// TPUReplicateMetadataPaddingMap sets the optional padding_map attribute to value. +// If not specified, defaults to {} +func TPUReplicateMetadataPaddingMap(value []string) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["padding_map"] = value + } +} + +// TPUReplicateMetadataStepMarkerLocation sets the optional step_marker_location attribute to value. +// If not specified, defaults to "STEP_MARK_AT_ENTRY" +func TPUReplicateMetadataStepMarkerLocation(value string) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["step_marker_location"] = value + } +} + +// TPUReplicateMetadataAllowSoftPlacement sets the optional allow_soft_placement attribute to value. +// If not specified, defaults to false +func TPUReplicateMetadataAllowSoftPlacement(value bool) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["allow_soft_placement"] = value + } +} + +// Metadata indicating how the TPU computation should be replicated. +// +// This operation holds the metadata common to operations of a `tpu.replicate()` computation subgraph. +// +// Arguments: +// num_replicas: Number of replicas of the computation +// +// Returns the created operation. +func TPUReplicateMetadata(scope *Scope, num_replicas int64, optional ...TPUReplicateMetadataAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_replicas": num_replicas} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TPUReplicateMetadata", + + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// RequantizePerChannelAttr is an optional argument to RequantizePerChannel. +type RequantizePerChannelAttr func(optionalAttr) + +// RequantizePerChannelOutType sets the optional out_type attribute to value. +// +// value: The quantized type of output tensor that needs to be converted. +// If not specified, defaults to DT_QUINT8 +func RequantizePerChannelOutType(value tf.DataType) RequantizePerChannelAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Requantizes input with min and max values known per channel. +// +// Arguments: +// input: The original input tensor. +// input_min: The minimum value of the input tensor +// input_max: The maximum value of the input tensor. +// requested_output_min: The minimum value of the output tensor requested. +// requested_output_max: The maximum value of the output tensor requested. +// +// Returns: +// output: Output tensor. +// output_min: The minimum value of the final output tensor +// output_max: The maximum value of the final output tensor. +func RequantizePerChannel(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, optional ...RequantizePerChannelAttr) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RequantizePerChannel", + Input: []tf.Input{ + input, input_min, input_max, requested_output_min, requested_output_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// LeakyReluAttr is an optional argument to LeakyRelu. +type LeakyReluAttr func(optionalAttr) + +// LeakyReluAlpha sets the optional alpha attribute to value. +// If not specified, defaults to 0.2 +func LeakyReluAlpha(value float32) LeakyReluAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// Computes rectified linear: `max(features, features * alpha)`. +func LeakyRelu(scope *Scope, features tf.Output, optional ...LeakyReluAttr) (activations tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LeakyRelu", + Input: []tf.Input{ + features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns an element-wise indication of the sign of a number. +// +// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`. +// +// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. +// +// Example usage: +// >>> tf.math.sign([0., 2., -3.]) +// +func Sign(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sign", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAddSignAttr is an optional argument to ResourceApplyAddSign. +type ResourceApplyAddSignAttr func(optionalAttr) + +// ResourceApplyAddSignUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and m tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAddSignUseLocking(value bool) ResourceApplyAddSignAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the AddSign update. +// +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// update <- (alpha + sign_decay * sign(g) *sign(m)) * g +// variable <- variable - lr_t * update +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// alpha: Must be a scalar. +// sign_decay: Must be a scalar. +// beta: Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, alpha tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyAddSignAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAddSign", + Input: []tf.Input{ + var_, m, lr, alpha, sign_decay, beta, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Concatenates tensors along one dimension. +// +// Arguments: +// values: List of `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// axis: 0-D. The dimension along which to concatenate. Must be in the +// range [-rank(values), rank(values)). +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatV2", + Input: []tf.Input{ + tf.OutputList(values), axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingRMSPropParametersGradAccumDebug. +type LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingRMSPropParametersGradAccumDebugConfig(value string) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Load RMSProp embedding parameters with debug support. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the RMSProp optimization algorithm. +// ms: Value of ms used in the RMSProp optimization algorithm. +// mom: Value of mom used in the RMSProp optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the RMSProp optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingRMSPropParametersGradAccumDebug(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingRMSPropParametersGradAccumDebug", + Input: []tf.Input{ + parameters, ms, mom, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// TensorListConcatAttr is an optional argument to TensorListConcat. +type TensorListConcatAttr func(optionalAttr) + +// TensorListConcatElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to {unknown_rank:true} +func TensorListConcatElementShape(value tf.Shape) TensorListConcatAttr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Concats all tensors in the list along the 0th dimension. +// +// Requires that all tensors have the same shape except the first dimension. +// +// input_handle: The input list. +// tensor: The concated result. +// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. +// +func TensorListConcat(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListConcatAttr) (tensor tf.Output, lengths tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorListConcat", + Input: []tf.Input{ + input_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// LoadTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to LoadTPUEmbeddingMDLAdagradLightParameters. +type LoadTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func LoadTPUEmbeddingMDLAdagradLightParametersTableId(value int64) LoadTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingMDLAdagradLightParametersTableName(value string) LoadTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// LoadTPUEmbeddingMDLAdagradLightParametersConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingMDLAdagradLightParametersConfig(value string) LoadTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Load MDL Adagrad Light embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the MDL Adagrad Light optimization algorithm. +// accumulators: Value of accumulators used in the MDL Adagrad Light optimization algorithm. +// weights: Value of weights used in the MDL Adagrad Light optimization algorithm. +// benefits: Value of benefits used in the MDL Adagrad Light optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingMDLAdagradLightParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMDLAdagradLightParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingMDLAdagradLightParameters", + Input: []tf.Input{ + parameters, accumulators, weights, benefits, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Returns the next record (key, value pair) produced by a Reader. +// +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). +// +// Arguments: +// reader_handle: Handle to a Reader. +// queue_handle: Handle to a Queue, with string work items. +// +// Returns: +// key: A scalar. +// value: A scalar. +func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderReadV2", + Input: []tf.Input{ + reader_handle, queue_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// CumprodAttr is an optional argument to Cumprod. +type CumprodAttr func(optionalAttr) + +// CumprodExclusive sets the optional exclusive attribute to value. +// +// value: If `True`, perform exclusive cumprod. +// If not specified, defaults to false +func CumprodExclusive(value bool) CumprodAttr { + return func(m optionalAttr) { + m["exclusive"] = value + } +} + +// CumprodReverse sets the optional reverse attribute to value. +// +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumprodReverse(value bool) CumprodAttr { + return func(m optionalAttr) { + m["reverse"] = value + } +} + +// Compute the cumulative product of the tensor `x` along `axis`. +// +// By default, this op performs an inclusive cumprod, which means that the first +// element of the input is identical to the first element of the output: +// +// ```python +// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +// ``` +// +// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +// performed instead: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +// ``` +// +// By setting the `reverse` kwarg to `True`, the cumprod is performed in the +// opposite direction: +// +// ```python +// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +// ``` +// +// This is more efficient than using separate `tf.reverse` ops. +// +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +// ``` +// +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Cumprod", + Input: []tf.Input{ + x, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug. +type LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugConfig(value string) LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Load SGD embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the stochastic gradient descent optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the Adadelta optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug(scope *Scope, parameters tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", + Input: []tf.Input{ + parameters, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // Creates a Dataset that returns pseudorandom numbers. // // Creates a Dataset that returns a stream of uniformly distributed @@ -43784,8 +43987,6 @@ type LoadTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) // LoadTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) LoadTPUEmbeddingStochasticGradientDescentParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -43840,6 +44041,83 @@ func LoadTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, parameter return scope.AddOperation(opspec) } +// Counts the number of occurrences of each value in an integer array. +// +// Outputs a vector with length `size` and the same dtype as `weights`. If +// `weights` are empty, then index `i` stores the number of times the value `i` is +// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of +// the value in `weights` at each index where the corresponding value in `arr` is +// `i`. +// +// Values in `arr` outside of the range [0, size) are ignored. +// +// Arguments: +// arr: int32 `Tensor`. +// size: non-negative int32 scalar `Tensor`. +// weights: is an int32, int64, float32, or float64 `Tensor` with the same +// shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights +// equal to 1. +// +// Returns 1D `Tensor` with length equal to `size`. The counts or summed weights for +// each value in the range [0, size). +func Bincount(scope *Scope, arr tf.Output, size tf.Output, weights tf.Output) (bins tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Bincount", + Input: []tf.Input{ + arr, size, weights, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Gradients for batch normalization. +// +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// +// This op is deprecated. See `tf.nn.batch_normalization`. +// +// Arguments: +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this Tensor will be multiplied +// with the normalized Tensor. +// backprop: 4D backprop Tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +// +// Returns: +// dx: 4D backprop tensor for input. +// dm: 1D backprop tensor for mean. +// dv: 1D backprop tensor for variance. +// db: 1D backprop tensor for beta. +// dg: 1D backprop tensor for gamma. +func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + opspec := tf.OpSpec{ + Type: "BatchNormWithGlobalNormalizationGrad", + Input: []tf.Input{ + t, m, v, gamma, backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + // Strip leading and trailing whitespaces from the Tensor. // // Arguments: @@ -44036,8 +44314,6 @@ type LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr func(optionalAttr) // LoadTPUEmbeddingAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingAdagradParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value @@ -44178,8 +44454,6 @@ type LoadTPUEmbeddingADAMParametersAttr func(optionalAttr) // LoadTPUEmbeddingADAMParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingADAMParametersTableId(value int64) LoadTPUEmbeddingADAMParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -44386,8 +44660,6 @@ type LoadTPUEmbeddingAdagradParametersAttr func(optionalAttr) // LoadTPUEmbeddingAdagradParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingAdagradParametersTableId(value int64) LoadTPUEmbeddingAdagradParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -44745,91 +45017,6 @@ func ImageProjectiveTransformV2(scope *Scope, images tf.Output, transforms tf.Ou return op.Output(0) } -// Computes rectified linear gradients for a Relu operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Relu operation. -// features: The features passed as input to the corresponding Relu operation, OR -// the outputs of that operation (both work equivalently). -// -// Returns `gradients * (features > 0)`. -func ReluGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReluGrad", - Input: []tf.Input{ - gradients, features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LoadTPUEmbeddingRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingRMSPropParameters. -type LoadTPUEmbeddingRMSPropParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingRMSPropParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingRMSPropParametersTableId(value int64) LoadTPUEmbeddingRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingRMSPropParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingRMSPropParametersTableName(value string) LoadTPUEmbeddingRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// LoadTPUEmbeddingRMSPropParametersConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingRMSPropParametersConfig(value string) LoadTPUEmbeddingRMSPropParametersAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Load RMSProp embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the RMSProp optimization algorithm. -// ms: Value of ms used in the RMSProp optimization algorithm. -// mom: Value of mom used in the RMSProp optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingRMSPropParameters(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingRMSPropParameters", - Input: []tf.Input{ - parameters, ms, mom, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // InfeedEnqueueTupleAttr is an optional argument to InfeedEnqueueTuple. type InfeedEnqueueTupleAttr func(optionalAttr) @@ -44975,6 +45162,154 @@ func InfeedEnqueue(scope *Scope, input tf.Output, optional ...InfeedEnqueueAttr) return scope.AddOperation(opspec) } +// SerializeManySparseAttr is an optional argument to SerializeManySparse. +type SerializeManySparseAttr func(optionalAttr) + +// SerializeManySparseOutType sets the optional out_type attribute to value. +// +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeManySparseOutType(value tf.DataType) SerializeManySparseAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object. +// +// The `SparseTensor` must have rank `R` greater than 1, and the first dimension +// is treated as the minibatch dimension. Elements of the `SparseTensor` +// must be sorted in increasing order of this first dimension. The serialized +// `SparseTensor` objects going into each row of `serialized_sparse` will have +// rank `R-1`. +// +// The minibatch size `N` is extracted from `sparse_shape[0]`. +// +// Arguments: +// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. +// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. +func SerializeManySparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeManySparseAttr) (serialized_sparse tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SerializeManySparse", + Input: []tf.Input{ + sparse_indices, sparse_values, sparse_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Says whether the targets are in the top `K` predictions. +// +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. +// +// More formally, let +// +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// +// Arguments: +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. +// +// Returns Computed precision at `k` as a `bool Tensor`. +func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InTopKV2", + Input: []tf.Input{ + predictions, targets, k, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates an Optional variant with no value. +func OptionalNone(scope *Scope) (optional tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "OptionalNone", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr is an optional argument to RetrieveTPUEmbeddingStochasticGradientDescentParameters. +type RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingStochasticGradientDescentParametersConfig(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve SGD embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the stochastic gradient descent optimization algorithm. +func RetrieveTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr) (parameters tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingStochasticGradientDescentParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Output a fact about factorials. func Fact(scope *Scope) (fact tf.Output) { if scope.Err() != nil { @@ -45792,205 +46127,6 @@ func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Out return scope.AddOperation(opspec) } -// Returns an element-wise indication of the sign of a number. -// -// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`. -// -// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. -// -// Example usage: -// >>> tf.math.sign([0., 2., -3.]) -// -func Sign(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sign", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyAddSignAttr is an optional argument to ResourceApplyAddSign. -type ResourceApplyAddSignAttr func(optionalAttr) - -// ResourceApplyAddSignUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and m tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAddSignUseLocking(value bool) ResourceApplyAddSignAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the AddSign update. -// -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// update <- (alpha + sign_decay * sign(g) *sign(m)) * g -// variable <- variable - lr_t * update -// -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// alpha: Must be a scalar. -// sign_decay: Must be a scalar. -// beta: Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, alpha tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyAddSignAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAddSign", - Input: []tf.Input{ - var_, m, lr, alpha, sign_decay, beta, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Concatenates tensors along one dimension. -// -// Arguments: -// values: List of `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// axis: 0-D. The dimension along which to concatenate. Must be in the -// range [-rank(values), rank(values)). -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConcatV2", - Input: []tf.Input{ - tf.OutputList(values), axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorListConcatAttr is an optional argument to TensorListConcat. -type TensorListConcatAttr func(optionalAttr) - -// TensorListConcatElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to {unknown_rank:true} -func TensorListConcatElementShape(value tf.Shape) TensorListConcatAttr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// Concats all tensors in the list along the 0th dimension. -// -// Requires that all tensors have the same shape except the first dimension. -// -// input_handle: The input list. -// tensor: The concated result. -// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. -// -func TensorListConcat(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListConcatAttr) (tensor tf.Output, lengths tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorListConcat", - Input: []tf.Input{ - input_handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingRMSPropParametersGradAccumDebug. -type LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) - -// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// LoadTPUEmbeddingRMSPropParametersGradAccumDebugConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingRMSPropParametersGradAccumDebugConfig(value string) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Load RMSProp embedding parameters with debug support. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the RMSProp optimization algorithm. -// ms: Value of ms used in the RMSProp optimization algorithm. -// mom: Value of mom used in the RMSProp optimization algorithm. -// gradient_accumulators: Value of gradient_accumulators used in the RMSProp optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingRMSPropParametersGradAccumDebug(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingRMSPropParametersGradAccumDebug", - Input: []tf.Input{ - parameters, ms, mom, gradient_accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // EqualAttr is an optional argument to Equal. type EqualAttr func(optionalAttr) @@ -46283,147 +46419,59 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms return scope.AddOperation(opspec) } -// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. -type Conv3DBackpropInputAttr func(optionalAttr) +// RetrieveTPUEmbeddingFTRLParametersAttr is an optional argument to RetrieveTPUEmbeddingFTRLParameters. +type RetrieveTPUEmbeddingFTRLParametersAttr func(optionalAttr) -// Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} -func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { +// RetrieveTPUEmbeddingFTRLParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingFTRLParametersTableId(value int64) RetrieveTPUEmbeddingFTRLParametersAttr { return func(m optionalAttr) { - m["dilations"] = value + m["table_id"] = value } } -// Computes the gradients of 3-D convolution with respect to the input. +// RetrieveTPUEmbeddingFTRLParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingFTRLParametersTableName(value string) RetrieveTPUEmbeddingFTRLParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingFTRLParametersConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingFTRLParametersConfig(value string) RetrieveTPUEmbeddingFTRLParametersAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve FTRL embedding parameters. // -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { +// Returns: +// parameters: Parameter parameters updated by the FTRL optimization algorithm. +// accumulators: Parameter accumulators updated by the FTRL optimization algorithm. +// linears: Parameter linears updated by the FTRL optimization algorithm. +func RetrieveTPUEmbeddingFTRLParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropInput", - Input: []tf.Input{ - input, filter, out_backprop, - }, + Type: "RetrieveTPUEmbeddingFTRLParameters", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. -type DepthwiseConv2dNativeAttr func(optionalAttr) - -// DepthwiseConv2dNativeDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthwiseConv2dNativeDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} -func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. -// -// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -// and a filter / kernel tensor of shape -// `[filter_height, filter_width, in_channels, channel_multiplier]`, containing -// `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies -// a different filter to each input channel (expanding from 1 channel to -// `channel_multiplier` channels for each), then concatenates the results -// together. Thus, the output has `in_channels * channel_multiplier` channels. -// -// ``` -// for k in 0..in_channels-1 -// for q in 0..channel_multiplier-1 -// output[b, i, j, k * channel_multiplier + q] = -// sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * -// filter[di, dj, k, q] -// ``` -// -// Must have `strides[0] = strides[3] = 1`. For the most common case of the same -// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. -// -// Arguments: -// -// -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. -// padding: The type of padding algorithm to use. -func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNative", - Input: []tf.Input{ - input, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates an all-zeros CSRSparseMatrix with shape `dense_shape`. -// -// Arguments: -// dense_shape: The desired matrix shape. -// -// -// Returns An empty CSR matrix with shape `dense_shape`. -func SparseMatrixZeros(scope *Scope, dense_shape tf.Output, type_ tf.DataType) (sparse_matrix tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"type": type_} - opspec := tf.OpSpec{ - Type: "SparseMatrixZeros", - Input: []tf.Input{ - dense_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } // EnqueueTPUEmbeddingIntegerBatchAttr is an optional argument to EnqueueTPUEmbeddingIntegerBatch. @@ -46578,8 +46626,6 @@ type LoadTPUEmbeddingFTRLParametersAttr func(optionalAttr) // LoadTPUEmbeddingFTRLParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingFTRLParametersTableId(value int64) LoadTPUEmbeddingFTRLParametersAttr { return func(m optionalAttr) { m["table_id"] = value @@ -46636,6 +46682,149 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula return scope.AddOperation(opspec) } +// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. +type Conv3DBackpropInputAttr func(optionalAttr) + +// Conv3DBackpropInputDilations sets the optional dilations attribute to value. +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of 3-D convolution with respect to the input. +// +// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv3DBackpropInput", + Input: []tf.Input{ + input, filter, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. +type DepthwiseConv2dNativeAttr func(optionalAttr) + +// DepthwiseConv2dNativeDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthwiseConv2dNativeDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to {i:1 i:1 i:1 i:1} +func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. +// +// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +// and a filter / kernel tensor of shape +// `[filter_height, filter_width, in_channels, channel_multiplier]`, containing +// `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies +// a different filter to each input channel (expanding from 1 channel to +// `channel_multiplier` channels for each), then concatenates the results +// together. Thus, the output has `in_channels * channel_multiplier` channels. +// +// ``` +// for k in 0..in_channels-1 +// for q in 0..channel_multiplier-1 +// output[b, i, j, k * channel_multiplier + q] = +// sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * +// filter[di, dj, k, q] +// ``` +// +// Must have `strides[0] = strides[3] = 1`. For the most common case of the same +// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// +// Arguments: +// +// +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. +// padding: The type of padding algorithm to use. +func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DepthwiseConv2dNative", + Input: []tf.Input{ + input, filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates an all-zeros CSRSparseMatrix with shape `dense_shape`. +// +// Arguments: +// dense_shape: The desired matrix shape. +// +// +// Returns An empty CSR matrix with shape `dense_shape`. +func SparseMatrixZeros(scope *Scope, dense_shape tf.Output, type_ tf.DataType) (sparse_matrix tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"type": type_} + opspec := tf.OpSpec{ + Type: "SparseMatrixZeros", + Input: []tf.Input{ + dense_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the minimum along segments of a tensor. // // Read @@ -46778,83 +46967,6 @@ func ResourceApplyAdagradV2(scope *Scope, var_ tf.Output, accum tf.Output, lr tf return scope.AddOperation(opspec) } -// Counts the number of occurrences of each value in an integer array. -// -// Outputs a vector with length `size` and the same dtype as `weights`. If -// `weights` are empty, then index `i` stores the number of times the value `i` is -// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of -// the value in `weights` at each index where the corresponding value in `arr` is -// `i`. -// -// Values in `arr` outside of the range [0, size) are ignored. -// -// Arguments: -// arr: int32 `Tensor`. -// size: non-negative int32 scalar `Tensor`. -// weights: is an int32, int64, float32, or float64 `Tensor` with the same -// shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights -// equal to 1. -// -// Returns 1D `Tensor` with length equal to `size`. The counts or summed weights for -// each value in the range [0, size). -func Bincount(scope *Scope, arr tf.Output, size tf.Output, weights tf.Output) (bins tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Bincount", - Input: []tf.Input{ - arr, size, weights, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Gradients for batch normalization. -// -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() -// -// This op is deprecated. See `tf.nn.batch_normalization`. -// -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this Tensor will be multiplied -// with the normalized Tensor. -// backprop: 4D backprop Tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -// -// Returns: -// dx: 4D backprop tensor for input. -// dm: 1D backprop tensor for mean. -// dv: 1D backprop tensor for variance. -// db: 1D backprop tensor for beta. -// dg: 1D backprop tensor for gamma. -func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} - opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalizationGrad", - Input: []tf.Input{ - t, m, v, gamma, backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - // Returns the number of gradients aggregated in the given accumulators. // // Arguments: @@ -46911,70 +47023,6 @@ func TPUReplicatedOutput(scope *Scope, input tf.Output, num_replicas int64) (out return outputs } -// LoadTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to LoadTPUEmbeddingMDLAdagradLightParameters. -type LoadTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingMDLAdagradLightParametersTableId(value int64) LoadTPUEmbeddingMDLAdagradLightParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingMDLAdagradLightParametersTableName(value string) LoadTPUEmbeddingMDLAdagradLightParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// LoadTPUEmbeddingMDLAdagradLightParametersConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingMDLAdagradLightParametersConfig(value string) LoadTPUEmbeddingMDLAdagradLightParametersAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Load MDL Adagrad Light embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the MDL Adagrad Light optimization algorithm. -// accumulators: Value of accumulators used in the MDL Adagrad Light optimization algorithm. -// weights: Value of weights used in the MDL Adagrad Light optimization algorithm. -// benefits: Value of benefits used in the MDL Adagrad Light optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingMDLAdagradLightParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMDLAdagradLightParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingMDLAdagradLightParameters", - Input: []tf.Input{ - parameters, accumulators, weights, benefits, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // Converts a tensor to a scalar predicate. // // Converts a tensor to a scalar predicate with the following rules: @@ -47533,8 +47581,6 @@ type LoadTPUEmbeddingProximalAdagradParametersAttr func(optionalAttr) // LoadTPUEmbeddingProximalAdagradParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -// -// REQUIRES: value >= -1 func LoadTPUEmbeddingProximalAdagradParametersTableId(value int64) LoadTPUEmbeddingProximalAdagradParametersAttr { return func(m optionalAttr) { m["table_id"] = value From 66e8fea58bdd2d142daa26e28b3907a1dd31fa99 Mon Sep 17 00:00:00 2001 From: Robert David Date: Thu, 19 Mar 2020 15:51:12 -0700 Subject: [PATCH 267/492] Remove obsolete comment. PiperOrigin-RevId: 301912244 Change-Id: I56f4ec4e208d89ecc8502ecd3627018884269460 --- tensorflow/lite/kernels/activations.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 7e65fbb5306..bf3dec21099 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -128,7 +128,6 @@ void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input, #if __aarch64__ && __clang__ namespace { // Looks up each element of in , returns them in a vector. -// idx_offset must be a int8x16_t vector containing 64 in each lane. inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4], uint8x16_t indices) { // Look up in 1st quarter of the table: top 2 bits of indices == 00 From 9426596938d98e0ce8941863db815ffd0d5db5c8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 19 Mar 2020 16:20:42 -0700 Subject: [PATCH 268/492] [TF:MLIR] Implement optimal layout assignment for Conv2DBackpropFilter PiperOrigin-RevId: 301917898 Change-Id: I0925b1e99d20bd725bb863bb4c7498e8d0565e9c --- .../mlir/tensorflow/ir/tf_generated_ops.td | 10 +++- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 48 +++++++++++++++++++ .../compiler/mlir/tensorflow/ir/tf_ops.td | 2 +- ...imization_layout_assignment_gpu_cc_60.mlir | 20 ++++++++ ...imization_layout_assignment_gpu_cc_70.mlir | 40 ++++++++++++++++ ...ptimization_layout_assignment_to_nchw.mlir | 42 +++++++++++++++- ...ptimization_layout_assignment_to_nhwc.mlir | 8 ++-- 7 files changed, 163 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 3592fa62a25..c4ba3f2ce9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1099,7 +1099,7 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. }]; } -def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect]> { +def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect, TF_LayoutSensitiveInterface]> { let summary = [{ Computes the gradients of convolution with respect to the filter. }]; @@ -1125,6 +1125,14 @@ Computes the gradients of convolution with respect to the filter. ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let extraClassDeclaration = [{ + // TF_LayoutSensitiveInterface: + SmallVector GetLayoutDependentArgs() { return {0, 2}; } + SmallVector GetLayoutDependentResults() { return {}; } + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; } def TF_Conv2DBackpropInputOp : TF_Op<"Conv2DBackpropInput", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 3e40d81bbf5..d7afc3c9c86 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -1099,6 +1099,54 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { return "NCHW"; } +//===----------------------------------------------------------------------===// +// Conv2dBackpropFilterOp +//===----------------------------------------------------------------------===// + +LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { + StringRef src_data_format = this->data_format(); + + auto perm = GetDataFormatPermutation(src_data_format, data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + // Permute filter sizes operand. + OpBuilder builder(getOperation()); + auto data_format_permute = builder.create( + getLoc(), filter_sizes(), StringAttr::get(src_data_format, getContext()), + StringAttr::get(data_format, getContext())); + setOperand(1, data_format_permute); + + return success(); +} + +StringRef Conv2DBackpropFilterOp::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Input must be a tensor. + auto input_ty = input().getType().dyn_cast(); + if (!input_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = input_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // Otherwise always use "NCHW". + return "NCHW"; +} + //===----------------------------------------------------------------------===// // Conv2dBackpropInputOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index b609a071975..e4feae42eda 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -87,7 +87,7 @@ def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect]> { let hasFolder = 1; } -def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect]> { +def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Permute input tensor from `src_format` to `dst_format`"; let description = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir index 3786a26d114..83338a95a05 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir @@ -22,4 +22,24 @@ func @transposeConv2D_3x3_f16(%input: tensor<1x28x28x64xf16>, %filter: tensor<3x return %0 : tensor<1x28x28x64xf16> } +// CHECK-LABEL: func @transposeConv2DBackpropFilter_f16 +func @transposeConv2DBackpropFilter_f16( + %input: tensor<1x28x28x64xf16>, + %filter_size: tensor<4xi32>, + %out_backprop: tensor<1x28x28x64xf16> +) -> tensor<1x1x64x64xf16> { + + // CHECK: "tf.Conv2DBackpropFilter" + // CHECK-SAME: data_format = "NCHW" + %0 = "tf.Conv2DBackpropFilter"(%input, %filter_size, %out_backprop) + { + data_format = "NHWC", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<1x28x28x64xf16>, tensor<4xi32>, tensor<1x28x28x64xf16>) + -> tensor<1x1x64x64xf16> + + return %0 : tensor<1x1x64x64xf16> +} + } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir index 0b2588c38cc..3d7cb1affa8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir @@ -63,4 +63,44 @@ func @transposeConv2D_3x3_f16(%input: tensor<1x64x28x28xf16>, %filter: tensor<3x return %0 : tensor<1x64x28x28xf16> } +// CHECK-LABEL: func @transposeConv2DBackpropFilter_f32 +func @transposeConv2DBackpropFilter_f32( + %input: tensor<1x28x28x64xf32>, + %filter_size: tensor<4xi32>, + %out_backprop: tensor<1x28x28x64xf32> +) -> tensor<1x1x64x64xf32> { + + // CHECK: "tf.Conv2DBackpropFilter" + // CHECK-SAME: data_format = "NCHW" + %0 = "tf.Conv2DBackpropFilter"(%input, %filter_size, %out_backprop) + { + data_format = "NHWC", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<1x28x28x64xf32>, tensor<4xi32>, tensor<1x28x28x64xf32>) + -> tensor<1x1x64x64xf32> + + return %0 : tensor<1x1x64x64xf32> +} + +// CHECK-LABEL: func @transposeConv2DBackpropFilter_f16 +func @transposeConv2DBackpropFilter_f16( + %input: tensor<1x64x28x28xf16>, + %filter_size: tensor<4xi32>, + %out_backprop: tensor<1x64x28x28xf16> +) -> tensor<1x1x64x64xf16> { + + // CHECK: "tf.Conv2DBackpropFilter" + // CHECK-SAME: data_format = "NHWC" + %0 = "tf.Conv2DBackpropFilter"(%input, %filter_size, %out_backprop) + { + data_format = "NCHW", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<1x64x28x28xf16>, tensor<4xi32>, tensor<1x64x28x28xf16>) + -> tensor<1x1x64x64xf16> + + return %0 : tensor<1x1x64x64xf16> +} + } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir index b66289ae34b..099b97ff0de 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir @@ -64,4 +64,44 @@ func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: ten } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> -} \ No newline at end of file +} + +// CHECK-LABEL: func @transposeConv2DBackpropFilter +func @transposeConv2DBackpropFilter( + %input: tensor<1x32x32x3xf32>, + %filter_sizes: tensor<4xi32>, + %out_backprop: tensor<1x32x32x8xf32> +) -> tensor<1x1x3x8xf32> { + + // CHECK: %[[FILTER_PERM:[0-9]*]] = "tf.DataFormatVecPermute" + // CHECK-SAME: dst_format = "NCHW" + // CHECK-SAME: src_format = "NHWC" + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[IN_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + // CHECK: %[[OUT_BP_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg2, %[[ARG_PERM]]) + + // CHECK: %[[CONV2D_BACKPROP:[0-9]*]] = "tf.Conv2DBackpropFilter" + // CHECK-SAME: (%[[IN_TRANSPOSE]], %[[FILTER_PERM]], %[[OUT_BP_TRANSPOSE]]) + // CHECK-SAME: data_format = "NCHW" + // CHECK-SAME: dilations = [1, 4, 2, 3] + // CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6] + // CHECK-SAME: padding = "EXPLICIT" + // CHECK-SAME: strides = [5, 8, 6, 7] + // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<4xi32>, tensor<1x8x32x32xf32>) + // CHECK-SAME: -> tensor<1x1x3x8xf32> + + // CHECK: return %[[CONV2D_BACKPROP]] + + %0 = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) + { + data_format = "NHWC", + dilations = [1, 2, 3, 4], + explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8], + padding = "EXPLICIT", + strides = [5, 6, 7, 8] + } : (tensor<1x32x32x3xf32>, tensor<4xi32>, tensor<1x32x32x8xf32>) + -> tensor<1x1x3x8xf32> + + return %0 : tensor<1x1x3x8xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir index 0ed7b833158..e27448e1d0f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir @@ -1,12 +1,12 @@ // RUN: tf-opt %s -tf-layout-assignment=force-data-format=NHWC -verify-diagnostics | FileCheck %s --dump-input=always +// IMPORTANT: Tensor shapes do not match convolution parameters (stride, +// dilations, etc...). This test only verifies that changing convolution data +// layout will update all the attributes. + // CHECK-LABEL: func @transposeConv2D func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> { - // IMPORTANT: Tensor shapes do not match convolution parameters (stride, - // dilations, etc...). This test only verifies that changing convolution data - // layout will update all the attributes. - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) From 9147d7d2a5a5eb76cea501d65d0095006fc66196 Mon Sep 17 00:00:00 2001 From: Nat Jeffries Date: Thu, 19 Mar 2020 16:23:26 -0700 Subject: [PATCH 269/492] Increase static op data array length to account for new models. PiperOrigin-RevId: 301918391 Change-Id: I29e35cc255abcc809afc516b637ff14bf66a9806 --- .../lite/micro/kernels/xtensa_hifimini/fully_connected.cc | 2 +- tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc index 91761b00c2a..bbbd0fdb496 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc @@ -147,7 +147,7 @@ constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; // This size will work for both the hotword (5) and ambient music (2): -constexpr int kMaxOpDataSize = 5; +constexpr int kMaxOpDataSize = 7; static int kStaticOpDataCounter = 0; static OpData kStaticOpData[kMaxOpDataSize]; diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc index 0859c54cce4..caa567726c4 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc @@ -113,7 +113,10 @@ struct OpData { int scale_multiplier = 0; }; -static OpData kStaticOpData; +// This size will work for both the hotword (1) and ambient music (1): +constexpr int kMaxOpDataSize = 2; +static int kStaticOpDataCounter = 0; +static OpData kStaticOpData[kMaxOpDataSize]; void* Init(TfLiteContext* context, const char* buffer, size_t length) { return nullptr; @@ -126,7 +129,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. - OpData* op_data = &kStaticOpData; + OpData* op_data = &kStaticOpData[kStaticOpDataCounter++]; node->user_data = op_data; op_data->scale_multiplier = From def3e4484a86a35f495f23cfa6fbb00e3a48cc9e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 19 Mar 2020 16:24:49 -0700 Subject: [PATCH 270/492] [TF:MLIR] Implement optimal layout assignment for Conv2DBackpropInput PiperOrigin-RevId: 301918632 Change-Id: Ia205a4927fdd8abc43f6aaaa857980b596631c82 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 10 +++- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 52 +++++++++++++++++-- ...imization_layout_assignment_gpu_cc_60.mlir | 20 +++++++ ...imization_layout_assignment_gpu_cc_70.mlir | 40 ++++++++++++++ ...ptimization_layout_assignment_to_nchw.mlir | 41 +++++++++++++++ 5 files changed, 158 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index c4ba3f2ce9c..e5bda71323e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1135,7 +1135,7 @@ Computes the gradients of convolution with respect to the filter. }]; } -def TF_Conv2DBackpropInputOp : TF_Op<"Conv2DBackpropInput", [NoSideEffect]> { +def TF_Conv2DBackpropInputOp : TF_Op<"Conv2DBackpropInput", [NoSideEffect, TF_LayoutSensitiveInterface]> { let summary = [{ Computes the gradients of convolution with respect to the input. }]; @@ -1165,6 +1165,14 @@ Computes the gradients of convolution with respect to the input. let verifier = [{ return Verify(*this); }]; + + let extraClassDeclaration = [{ + // TF_LayoutSensitiveInterface: + SmallVector GetLayoutDependentArgs() { return {2}; } + SmallVector GetLayoutDependentResults() { return {0}; } + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; } def TF_Conv3DOp : TF_Op<"Conv3D", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index d7afc3c9c86..c2a94a8efe5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -1119,10 +1119,10 @@ LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { // Permute filter sizes operand. OpBuilder builder(getOperation()); - auto data_format_permute = builder.create( + auto filter_sizes_permuted = builder.create( getLoc(), filter_sizes(), StringAttr::get(src_data_format, getContext()), StringAttr::get(data_format, getContext())); - setOperand(1, data_format_permute); + setOperand(1, filter_sizes_permuted); return success(); } @@ -1148,7 +1148,7 @@ StringRef Conv2DBackpropFilterOp::GetOptimalLayout( } //===----------------------------------------------------------------------===// -// Conv2dBackpropInputOp +// Conv2DBackpropInputOp //===----------------------------------------------------------------------===// static LogicalResult Verify(Conv2DBackpropInputOp op) { @@ -1166,7 +1166,51 @@ static LogicalResult Verify(Conv2DBackpropInputOp op) { } return success(); -} // namespace TF +} + +LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { + StringRef src_data_format = this->data_format(); + + auto perm = GetDataFormatPermutation(src_data_format, data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + // Permute input sizes operand. + OpBuilder builder(getOperation()); + auto input_sizes_permuted = builder.create( + getLoc(), input_sizes(), StringAttr::get(src_data_format, getContext()), + StringAttr::get(data_format, getContext())); + setOperand(0, input_sizes_permuted); + + return success(); +} + +StringRef Conv2DBackpropInputOp::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Filter must be a tensor. + auto filter_ty = filter().getType().dyn_cast(); + if (!filter_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = filter_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // Otherwise always use "NCHW". + return "NCHW"; +} //===----------------------------------------------------------------------===// // DataFormatVecPermuteOp diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir index 83338a95a05..3839b000f3a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir @@ -42,4 +42,24 @@ func @transposeConv2DBackpropFilter_f16( return %0 : tensor<1x1x64x64xf16> } +// CHECK-LABEL: func @transposeConv2DBackpropInput_f16 +func @transposeConv2DBackpropInput_f16( + %input_size: tensor<4xi32>, + %filter: tensor<1x28x28x64xf16>, + %out_backprop: tensor<1x28x28x64xf16> +) -> tensor<1x28x28x64xf16> { + + // CHECK: "tf.Conv2DBackpropInput" + // CHECK-SAME: data_format = "NCHW" + %0 = "tf.Conv2DBackpropInput"(%input_size, %filter, %out_backprop) + { + data_format = "NHWC", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<4xi32>, tensor<1x28x28x64xf16>, tensor<1x28x28x64xf16>) + -> tensor<1x28x28x64xf16> + + return %0 : tensor<1x28x28x64xf16> +} + } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir index 3d7cb1affa8..b52ef1c4f4a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir @@ -103,4 +103,44 @@ func @transposeConv2DBackpropFilter_f16( return %0 : tensor<1x1x64x64xf16> } +// CHECK-LABEL: func @transposeConv2DBackpropInput_f32 +func @transposeConv2DBackpropInput_f32( + %input_size: tensor<4xi32>, + %filter: tensor<1x28x28x64xf32>, + %out_backprop: tensor<1x28x28x64xf32> +) -> tensor<1x28x28x64xf32> { + + // CHECK: "tf.Conv2DBackpropInput" + // CHECK-SAME: data_format = "NCHW" + %0 = "tf.Conv2DBackpropInput"(%input_size, %filter, %out_backprop) + { + data_format = "NHWC", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<4xi32>, tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>) + -> tensor<1x28x28x64xf32> + + return %0 : tensor<1x28x28x64xf32> +} + +// CHECK-LABEL: func @transposeConv2DBackpropInput_f16 +func @transposeConv2DBackpropInput_f16( + %input_size: tensor<4xi32>, + %filter: tensor<1x64x28x28xf16>, + %out_backprop: tensor<1x64x28x28xf16> +) -> tensor<1x64x28x28xf16> { + + // CHECK: "tf.Conv2DBackpropInput" + // CHECK-SAME: data_format = "NHWC" + %0 = "tf.Conv2DBackpropInput"(%input_size, %filter, %out_backprop) + { + data_format = "NCHW", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<4xi32>, tensor<1x64x28x28xf16>, tensor<1x64x28x28xf16>) + -> tensor<1x64x28x28xf16> + + return %0 : tensor<1x64x28x28xf16> +} + } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir index 099b97ff0de..22be6537adb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir @@ -105,3 +105,44 @@ func @transposeConv2DBackpropFilter( return %0 : tensor<1x1x3x8xf32> } + +// CHECK-LABEL: func @transposeConv2DBackpropInput +func @transposeConv2DBackpropInput( + %input_sizes: tensor<4xi32>, + %filter: tensor<1x1x3x8xf32>, + %out_backprop: tensor<1x32x32x8xf32> +) -> tensor<1x32x32x3xf32> { + + // CHECK: %[[INPUT_PERM:[0-9]*]] = "tf.DataFormatVecPermute" + // CHECK-SAME: dst_format = "NCHW" + // CHECK-SAME: src_format = "NHWC" + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[OUT_BP_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg2, %[[ARG_PERM]]) + + // CHECK: %[[CONV2D_BACKPROP:[0-9]*]] = "tf.Conv2DBackpropInput" + // CHECK-SAME: (%[[INPUT_PERM]], %arg1, %[[OUT_BP_TRANSPOSE]]) + // CHECK-SAME: data_format = "NCHW" + // CHECK-SAME: dilations = [1, 4, 2, 3] + // CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6] + // CHECK-SAME: padding = "EXPLICIT" + // CHECK-SAME: strides = [5, 8, 6, 7] + // CHECK-SAME: (tensor<4xi32>, tensor<1x1x3x8xf32>, tensor<1x8x32x32xf32>) + // CHECK-SAME: -> tensor<1x3x32x32xf32> + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D_BACKPROP]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + %0 = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) + { + data_format = "NHWC", + dilations = [1, 2, 3, 4], + explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8], + padding = "EXPLICIT", + strides = [5, 6, 7, 8] + } : (tensor<4xi32>, tensor<1x1x3x8xf32>, tensor<1x32x32x8xf32>) + -> tensor<1x32x32x3xf32> + + return %0 : tensor<1x32x32x3xf32> +} From 92f2e1c31515799382a5a9f2ba0aaa58af6bd534 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Thu, 19 Mar 2020 16:27:56 -0700 Subject: [PATCH 271/492] Fix yet another seed forgotten in 301407257 PiperOrigin-RevId: 301919131 Change-Id: I4021932777af023efaaa41aa878b0941ff2ce36a --- tensorflow/compiler/tests/stateless_random_ops_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index a56c9206861..14b062e5cba 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -70,7 +70,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( shape=[1000], seed=seed_t, maxval=maxval, dtype=dtype) - y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) self.assertTrue(np.all(y >= 0)) self.assertTrue(np.all(y < maxval)) From dab07567118ff04e7865edb3ca3f1835e6e6bb87 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 19 Mar 2020 16:31:04 -0700 Subject: [PATCH 272/492] Fix bug in lift_to_graph. PiperOrigin-RevId: 301919704 Change-Id: I35c08facb9523101681a57587d207bd10d0f4441 --- tensorflow/python/eager/lift_to_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py index 9d4c0068685..12a3c58d095 100644 --- a/tensorflow/python/eager/lift_to_graph.py +++ b/tensorflow/python/eager/lift_to_graph.py @@ -276,7 +276,7 @@ def lift_to_graph(tensors, for inp in op_selector.graph_inputs(op): # Don't lift the TPUReplicateMetadata nodes out of the function, because # it has no registered kernels. - if inp.name == "TPUReplicateMetadata": + if inp.type == "TPUReplicateMetadata": continue unvisited_ops.add(inp) if (all(x in marked_ops for x in op_outputs[inp]) and @@ -351,7 +351,7 @@ def lift_to_graph(tensors, for mutation in control_mutations: # Don't lift the TPUReplicateMetadata nodes out of the function, because # it has no registered kernels. - if mutation.old_graph_op.name == "TPUReplicateMetadata": + if mutation.old_graph_op.type == "TPUReplicateMetadata": continue mutation.copied_op._add_control_input(op_map[mutation.old_graph_op]) # pylint: enable=protected-access From 5c66fbfd388d398124d39ac9826968fbbfc1b3d2 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Thu, 19 Mar 2020 16:44:14 -0700 Subject: [PATCH 273/492] Fix unnecessary outer loop. PiperOrigin-RevId: 301922075 Change-Id: I495260bf821ce97ca37862a2c9efddaebde86027 --- .../compiler/mlir/lite/converter_gen.cc | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 8ecff8757b7..db2b924278f 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -442,22 +442,20 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { verify_ctx.withOp("top"); for (int i = 0, e = op.getNumOperands(); i < e; ++i) { - for (int i = 0, e = op.getNumOperands(); i < e; ++i) { - auto &value = op.getOperand(i); - // Skip from from first variadic operands for now. Else getOperand index - // used below doesn't match. - if (value.isVariadic()) break; - if (!value.name.empty()) - verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i)); - } - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - auto &value = op.getResult(i); - // Skip from from first variadic results for now. Else getResult index - // used below doesn't match. - if (value.isVariadic()) break; - if (!value.name.empty()) - verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i)); - } + auto &value = op.getOperand(i); + // Skip from from first variadic operands for now. Else getOperand index + // used below doesn't match. + if (value.isVariadic()) break; + if (!value.name.empty()) + verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i)); + } + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + auto &value = op.getResult(i); + // Skip from from first variadic results for now. Else getResult index + // used below doesn't match. + if (value.isVariadic()) break; + if (!value.name.empty()) + verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i)); } GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(), "operand"); From 12aa08b32bda44575a5d5807ebcc59fe5d1a897b Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Thu, 19 Mar 2020 16:55:26 -0700 Subject: [PATCH 274/492] Ifdef out cusolver wrapper functions disabled on windows. PiperOrigin-RevId: 301924035 Change-Id: I9842e63ea04f1e69129e1667947d953c11a38f1d --- tensorflow/stream_executor/cuda/cusparse_10_1.inc | 4 ++++ tensorflow/stream_executor/cuda/cusparse_10_2.inc | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tensorflow/stream_executor/cuda/cusparse_10_1.inc b/tensorflow/stream_executor/cuda/cusparse_10_1.inc index c63300697fe..3b7f3815829 100644 --- a/tensorflow/stream_executor/cuda/cusparse_10_1.inc +++ b/tensorflow/stream_executor/cuda/cusparse_10_1.inc @@ -7780,6 +7780,8 @@ cusparseStatus_t CUSPARSEAPI cusparseCsr2cscEx2_bufferSize( bufferSize); } +#if !defined(_WIN32) + cusparseStatus_t CUSPARSEAPI cusparseCreateSpVec(cusparseSpVecDescr_t *spVecDescr, int64_t size, int64_t nnz, void *indices, void *values, cusparseIndexType_t idxType, @@ -8223,4 +8225,6 @@ cusparseStatus_t CUSPARSEAPI cusparseConstrainedGeMM_bufferSize( bufferSize); } +#endif // _WIN32 + } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cusparse_10_2.inc b/tensorflow/stream_executor/cuda/cusparse_10_2.inc index c63300697fe..3b7f3815829 100644 --- a/tensorflow/stream_executor/cuda/cusparse_10_2.inc +++ b/tensorflow/stream_executor/cuda/cusparse_10_2.inc @@ -7780,6 +7780,8 @@ cusparseStatus_t CUSPARSEAPI cusparseCsr2cscEx2_bufferSize( bufferSize); } +#if !defined(_WIN32) + cusparseStatus_t CUSPARSEAPI cusparseCreateSpVec(cusparseSpVecDescr_t *spVecDescr, int64_t size, int64_t nnz, void *indices, void *values, cusparseIndexType_t idxType, @@ -8223,4 +8225,6 @@ cusparseStatus_t CUSPARSEAPI cusparseConstrainedGeMM_bufferSize( bufferSize); } +#endif // _WIN32 + } // extern "C" From 5c4931bbf69e0f006f210c6382a234e83dd4dc8e Mon Sep 17 00:00:00 2001 From: Kathy Ruan Date: Thu, 19 Mar 2020 16:56:44 -0700 Subject: [PATCH 275/492] Preserve original order of checkpoint layer dependencies by making an OrderedDict instead of sorting the dictionary. PiperOrigin-RevId: 301924263 Change-Id: I90582a85a73a260ee55bdc83400bb6d939f031ee --- tensorflow/python/keras/engine/network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 91cd1b77734..95e581a39dc 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -393,7 +393,7 @@ class Network(base_layer.Layer): weight_layer_index = 0 - dependencies = {} + dependencies = collections.OrderedDict() for layer_index, layer in enumerate(self.layers): try: if layer.weights: @@ -416,7 +416,7 @@ class Network(base_layer.Layer): def _checkpoint_dependencies(self): dependencies = [ trackable.TrackableReference(name=name, ref=layer) - for name, layer in sorted(self._layer_checkpoint_dependencies.items())] + for name, layer in self._layer_checkpoint_dependencies.items()] dependencies.extend(super(Network, self)._checkpoint_dependencies) return dependencies From bd0b515922ea8a480bdaacfc948488bcc365ec04 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 17:38:34 -0700 Subject: [PATCH 276/492] Remove the TENSORFLOW_MEM_DEBUG compilation flag from the path that passes TF op name etc. to BFCAllocator, i.e. enable the passing by default. PiperOrigin-RevId: 301930881 Change-Id: Ib0005391dbfe873f043f90598897e46c0b999ea3 --- tensorflow/core/common_runtime/bfc_allocator.cc | 2 ++ .../core/common_runtime/eager/eager_operation.cc | 2 ++ .../core/common_runtime/eager/eager_operation.h | 2 ++ tensorflow/core/framework/allocator.cc | 2 ++ tensorflow/core/framework/allocator.h | 14 ++++++++++++-- 5 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index df2bec93f0c..1100ba9684c 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -460,7 +460,9 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name, ",bytes_available=", bytes_available, ",peak_bytes_in_use=", stats.peak_bytes_in_use, ",requested_bytes=", requested_bytes, +#ifdef TENSORFLOW_MEM_DEBUG ",tf_op=", pending_op_name, ",id=", pending_step_id, +#endif "#"); }, traceme_level); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 7c4d04646a7..94b85a190c1 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -57,7 +57,9 @@ Status EagerOperation::Reset( cancellation_manager_ = nullptr; executor_ = executor ? executor : &ctx_.Executor(); remote_func_params_ = remote_func_params; +#ifdef TENSORFLOW_MEM_DEBUG op_name_ = op; +#endif return SetDeviceName(raw_device_name, true); } diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 3e3474d6b61..4b46fc5c709 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -121,8 +121,10 @@ class EagerOperation { return remote_func_params_; } +#ifdef TENSORFLOW_MEM_DEBUG const char* op_name() const { return op_name_; } const char* op_name_ = nullptr; +#endif Status MaybeInferSingleInputAttrs(TensorHandle* handle); Status InferInputListAttrs(int num_inputs); diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 7224aa8051b..6757a9b593e 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -27,8 +27,10 @@ limitations under the License. namespace tensorflow { +#ifdef TENSORFLOW_MEM_DEBUG thread_local const char* pending_op_name = nullptr; thread_local uint64 pending_step_id = 0; +#endif string AllocatorStats::DebugString() const { return strings::Printf( diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 609fe716180..2e239a4d6de 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -62,8 +62,9 @@ struct AllocationAttributes { TF_DISALLOW_COPY_AND_ASSIGN(AllocationAttributes); }; -// The runtime will cache Op names in thread-local memory and some allocators -// will try to tag allocations with the requesting Op. +// If defined, the runtime will cache Op names in thread-local memory +// and some allocators will try to tag allocations with the requesting Op. +#ifdef TENSORFLOW_MEM_DEBUG extern thread_local const char* pending_op_name; extern thread_local uint64 pending_step_id; #define MEMDEBUG_CACHE_OP(N) \ @@ -75,6 +76,15 @@ extern thread_local uint64 pending_step_id; pending_step_id = (N); \ } while (0) #define MEMDEBUG_CACHE_VAL pending_op_name +#else +#define MEMDEBUG_CACHE_OP(N) \ + do { \ + } while (0) +#define MEMDEBUG_CACHE_STEPID(N) \ + do { \ + } while (0) +#define MEMDEBUG_CACHE_VAL nullptr +#endif // Runtime statistics collected by an allocator. Exactly the same as // stream_executor::AllocatorStats, but independently defined to preserve the From 2528d18d1f5e2bdb844df5dfbebfa3338304e5ba Mon Sep 17 00:00:00 2001 From: Robert David Date: Thu, 19 Mar 2020 17:49:59 -0700 Subject: [PATCH 277/492] Remove unused function. PiperOrigin-RevId: 301932345 Change-Id: I24b7858c3efc9f380a58f01a6a68b7fb6c729f7f --- .../kernels/internal/optimized/optimized_ops.h | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 09122686db5..c943f7e989c 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -875,21 +875,6 @@ inline void ShuffledFullyConnected( #ifdef USE_NEON -inline float32x4_t DivideSumForMeanImpl( - const float32x4_t sum, const float32x4_t num_elements_reverse, - const bool ordinary_mean, const float32x4_t scale_dup, - const float32x4_t zero_point_with_bias_dup) { - const float32x4_t val = vmulq_f32(sum, num_elements_reverse); - if (!ordinary_mean) { -#ifdef ARM_FEATURE_FMA - return vfmaq_f32(zero_point_with_bias_dup, scale_dup, val); -#else - return vmlaq_f32(zero_point_with_bias_dup, scale_dup, val); -#endif // ARM_FEATURE_FMA - } - return val; -} - inline int32x4_t RoundToNearest(const float32x4_t input) { #if defined(__aarch64__) || defined(__SSSE3__) // Note: vcvtnq_s32_f32 is not available in ARMv7 From 30701c7344f5a1bbcc3cd955164201984c6451dd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 17:56:58 -0700 Subject: [PATCH 278/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301933256 Change-Id: Ifec774fbeda583daee52e347ce7820cc2e08ad5c --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 4f552e456e5..b8b73bc472d 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 22607193745e0145260d1fef07e4fd384d9f78b6 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Thu, 19 Mar 2020 17:59:17 -0700 Subject: [PATCH 279/492] Revert : Skip testNonMatchingVariableCreation. PiperOrigin-RevId: 301933594 Change-Id: Ifa25900442f291d25ef997fba09e61eb5aa1aafb --- tensorflow/python/distribute/mirrored_variable_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py index 0777bf3b42a..6623422b45f 100644 --- a/tensorflow/python/distribute/mirrored_variable_test.py +++ b/tensorflow/python/distribute/mirrored_variable_test.py @@ -459,7 +459,6 @@ class MirroredVariableCreationTest(test.TestCase): aggregation="invalid") def testNonMatchingVariableCreation(self, distribution): - self.skipTest("b/123075960") def model_fn(name): v = variable_scope.variable(1.0, name=name) @@ -467,7 +466,7 @@ class MirroredVariableCreationTest(test.TestCase): return v with distribution.scope(): - names = values.DistributedValues(("foo", "bar")) + names = values.PerReplica(("foo", "bar")) with self.assertRaises(RuntimeError): _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) From 08ab10f127bf7ef566c3b9930e729a6ba4cdd596 Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Thu, 19 Mar 2020 18:02:53 -0700 Subject: [PATCH 280/492] Fix formatting of lite/tools/pip_package/REDAME.md PiperOrigin-RevId: 301934203 Change-Id: I9146b9d9afca2f95576edbb2588dff8d8759e79e --- tensorflow/lite/tools/pip_package/README.md | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/tools/pip_package/README.md b/tensorflow/lite/tools/pip_package/README.md index 88906ee2bb0..849bbf57813 100644 --- a/tensorflow/lite/tools/pip_package/README.md +++ b/tensorflow/lite/tools/pip_package/README.md @@ -6,44 +6,51 @@ Python without requiring the rest of TensorFlow. ## Steps To build a binary wheel run this script: -``` + +```sh sudo apt install swig libjpeg-dev zlib1g-dev python3-dev python3-numpy sh tensorflow/lite/tools/pip_package/build_pip_package.sh ``` That will print out some output and a .whl file. You can then install that -``` + +```sh pip install --upgrade ``` You can also build a wheel inside docker container using make tool. For example the following command will cross-compile tflite-runtime package for python2.7 and python3.7 (from Debian Buster) on Raspberry Pi: -``` + +```sh make BASE_IMAGE=debian:buster PYTHON=python TENSORFLOW_TARGET=rpi docker-build make BASE_IMAGE=debian:buster PYTHON=python3 TENSORFLOW_TARGET=rpi docker-build ``` Another option is to cross-compile for python3.5 (from Debian Stretch) on ARM64 board: -``` + +```sh make BASE_IMAGE=debian:stretch PYTHON=python3 TENSORFLOW_TARGET=aarch64 docker-build ``` To build for python3.6 (from Ubuntu 18.04) on x86_64 (native to the docker image) run: -``` + +```sh make BASE_IMAGE=ubuntu:18.04 PYTHON=python3 TENSORFLOW_TARGET=native docker-build ``` In addition to the wheel there is a way to build Debian package by adding BUILD_DEB=y to the make command (only for python3): -``` + +```sh make BASE_IMAGE=debian:buster PYTHON=python3 TENSORFLOW_TARGET=rpi BUILD_DEB=y docker-build ``` Note, unlike tensorflow this will be installed to a tflite_runtime namespace. You can then use the Tensorflow Lite interpreter as. -``` + +```python from tflite_runtime.interpreter import Interpreter interpreter = Interpreter(model_path="foo.tflite") ``` From 40129bc3fa0fb21fc6a687201c1af850ba36400a Mon Sep 17 00:00:00 2001 From: Taehee Jeong Date: Thu, 19 Mar 2020 18:15:15 -0700 Subject: [PATCH 281/492] Upgrade protobuf to 3.9.2 PiperOrigin-RevId: 301935849 Change-Id: I63be177fc6601e4c68498e06bcba15fee8c445f1 --- tensorflow/tools/pip_package/setup.py | 2 +- tensorflow/workspace.bzl | 21 +++++---------- third_party/protobuf/protobuf.patch | 37 +++++++++++++-------------- 3 files changed, 26 insertions(+), 34 deletions(-) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 64a4469e0da..3cb24bf8999 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -60,7 +60,7 @@ REQUIRED_PACKAGES = [ 'keras_preprocessing >= 1.1.0', 'numpy >= 1.16.0, < 2.0', 'opt_einsum >= 2.3.2', - 'protobuf >= 3.8.0', + 'protobuf >= 3.9.2', 'tensorboard >= 2.1.0, < 2.2.0', 'tensorflow_estimator >= 2.1.0, < 2.2.0', 'termcolor >= 1.1.0', diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 8c4c8473faa..1066479823a 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -471,26 +471,19 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): }, ) - # 310ba5ee72661c081129eb878c1bbcec936b20f0 is based on 3.8.0 with a fix for protobuf.bzl. - PROTOBUF_URLS = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/protocolbuffers/protobuf/archive/310ba5ee72661c081129eb878c1bbcec936b20f0.tar.gz", - "https://github.com/protocolbuffers/protobuf/archive/310ba5ee72661c081129eb878c1bbcec936b20f0.tar.gz", - ] - PROTOBUF_SHA256 = "b9e92f9af8819bbbc514e2902aec860415b70209f31dfc8c4fa72515a5df9d59" - PROTOBUF_STRIP_PREFIX = "protobuf-310ba5ee72661c081129eb878c1bbcec936b20f0" - - PROTOBUF_PATCH = "//third_party/protobuf:protobuf.patch" - tf_http_archive( name = "com_google_protobuf", - patch_file = clean_dep(PROTOBUF_PATCH), - sha256 = PROTOBUF_SHA256, - strip_prefix = PROTOBUF_STRIP_PREFIX, + patch_file = clean_dep("//third_party/protobuf:protobuf.patch"), + sha256 = "cfcba2df10feec52a84208693937c17a4b5df7775e1635c1e3baffc487b24c9b", + strip_prefix = "protobuf-3.9.2", system_build_file = clean_dep("//third_party/systemlibs:protobuf.BUILD"), system_link_files = { "//third_party/systemlibs:protobuf.bzl": "protobuf.bzl", }, - urls = PROTOBUF_URLS, + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/protocolbuffers/protobuf/archive/v3.9.2.zip", + "https://github.com/protocolbuffers/protobuf/archive/v3.9.2.zip", + ], ) tf_http_archive( diff --git a/third_party/protobuf/protobuf.patch b/third_party/protobuf/protobuf.patch index decd92e9d03..8ce4a843759 100644 --- a/third_party/protobuf/protobuf.patch +++ b/third_party/protobuf/protobuf.patch @@ -1,17 +1,25 @@ diff --git a/BUILD b/BUILD -index 2fb26050..c2744d5b 100644 +index dbae719ff..87dc38470 100644 --- a/BUILD +++ b/BUILD -@@ -19,7 +19,7 @@ config_setting( +@@ -23,7 +23,7 @@ config_setting( # ZLIB configuration ################################################################################ - + -ZLIB_DEPS = ["@zlib//:zlib"] +ZLIB_DEPS = ["@zlib"] - + ################################################################################ # Protobuf Runtime Library -@@ -209,6 +209,7 @@ cc_library( +@@ -143,6 +143,7 @@ cc_library( + copts = COPTS, + includes = ["src/"], + linkopts = LINK_OPTS, ++ alwayslink = 1, + visibility = ["//visibility:public"], + ) + +@@ -213,6 +214,7 @@ cc_library( copts = COPTS, includes = ["src/"], linkopts = LINK_OPTS, @@ -19,26 +27,17 @@ index 2fb26050..c2744d5b 100644 visibility = ["//visibility:public"], deps = [":protobuf_lite"] + PROTOBUF_DEPS, ) -@@ -219,7 +220,7 @@ cc_library( - # TODO(keveman): Remove this target once the support gets added to Bazel. - cc_library( - name = "protobuf_headers", -- hdrs = glob(["src/**/*.h"]), -+ hdrs = glob(["src/**/*.h", "src/**/*.inc"]), - includes = ["src/"], - visibility = ["//visibility:public"], - ) - diff --git a/protobuf.bzl b/protobuf.bzl -index e0653321f..4ac23594b 100644 +index e0653321f..253d9cbb5 100644 --- a/protobuf.bzl +++ b/protobuf.bzl -@@ -85,6 +85,8 @@ def _proto_gen_impl(ctx): +@@ -84,7 +84,9 @@ def _proto_gen_impl(ctx): + for dep in ctx.attr.deps: import_flags += dep.proto.import_flags deps += dep.proto.deps + import_flags = depset(import_flags).to_list() + deps = depset(deps).to_list() - + if not ctx.attr.gen_cc and not ctx.attr.gen_py and not ctx.executable.plugin: - return struct( + return struct( \ No newline at end of file From c826dad7f49869eef62777c4ca386ee3f988fe70 Mon Sep 17 00:00:00 2001 From: Ir1d Date: Fri, 20 Mar 2020 09:26:19 +0800 Subject: [PATCH 282/492] use correct indent for multiline doctest --- tensorflow/python/keras/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 45102fa2cfb..428f1ff7acf 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -1480,7 +1480,7 @@ def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): Example: >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3), - low=0.0, high=1.0) + ... low=0.0, high=1.0) >>> kvar @@ -1515,7 +1515,7 @@ def random_normal_variable(shape, mean, scale, dtype=None, name=None, Example: >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3), - mean=0.0, scale=1.0) + ... mean=0.0, scale=1.0) >>> kvar @@ -5691,7 +5691,7 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): Example: >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3), - minval=0.0, maxval=1.0) + ... minval=0.0, maxval=1.0) >>> random_uniform_tensor @@ -5724,7 +5724,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None): Example: >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3), - p=0.5) + ... p=0.5) >>> random_binomial_tensor From b28cbd90877456ea3f3407a475d40301ce1a9251 Mon Sep 17 00:00:00 2001 From: Ir1d Date: Fri, 20 Mar 2020 09:27:36 +0800 Subject: [PATCH 283/492] use correct indent for multiline doctest --- tensorflow/python/keras/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 428f1ff7acf..9859cc66184 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -5659,7 +5659,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): Example: >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3), - mean=0.0, stddev=1.0) + ... mean=0.0, stddev=1.0) >>> random_normal_tensor From c75cfa96c825042a676ef3403cb286d3136c6048 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Thu, 19 Mar 2020 18:34:53 -0700 Subject: [PATCH 284/492] A pass to fuse multiple hlo xla ops as one quant region We use table gen to define the source pattern to match the op sequence. Then a native code call is used to create a quant region with a specific kernel name. PiperOrigin-RevId: 301938424 Change-Id: I1649488220bc6b7e8ed7b63be11afd4fc098169c --- .../compiler/mlir/lite/quantization/xla/BUILD | 26 +++ .../quantization/xla/cpu_kernel_fusion.cc | 210 ++++++++++++++++++ .../quantization/xla/cpu_kernel_fusion.td | 24 ++ .../xla/tests/cpu_kernel_fusion.mlir | 16 ++ 4 files changed, 276 insertions(+) create mode 100644 tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc create mode 100644 tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td create mode 100644 tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD index 2c5bed86a84..36c897d5fec 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD @@ -1,3 +1,8 @@ +load( + "//third_party/mlir:tblgen.bzl", + "gentbl", +) + package( default_visibility = [ ":friends", @@ -18,6 +23,8 @@ package_group( cc_library( name = "hlo_xla_quantization_passes", srcs = [ + "cpu_kernel_fusion.cc", + "generated_cpu_kernel_fusion.inc", "materialize.cc", "op_quant_spec.inc", "propagate.cc", @@ -36,6 +43,8 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -65,3 +74,20 @@ cc_library( "@llvm-project//mlir:Transforms", ], ) + +gentbl( + name = "cpu_kernel_fusion_inc_gen", + tbl_outs = [ + ( + "-gen-rewriters", + "generated_cpu_kernel_fusion.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "cpu_kernel_fusion.td", + td_srcs = [ + "@llvm-project//mlir:StdOpsTdFiles", + "//tensorflow/compiler/mlir/xla:hlo_ops_td_files", + "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc new file mode 100644 index 00000000000..4ca5692584f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc @@ -0,0 +1,210 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/xla/client/lib/quantize.h" + +#define DEBUG_TYPE "quant-kernel-fusion" + +namespace mlir { +namespace xla_hlo { + +namespace { + +// Collects input values from outside for 'ops'. +void CollectInputs(llvm::ArrayRef ops, + llvm::SmallVectorImpl* inputs, + llvm::SmallVectorImpl* input_specs) { + for (auto* op : ops) { + for (auto operand : op->getOperands()) { + if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) { + continue; + } + if (auto* def_op = operand.getDefiningOp()) { + if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) { + inputs->push_back(operand); + } + } else { // argument value + inputs->push_back(operand); + } + } + } + + for (auto input : *inputs) { + ShapedType input_type = input.getType().cast(); + // TODO(fengliuai): detect whether it is from fake quant. + input_specs->push_back(TypeAttr::get(input_type.getElementType())); + } +} + +// Collects values that are produced by 'ops' and have use outside of 'ops'. +// TODO(fengliuai): if it is a single user and QDQ, write that to the specs. +void CollectRets(llvm::ArrayRef ops, + llvm::SmallVectorImpl* rets, + llvm::SmallVectorImpl* ret_types, + llvm::SmallVectorImpl* ret_specs) { + for (auto* op : ops) { + for (auto result : op->getResults()) { + for (auto* user : result.getUsers()) { + // If there are any user outside of 'ops' + if (std::find(ops.begin(), ops.end(), user) == ops.end()) { + ShapedType ret_type = result.getType().cast(); + rets->push_back(result); + ret_types->push_back(ret_type); + // TODO(fengliuai): detect whether it is used by fake quant. + ret_specs->push_back(TypeAttr::get(ret_type.getElementType())); + break; + } + } + } + } +} + +llvm::SmallVector fuseOps(PatternRewriter* rewriter, + const std::initializer_list& results, + StringRef kernel) { + // Collect all the operations to be fused. + llvm::SmallVector fused; + llvm::SmallVector locs; + fused.reserve(results.size()); + locs.reserve(results.size()); + for (auto value : results) { + Operation* op = value.getDefiningOp(); + fused.push_back(op); + locs.push_back(op->getLoc()); + } + + // Collect inputs from outside to 'ops'. + llvm::SmallVector inputs; + llvm::SmallVector input_specs; + CollectInputs(fused, &inputs, &input_specs); + + // Collect outputs from 'ops' to outside. + llvm::SmallVector rets; + llvm::SmallVector ret_types; + llvm::SmallVector ret_specs; + CollectRets(fused, &rets, &ret_types, &ret_specs); + + // Create the region op with the return. + auto region = rewriter->create( + rewriter->getFusedLoc(locs), ret_types, inputs, + rewriter->getArrayAttr(input_specs), rewriter->getArrayAttr(ret_specs), + kernel); + auto* body = new Block(); + region.body().push_back(body); + + OpBuilder builder(body); + BlockAndValueMapping mapping; + + // Make block arguments and add it to the block value mapping. + for (Value input : inputs) { + mapping.map(input, body->addArgument(input.getType())); + } + + // Clone the operations 'ops' to the region. + for (Operation* op : fused) { + builder.clone(*op, mapping); + } + + llvm::SmallVector new_rets; + new_rets.reserve(rets.size()); + for (auto ret : llvm::enumerate(rets)) { + Value new_ret = mapping.lookupOrNull(ret.value()); + assert(new_ret && "couldn't find return value."); + new_rets.push_back(new_ret); + ret.value().replaceAllUsesWith(region.getResult(ret.index())); + } + builder.create(builder.getUnknownLoc(), new_rets); + + LLVM_DEBUG({ + assert(region.verify().Success && "failed to create quant region."); + llvm::dbgs() << "\ncreated region: "; + region.print(llvm::dbgs()); + llvm::dbgs() << "\n\n\n"; + }); + + SmallVector new_values(fused.back()->getNumResults()); + return new_values; +} + +struct CpuKernelFusionPass : public FunctionPass { + explicit CpuKernelFusionPass() = default; + CpuKernelFusionPass(const CpuKernelFusionPass&) {} + + void runOnFunction() override; + + private: + LogicalResult fuseCpuKernels(Operation* op); +}; + +#include "tensorflow/compiler/mlir/lite/quantization/xla/generated_cpu_kernel_fusion.inc" + +LogicalResult CpuKernelFusionPass::fuseCpuKernels(Operation* op) { + MLIRContext* ctx = op->getContext(); + OwningRewritePatternList patterns; + populateWithGenerated(ctx, &patterns); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addLegalOp(); + return applyPartialConversion(op, target, patterns); +} + +void CpuKernelFusionPass::runOnFunction() { + if (failed(fuseCpuKernels(getFunction()))) signalPassFailure(); +} + +} // namespace + +// Creates an instance of the xla_hlo cpu kernel fusion pass. +std::unique_ptr> CreateCpuKernelFusionPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "xla-hlo-cpu-fusion", "Fuse xla hlo ops into cpu kernels"); + +} // namespace xla_hlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td new file mode 100644 index 00000000000..59a188792ab --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td @@ -0,0 +1,24 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" + +class Fused2Ops : NativeCodeCall< + "fuseOps(&$_builder, {$0, $1}, \"" # kernel # "\")">; +class Fused3Ops : NativeCodeCall< + "fuseOps(&$_builder, {$0, $1, $2}, \"" # kernel # "\")">; + +def : Pat<(HLO_AddOp:$add (HLO_MulOp:$mul $_, $_, $_), $_, $_), + (Fused2Ops<"generic.mul_add"> $mul, $add)>; diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir new file mode 100644 index 00000000000..3ca989b715c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir @@ -0,0 +1,16 @@ +// RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s + +// CHECK-LABEL: @mul_add +func @mul_add(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> + +// CHECK: %[[region:.*]] = "quant.region"(%arg0, %arg1, %arg2) ( { +// CHECK: ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors +// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> +// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> +// CHECK: "quant.return"(%[[add]]) : (tensor<4xf32>) -> () +// CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[region]] : tensor<4xf32> +} From b04c4e0e4338924d5281626445594a900bd673a6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 20:07:45 -0700 Subject: [PATCH 285/492] Flush denormals to +/- 0 when converting float to bfloat16. PiperOrigin-RevId: 301948798 Change-Id: Ic24b699b2e23683d3710d7abb4317833df252af0 --- tensorflow/core/framework/bfloat16_test.cc | 40 +++++++++++++++++++--- tensorflow/core/lib/bfloat16/bfloat16.h | 8 +++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc index bce18c74b49..8e780251498 100644 --- a/tensorflow/core/framework/bfloat16_test.cc +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -23,6 +23,35 @@ limitations under the License. namespace tensorflow { namespace { +TEST(Bfloat16Test, ZeroRepresentations) { + ASSERT_EQ(bfloat16{0.0f}, bfloat16{0.0f}); + ASSERT_EQ(bfloat16{-0.0f}, bfloat16{0.0f}); + ASSERT_EQ(bfloat16{-0.0f}, bfloat16{-0.0f}); + ASSERT_EQ(bfloat16{0.0f}.value, 0x0000); + ASSERT_EQ(bfloat16{-0.0f}.value, 0x8000); +} + +TEST(Bfloat16Test, FlushDenormalsToZero) { + for (float denorm = -std::numeric_limits::denorm_min(); + denorm < std::numeric_limits::denorm_min(); + denorm = std::nextafterf(denorm, 1.0f)) { + bfloat16 bf_trunc = bfloat16::truncate_to_bfloat16(denorm); + ASSERT_EQ(float{bf_trunc}, 0.0f); + if (std::signbit(denorm)) { + ASSERT_EQ(bf_trunc.value, 0x8000) << denorm; + } else { + ASSERT_EQ(bf_trunc.value, 0x0000) << denorm; + } + bfloat16 bf_round = bfloat16::round_to_bfloat16(denorm); + ASSERT_EQ(float{bf_round}, 0.0f); + if (std::signbit(denorm)) { + ASSERT_EQ(bf_round.value, 0x8000) << denorm; + } else { + ASSERT_EQ(bf_round.value, 0x0000) << denorm; + } + } +} + TEST(Bfloat16Test, DefaultValueIsZero) { EXPECT_EQ(0.0f, static_cast(bfloat16())); } @@ -65,6 +94,7 @@ TEST_P(Bfloat16Test, TruncateTest) { EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated))); return; } + EXPECT_EQ(GetParam().expected_truncation, float(truncated)); bfloat16 rounded = bfloat16::round_to_bfloat16((GetParam().input)); @@ -114,14 +144,16 @@ INSTANTIATE_TEST_SUITE_P( BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000), BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000), BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + // The following two floats are denormals and will be flushed + // to zero. Bfloat16TestParam{ BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000), - BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000), - BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)}, + BinaryToFloat(0, 0b00000000, 0b0000000, 0b0000000000000000), + BinaryToFloat(0, 0b00000000, 0b0000000, 0b0000000000000000)}, Bfloat16TestParam{ BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000), - BinaryToFloat(0, 0b00000000, 0b1111111, 0b0000000000000000), - BinaryToFloat(0, 0b00000001, 0b0000000, 0b0000000000000000)})); + BinaryToFloat(0, 0b00000000, 0b0000000, 0b0000000000000000), + BinaryToFloat(0, 0b00000000, 0b0000000, 0b0000000000000000)})); TEST(Bfloat16Test, Conversion) { float a[100]; diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h index a25f4d947ed..89850ed7ed1 100644 --- a/tensorflow/core/lib/bfloat16/bfloat16.h +++ b/tensorflow/core/lib/bfloat16/bfloat16.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/core/platform/byte_order.h" @@ -53,6 +54,10 @@ struct bfloat16 { if (float_isnan(v)) { output.value = NAN_VALUE; return output; + } else if (std::fabs(v) < std::numeric_limits::min()) { + // Flush denormal to +/- 0. + output.value = std::signbit(v) ? 0x8000 : 0; + return output; } const uint16_t* p = reinterpret_cast(&v); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ @@ -196,6 +201,9 @@ struct bfloat16 { // qNaN magic: All exponent bits set + most significant bit of fraction // set. output.value = 0x7fc0; + } else if (std::fabs(v) < std::numeric_limits::min()) { + // Flush denormal to +/- 0.0 + output.value = std::signbit(v) ? 0x8000 : 0; } else { // Fast rounding algorithm that rounds a half value to nearest even. This // reduces expected error when we convert a large number of floats. Here From 3931d39379b9feb44d4f8edba0906e96629d6884 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Thu, 19 Mar 2020 20:34:36 -0700 Subject: [PATCH 286/492] Fix crash in Model.fit() if a gradient is None. The crash would occur with distributed strategy if multiple devices were used. PiperOrigin-RevId: 301951658 Change-Id: I2e596599bec19f1caa7cf41bd70771a7e6f7541d --- .../distribute/distribute_strategy_test.py | 23 +++++++++++++++++++ .../python/keras/optimizer_v2/optimizer_v2.py | 6 +++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 4ca3cf2b142..c696ff18f93 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -796,6 +796,29 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, atol=1e-4, rtol=1e-4) + @combinations.generate(all_strategy_combinations_plus_run_distributed()) + def test_gradient_is_none(self, distribution): + + if not context.executing_eagerly(): + self.skipTest('None gradients are not supported in graph mode') + + class DenseWithExtraWeight(keras.layers.Dense): + + def build(self, input_shape): + super(DenseWithExtraWeight, self).build(input_shape) + # Gradient w.r.t. extra_weight is None + self.extra_weight = self.add_weight('extra_weight', shape=(), + initializer='ones') + + with self.cached_session(): + with distribution.scope(): + model = keras.Sequential([DenseWithExtraWeight(4, input_shape=(4,))]) + model.compile('adam', 'mse') + + inputs = np.zeros((64, 4), dtype='float32') + targets = np.zeros((64, 4), dtype='float32') + model.fit(inputs, targets) + class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 2a4d4cf86e8..4e42b313a94 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -508,9 +508,10 @@ class OptimizerV2(trackable.Trackable): grads_and_vars: List of (gradient, variable) pairs. Returns: - A list of all-reduced gradients. + A list of all-reduced gradients. Any gradients which were None are + removed. """ - grads_and_vars = list(grads_and_vars) + grads_and_vars = _filter_grads(grads_and_vars) def all_reduce_fn(distribution, grads_and_vars): return distribution.extended.batch_reduce_to( ds_reduce_util.ReduceOp.SUM, grads_and_vars) @@ -520,6 +521,7 @@ class OptimizerV2(trackable.Trackable): # TODO(b/150507409): Do not switch to a cross-replica context once the bug # is fixed. if grads_and_vars: + # TODO(reedwm): Should we return the None gradients as well? return distribute_ctx.get_replica_context().merge_call( all_reduce_fn, args=(grads_and_vars,)) From cdf7d9830c40a3e51d82e29bff495fb3c7b5e899 Mon Sep 17 00:00:00 2001 From: Cesar Crusius Date: Thu, 19 Mar 2020 21:12:38 -0700 Subject: [PATCH 287/492] Intenral Copybara change. PiperOrigin-RevId: 301955874 Change-Id: I104e6217a7598ccc4d049d385fb992bfbd8114fc --- tensorflow/core/kernels/eigen_contraction_kernel.cc | 10 ++++++++-- tensorflow/core/platform/strcat.h | 2 +- tensorflow/lite/delegates/gpu/gl/kernels/resize.cc | 2 +- tensorflow/lite/delegates/gpu/metal/kernels/resize.cc | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/eigen_contraction_kernel.cc b/tensorflow/core/kernels/eigen_contraction_kernel.cc index aa6cb4b9cb9..4959651569c 100644 --- a/tensorflow/core/kernels/eigen_contraction_kernel.cc +++ b/tensorflow/core/kernels/eigen_contraction_kernel.cc @@ -28,7 +28,9 @@ limitations under the License. // the configuration through the environment variable. // // Example: -// bazel test --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test +// bazel test \ +// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ +// //path/to:test #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) @@ -37,7 +39,11 @@ namespace internal { // TODO(ezhulenev): This is a temporary workaround for disabling custom kernels // at runtime in tests. We should always rely on compile time flags for that. -// Example: ... --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test +// +// Example: +// bazel test \ +// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ +// //path/to:test EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels() { static bool use_custom_contraction_kernel = true; diff --git a/tensorflow/core/platform/strcat.h b/tensorflow/core/platform/strcat.h index 6b435dceca3..640355c9ea5 100644 --- a/tensorflow/core/platform/strcat.h +++ b/tensorflow/core/platform/strcat.h @@ -33,7 +33,7 @@ limitations under the License. // to your function, your callers will automatically convert bools, integers, // and floating point values to strings for you. // -// NOTE: Use of AlphaNum outside of the //strings package is unsupported except +// NOTE: Use of AlphaNum outside of the "strings" package is unsupported except // for the specific case of function parameters of type "AlphaNum" or "const // AlphaNum &". In particular, instantiating AlphaNum directly as a stack // variable is not supported. diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc index b8949e41426..33d59518987 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc @@ -93,7 +93,7 @@ class Resize : public NodeShader { st.xy = max(icoord_floor, ivec2(0, 0)); st.zw = min(icoord_floor + ivec2(1, 1), borders); - vec2 t = coord - coord_floor; //interpolating factors + vec2 t = coord - coord_floor; // interpolating factors vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$; vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc index 2ed75ad65b1..24d7bcf13bc 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc @@ -54,7 +54,7 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) { int4 st; st.xy = max(itex_coord_floor, int2(0, 0)); st.zw = min(itex_coord_floor + int2(1, 1), borders); - const float2 t = tex_coord - tex_coord_floor; //interpolating factors + const float2 t = tex_coord - tex_coord_floor; // interpolating factors const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x; const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z; const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x; From f50031920375dae2f3c67afbbc8669086db20de4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 21:43:55 -0700 Subject: [PATCH 288/492] Make saved models more deterministic Nondeterministic serialized protos hurt caching. PiperOrigin-RevId: 301958908 Change-Id: I5032ec10250d5da86e709023d5cb7088ec6582bf --- tensorflow/python/eager/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 46331461d4a..c0331d760b9 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -139,7 +139,8 @@ class FunctionCallOptions(object): @config_proto_serialized.setter def config_proto_serialized(self, config): if isinstance(config, config_pb2.ConfigProto): - self._config_proto_serialized = config.SerializeToString() + self._config_proto_serialized = config.SerializeToString( + deterministic=True) elif isinstance(config, str): self._config_proto_serialized = config elif config is None: From 308835c96a6488e8c7ec95fbb309e09072cd8799 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 21:44:19 -0700 Subject: [PATCH 289/492] Fix crash in Model.fit() if a gradient is None. The crash would occur with distributed strategy if multiple devices were used. PiperOrigin-RevId: 301958941 Change-Id: I92081f22b66f62e7749525e2c292d45262ab9ae7 --- .../distribute/distribute_strategy_test.py | 23 ------------------- .../python/keras/optimizer_v2/optimizer_v2.py | 6 ++--- 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index c696ff18f93..4ca3cf2b142 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -796,29 +796,6 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, atol=1e-4, rtol=1e-4) - @combinations.generate(all_strategy_combinations_plus_run_distributed()) - def test_gradient_is_none(self, distribution): - - if not context.executing_eagerly(): - self.skipTest('None gradients are not supported in graph mode') - - class DenseWithExtraWeight(keras.layers.Dense): - - def build(self, input_shape): - super(DenseWithExtraWeight, self).build(input_shape) - # Gradient w.r.t. extra_weight is None - self.extra_weight = self.add_weight('extra_weight', shape=(), - initializer='ones') - - with self.cached_session(): - with distribution.scope(): - model = keras.Sequential([DenseWithExtraWeight(4, input_shape=(4,))]) - model.compile('adam', 'mse') - - inputs = np.zeros((64, 4), dtype='float32') - targets = np.zeros((64, 4), dtype='float32') - model.fit(inputs, targets) - class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 4e42b313a94..2a4d4cf86e8 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -508,10 +508,9 @@ class OptimizerV2(trackable.Trackable): grads_and_vars: List of (gradient, variable) pairs. Returns: - A list of all-reduced gradients. Any gradients which were None are - removed. + A list of all-reduced gradients. """ - grads_and_vars = _filter_grads(grads_and_vars) + grads_and_vars = list(grads_and_vars) def all_reduce_fn(distribution, grads_and_vars): return distribution.extended.batch_reduce_to( ds_reduce_util.ReduceOp.SUM, grads_and_vars) @@ -521,7 +520,6 @@ class OptimizerV2(trackable.Trackable): # TODO(b/150507409): Do not switch to a cross-replica context once the bug # is fixed. if grads_and_vars: - # TODO(reedwm): Should we return the None gradients as well? return distribute_ctx.get_replica_context().merge_call( all_reduce_fn, args=(grads_and_vars,)) From 289d2827ce20b76a1b10b61380f6fbcae9577a16 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 19 Mar 2020 21:50:57 -0700 Subject: [PATCH 290/492] Always link the static filesystem registration target. --- tensorflow/c/experimental/filesystem/plugins/posix/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD index 49a412dfb6a..3afe114b5a6 100644 --- a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:modular_filesystem", ], + alwayslink = 1, ) # Library implementing helper functionality, so that the above only contains From c48d2a48bbbb9036a9b86194a8ac6bd3160ade00 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 21:46:58 -0700 Subject: [PATCH 291/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301959179 Change-Id: Idfb2acaec4473be712f4293594d2fa8d46406812 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index b8b73bc472d..4f552e456e5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 61581a3b7c76aa5feea1c0db8b06064378d7cd47 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 22:30:55 -0700 Subject: [PATCH 292/492] Intenral Copybara change. PiperOrigin-RevId: 301964706 Change-Id: I39757e487623e2b728654446ec39d11f997317a3 --- tensorflow/core/kernels/eigen_contraction_kernel.cc | 10 ++-------- tensorflow/core/platform/strcat.h | 2 +- tensorflow/lite/delegates/gpu/gl/kernels/resize.cc | 2 +- tensorflow/lite/delegates/gpu/metal/kernels/resize.cc | 2 +- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/kernels/eigen_contraction_kernel.cc b/tensorflow/core/kernels/eigen_contraction_kernel.cc index 4959651569c..aa6cb4b9cb9 100644 --- a/tensorflow/core/kernels/eigen_contraction_kernel.cc +++ b/tensorflow/core/kernels/eigen_contraction_kernel.cc @@ -28,9 +28,7 @@ limitations under the License. // the configuration through the environment variable. // // Example: -// bazel test \ -// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ -// //path/to:test +// bazel test --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) @@ -39,11 +37,7 @@ namespace internal { // TODO(ezhulenev): This is a temporary workaround for disabling custom kernels // at runtime in tests. We should always rely on compile time flags for that. -// -// Example: -// bazel test \ -// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ -// //path/to:test +// Example: ... --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels() { static bool use_custom_contraction_kernel = true; diff --git a/tensorflow/core/platform/strcat.h b/tensorflow/core/platform/strcat.h index 640355c9ea5..6b435dceca3 100644 --- a/tensorflow/core/platform/strcat.h +++ b/tensorflow/core/platform/strcat.h @@ -33,7 +33,7 @@ limitations under the License. // to your function, your callers will automatically convert bools, integers, // and floating point values to strings for you. // -// NOTE: Use of AlphaNum outside of the "strings" package is unsupported except +// NOTE: Use of AlphaNum outside of the //strings package is unsupported except // for the specific case of function parameters of type "AlphaNum" or "const // AlphaNum &". In particular, instantiating AlphaNum directly as a stack // variable is not supported. diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc index 33d59518987..b8949e41426 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc @@ -93,7 +93,7 @@ class Resize : public NodeShader { st.xy = max(icoord_floor, ivec2(0, 0)); st.zw = min(icoord_floor + ivec2(1, 1), borders); - vec2 t = coord - coord_floor; // interpolating factors + vec2 t = coord - coord_floor; //interpolating factors vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$; vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc index 24d7bcf13bc..2ed75ad65b1 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc @@ -54,7 +54,7 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) { int4 st; st.xy = max(itex_coord_floor, int2(0, 0)); st.zw = min(itex_coord_floor + int2(1, 1), borders); - const float2 t = tex_coord - tex_coord_floor; // interpolating factors + const float2 t = tex_coord - tex_coord_floor; //interpolating factors const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x; const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z; const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x; From 356fb0d529aca082bc6d4e04c186f138a38d2bc9 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Thu, 19 Mar 2020 23:18:51 -0700 Subject: [PATCH 293/492] Utilize helper functions to parse delegate string and create delegate for evaluation-related tools. PiperOrigin-RevId: 301969297 Change-Id: I1b66dde0cec92ec4b103c2a4336a5c133ab6a3e3 --- tensorflow/lite/tools/accuracy/ilsvrc/BUILD | 2 +- .../ilsvrc/imagenet_model_evaluator.cc | 24 ++++--------- .../ilsvrc/imagenet_model_evaluator.h | 2 +- tensorflow/lite/tools/benchmark/BUILD | 2 +- .../benchmark/xnnpack_delegate_provider.cc | 14 +++----- tensorflow/lite/tools/evaluation/stages/BUILD | 1 + .../stages/tflite_inference_stage.cc | 35 +++++-------------- .../tasks/coco_object_detection/BUILD | 1 + .../tasks/coco_object_detection/run_eval.cc | 22 ++++++------ .../tasks/imagenet_image_classification/BUILD | 1 + .../imagenet_image_classification/run_eval.cc | 26 +++++++------- .../evaluation/tasks/inference_diff/BUILD | 1 + .../tasks/inference_diff/run_eval.cc | 20 +++++------ 13 files changed, 56 insertions(+), 95 deletions(-) diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD index 9af47e2c1d7..f350914030b 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/BUILD +++ b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD @@ -24,7 +24,7 @@ cc_library( "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", - "//tensorflow/lite/tools/accuracy/ilsvrc/default:custom_delegates", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc index 6fbd18d6c2b..558ee8b1dd3 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc @@ -25,17 +25,15 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/tools/accuracy/ilsvrc/default/custom_delegates.h" #include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h" #include "tensorflow/lite/tools/evaluation/utils.h" namespace { - using tflite::evaluation::ImageLabel; -using tflite::evaluation::TfliteInferenceParams; constexpr char kNumImagesFlag[] = "num_images"; constexpr char kModelOutputLabelsFlag[] = "model_output_labels"; @@ -45,8 +43,6 @@ constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path"; constexpr char kModelFileFlag[] = "model_file"; constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads"; constexpr char kDelegateFlag[] = "delegate"; -constexpr char kNnapiDelegate[] = "nnapi"; -constexpr char kGpuDelegate[] = "gpu"; constexpr char kNumRanksFlag[] = "num_ranks"; template @@ -71,7 +67,6 @@ std::vector> Split(const std::vector& v, int n) { } return vecs; } - } // namespace namespace tensorflow { @@ -140,9 +135,10 @@ class CompositeObserver : public ImagenetModelEvaluator::Observer { tflite::Flag::CreateFlag( kInterpreterThreadsFlag, ¶ms.num_interpreter_threads, "Number of interpreter threads to use for inference."), - tflite::Flag::CreateFlag(kDelegateFlag, ¶ms.delegate, - "Delegate to use for inference, if available. " - "Must be one of {'nnapi', 'gpu'}"), + tflite::Flag::CreateFlag( + kDelegateFlag, ¶ms.delegate, + "Delegate to use for inference, if available. " + "Must be one of {'nnapi', 'gpu', 'hexagon', xnnpack'}"), tflite::Flag::CreateFlag(kNumRanksFlag, ¶ms.num_ranks, "Generates the top-1 to top-k accuracy values" "where k = num_ranks. Default: 10"), @@ -172,20 +168,14 @@ TfLiteStatus EvaluateModelForShard(const uint64_t shard_id, auto* inference_params = classification_params->mutable_inference_params(); inference_params->set_model_file_path(params.model_file_path); inference_params->set_num_threads(params.num_interpreter_threads); - if (params.delegate == kNnapiDelegate) { - inference_params->set_delegate(TfliteInferenceParams::NNAPI); - } else if (params.delegate == kGpuDelegate) { - inference_params->set_delegate(TfliteInferenceParams::GPU); - } + inference_params->set_delegate( + tflite::evaluation::ParseStringToDelegateType(params.delegate)); classification_params->mutable_topk_accuracy_eval_params()->set_k(num_ranks); tflite::evaluation::ImageClassificationStage eval(eval_config); eval.SetAllLabels(model_labels); TF_LITE_ENSURE_STATUS(eval.Init()); - TF_LITE_ENSURE_STATUS(tflite::evaluation::ApplyCustomDelegates( - params.delegate, params.num_interpreter_threads, &eval)); - for (const auto& image_label : image_labels) { eval.SetInputs(image_label.image, image_label.label); diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h index 8776a20ae33..b10b0f8f6e6 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h @@ -64,7 +64,7 @@ class ImagenetModelEvaluator { std::string blacklist_file_path; // Delegate used to perform inference (if available). - // Valid values: 'nnapi', 'gpu'. + // Valid values: 'nnapi', 'gpu', 'hexagon', 'xnnpack' std::string delegate; // The maximum number of images to calculate accuracy. diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index df527c6896b..6d946b9702c 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -301,7 +301,7 @@ cc_library( ":benchmark_model_lib", ":delegate_provider_hdr", ":logging", - "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + "//tensorflow/lite/tools/evaluation:utils", ], alwayslink = 1, ) diff --git a/tensorflow/lite/tools/benchmark/xnnpack_delegate_provider.cc b/tensorflow/lite/tools/benchmark/xnnpack_delegate_provider.cc index 63270fc2cd4..8fa9e7de69a 100644 --- a/tensorflow/lite/tools/benchmark/xnnpack_delegate_provider.cc +++ b/tensorflow/lite/tools/benchmark/xnnpack_delegate_provider.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/tools/benchmark/benchmark_model.h" #include "tensorflow/lite/tools/benchmark/delegate_provider.h" #include "tensorflow/lite/tools/benchmark/logging.h" +#include "tensorflow/lite/tools/evaluation/utils.h" namespace tflite { namespace benchmark { @@ -55,17 +55,11 @@ void XnnpackDelegateProvider::LogParams(const BenchmarkParams& params) const { TfLiteDelegatePtr XnnpackDelegateProvider::CreateTfLiteDelegate( const BenchmarkParams& params) const { - TfLiteDelegatePtr delegate(nullptr, [](TfLiteDelegate*) {}); if (params.Get("use_xnnpack")) { - TfLiteXNNPackDelegateOptions options = - TfLiteXNNPackDelegateOptionsDefault(); - const auto num_threads = params.Get("num_threads"); - // Note that we don't want to use the thread pool for num_threads == 1. - options.num_threads = num_threads > 1 ? num_threads : 0; - delegate = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&options), - &TfLiteXNNPackDelegateDelete); + return evaluation::CreateXNNPACKDelegate( + params.Get("num_threads")); } - return delegate; + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); } } // namespace benchmark diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD index 7a93fec5a3d..ea3341f4e75 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD @@ -111,6 +111,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/profiling:time", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc index a67397974dd..222e44c7168 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/profiling/time.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/utils.h" @@ -95,34 +96,14 @@ TfLiteStatus TfliteInferenceStage::Init() { } interpreter_->SetNumThreads(params.num_threads()); - if (params.delegate() == TfliteInferenceParams::NNAPI) { - Interpreter::TfLiteDelegatePtr delegate = CreateNNAPIDelegate(); - if (delegate) { - delegates_.push_back(std::move(delegate)); - } else { - LOG(WARNING) << "NNAPI not supported"; - } - } else if (params.delegate() == TfliteInferenceParams::GPU) { - Interpreter::TfLiteDelegatePtr delegate = CreateGPUDelegate(); - if (delegate) { - delegates_.push_back(std::move(delegate)); - } else { - LOG(WARNING) << "GPU not supported"; - } - } else if (params.delegate() == TfliteInferenceParams::HEXAGON) { - const std::string libhexagon_path("/data/local/tmp"); - Interpreter::TfLiteDelegatePtr delegate = - evaluation::CreateHexagonDelegate(libhexagon_path, false); - if (!delegate) { - // Refer to the Tensorflow Lite Hexagon delegate documentation for more - // information about how to get the required libraries. - LOG(WARNING) - << "Could not create Hexagon delegate: platform may not support " - "delegate or required libraries are missing"; - } else { - delegates_.push_back(std::move(delegate)); - } + std::string error_message; + auto delegate = CreateTfLiteDelegate(params, &error_message); + if (delegate) { + delegates_.push_back(std::move(delegate)); + } else { + LOG(WARNING) << error_message; } + for (int i = 0; i < delegates_.size(); ++i) { if (interpreter_->ModifyGraphWithDelegate(delegates_[i].get()) != kTfLiteOk) { diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD index 05bce542cd9..d8c42f9bc05 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD @@ -32,6 +32,7 @@ cc_binary( "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc index 6a61226d343..39b5082accb 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -20,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h" @@ -36,9 +38,6 @@ constexpr char kGroundTruthProtoFileFlag[] = "ground_truth_proto"; constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads"; constexpr char kDebugModeFlag[] = "debug_mode"; constexpr char kDelegateFlag[] = "delegate"; -constexpr char kNnapiDelegate[] = "nnapi"; -constexpr char kGpuDelegate[] = "gpu"; -constexpr char kHexagonDelegate[] = "hexagon"; std::string GetNameFromPath(const std::string& str) { int pos = str.find_last_of("/\\"); @@ -59,12 +58,11 @@ bool EvaluateModel(const std::string& model_file_path, auto* inference_params = detection_params->mutable_inference_params(); inference_params->set_model_file_path(model_file_path); inference_params->set_num_threads(num_interpreter_threads); - if (delegate == kNnapiDelegate) { - inference_params->set_delegate(TfliteInferenceParams::NNAPI); - } else if (delegate == kGpuDelegate) { - inference_params->set_delegate(TfliteInferenceParams::GPU); - } else if (delegate == kHexagonDelegate) { - inference_params->set_delegate(TfliteInferenceParams::HEXAGON); + inference_params->set_delegate(ParseStringToDelegateType(delegate)); + if (!delegate.empty() && + inference_params->delegate() == TfliteInferenceParams::NONE) { + LOG(WARNING) << "Unsupported TFLite delegate: " << delegate; + return false; } // Get ground truth data. @@ -167,17 +165,17 @@ int Main(int argc, char* argv[]) { std::vector model_labels; if (!ReadFileLines(model_output_labels_path, &model_labels)) { LOG(ERROR) << "Could not read model output labels file"; - return 0; + return EXIT_FAILURE; } if (!EvaluateModel(model_file_path, model_labels, image_paths, ground_truth_proto_file, delegate, output_file_path, num_interpreter_threads, debug_mode)) { LOG(ERROR) << "Could not evaluate model"; - return 0; + return EXIT_FAILURE; } - return 0; + return EXIT_SUCCESS; } } // namespace evaluation diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD index 5e2775870b5..8f6228c8857 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD @@ -25,6 +25,7 @@ cc_binary( "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc index 47a1161b2d7..5268039c500 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -19,6 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h" @@ -36,9 +38,6 @@ constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path"; constexpr char kNumImagesFlag[] = "num_images"; constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads"; constexpr char kDelegateFlag[] = "delegate"; -constexpr char kNnapiDelegate[] = "nnapi"; -constexpr char kGpuDelegate[] = "gpu"; -constexpr char kHexagonDelegate[] = "hexagon"; template std::vector GetFirstN(const std::vector& v, int n) { @@ -59,12 +58,11 @@ bool EvaluateModel(const std::string& model_file_path, auto* inference_params = classification_params->mutable_inference_params(); inference_params->set_model_file_path(model_file_path); inference_params->set_num_threads(num_interpreter_threads); - if (delegate == kNnapiDelegate) { - inference_params->set_delegate(TfliteInferenceParams::NNAPI); - } else if (delegate == kGpuDelegate) { - inference_params->set_delegate(TfliteInferenceParams::GPU); - } else if (delegate == kHexagonDelegate) { - inference_params->set_delegate(TfliteInferenceParams::HEXAGON); + inference_params->set_delegate(ParseStringToDelegateType(delegate)); + if (!delegate.empty() && + inference_params->delegate() == TfliteInferenceParams::NONE) { + LOG(WARNING) << "Unsupported TFLite delegate: " << delegate; + return false; } classification_params->mutable_topk_accuracy_eval_params()->set_k(10); @@ -144,11 +142,11 @@ int Main(int argc, char* argv[]) { StripTrailingSlashes(ground_truth_images_path), &image_files)); if (!ReadFileLines(ground_truth_labels_path, &ground_truth_image_labels)) { LOG(ERROR) << "Could not read ground truth labels file"; - return 0; + return EXIT_FAILURE; } if (image_files.size() != ground_truth_image_labels.size()) { LOG(ERROR) << "Number of images and ground truth labels is not same"; - return 0; + return EXIT_FAILURE; } std::vector image_labels; image_labels.reserve(image_files.size()); @@ -166,16 +164,16 @@ int Main(int argc, char* argv[]) { std::vector model_labels; if (!ReadFileLines(model_output_labels_path, &model_labels)) { LOG(ERROR) << "Could not read model output labels file"; - return 0; + return EXIT_FAILURE; } if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate, output_file_path, num_interpreter_threads)) { LOG(ERROR) << "Could not evaluate model"; - return 0; + return EXIT_FAILURE; } - return 0; + return EXIT_SUCCESS; } } // namespace evaluation diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD index 042aa1d85e6..72a2f9c2d74 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD @@ -23,6 +23,7 @@ cc_binary( "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc index 13dbd89b20f..cdd83d52d6f 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h" @@ -31,9 +32,6 @@ constexpr char kOutputFilePathFlag[] = "output_file_path"; constexpr char kNumRunsFlag[] = "num_runs"; constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads"; constexpr char kDelegateFlag[] = "delegate"; -constexpr char kNnapiDelegate[] = "nnapi"; -constexpr char kGpuDelegate[] = "gpu"; -constexpr char kHexagonDelegate[] = "hexagon"; bool EvaluateModel(const std::string& model_file_path, const std::string& delegate, int num_runs, @@ -49,14 +47,11 @@ bool EvaluateModel(const std::string& model_file_path, // This ensures that latency measurement isn't hampered by the time spent in // generating random data. inference_params->set_invocations_per_run(3); - if (delegate == kNnapiDelegate) { - inference_params->set_delegate(TfliteInferenceParams::NNAPI); - } - if (delegate == kGpuDelegate) { - inference_params->set_delegate(TfliteInferenceParams::GPU); - } - if (delegate == kHexagonDelegate) { - inference_params->set_delegate(TfliteInferenceParams::HEXAGON); + inference_params->set_delegate(ParseStringToDelegateType(delegate)); + if (!delegate.empty() && + inference_params->delegate() == TfliteInferenceParams::NONE) { + LOG(WARNING) << "Unsupported TFLite delegate: " << delegate; + return false; } InferenceProfilerStage eval(eval_config); if (eval.Init() != kTfLiteOk) return false; @@ -102,9 +97,10 @@ int Main(int argc, char* argv[]) { if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path, num_interpreter_threads)) { LOG(ERROR) << "Could not evaluate model!"; + return EXIT_FAILURE; } - return 0; + return EXIT_SUCCESS; } } // namespace evaluation From 578a4c32b4ca276ba7ecac7838999f7a13e17852 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 23:46:10 -0700 Subject: [PATCH 294/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301971528 Change-Id: Ib361d11ec414ed3b09f4ee9db94883b6a5548529 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 4f552e456e5..b8b73bc472d 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 7f1209593e1b3138a991f18e044e964fbd8a64c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 00:06:08 -0700 Subject: [PATCH 295/492] Update clang to newer version. PiperOrigin-RevId: 301973692 Change-Id: I29bd6ad109b68663688ed42102c93f927690fdda --- .../Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 index b529147e57e..c7e84936bf5 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 @@ -47,7 +47,7 @@ RUN apt-get update && apt-get install -y \ rm -rf /var/lib/apt/lists/* # Copy and run the install scripts. -ENV CLANG_VERSION="r373795" +ENV CLANG_VERSION="ra21beccea2020f950845cbb68db663d0737e174c" COPY install/*.sh /install/ ARG DEBIAN_FRONTEND=noninteractive RUN /install/install_bootstrap_deb_packages.sh From 9c7354462df25d2ddcdaafb89e04960dfb056621 Mon Sep 17 00:00:00 2001 From: Yi Situ Date: Fri, 20 Mar 2020 00:06:40 -0700 Subject: [PATCH 296/492] Added a better heuristic of detecting TensorCore eligible Einsum operations. Added equation as XStat type. PiperOrigin-RevId: 301973756 Change-Id: I5f3b73ac15697045dd41bb7264a395dac16d079e --- .../convert/xplane_to_kernel_stats_db.cc | 7 ++++++- .../core/profiler/utils/kernel_stats_utils.cc | 16 ++++++++++++++-- .../core/profiler/utils/kernel_stats_utils.h | 3 +++ tensorflow/core/profiler/utils/xplane_schema.cc | 1 + tensorflow/core/profiler/utils/xplane_schema.h | 1 + 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc index 4b87033a508..e0fd2bb6339 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc @@ -42,6 +42,7 @@ KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb( absl::string_view tf_op_fullname; KernelReport kernel; + absl::string_view equation; event.ForEachStat([&](const tensorflow::profiler::XStatVisitor& stat) { if (stat.Type() == StatType::kLevel0) { tf_op_fullname = stat.StrValue(); @@ -53,14 +54,18 @@ KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb( kernel.set_min_duration_ns(event.DurationNs()); kernel.set_max_duration_ns(event.DurationNs()); ParseKernelLaunchParams(stat.StrValue(), &kernel); + } else if (stat.Type() == StatType::kEquation) { + equation = stat.StrValue(); } }); if (!tf_op_fullname.empty()) { tensorflow::profiler::TfOp tf_op = ParseTfOpFullname(tf_op_fullname); + if (kernel.total_duration_ns()) { kernel.set_op_name(tf_op.name.data(), tf_op.name.size()); - bool tensor_core_eligible = IsOpTensorCoreEligible(kernel.op_name()); + bool tensor_core_eligible = IsEinsumTensorCoreEligible(equation) || + IsOpTensorCoreEligible(kernel.op_name()); #if defined(LOG_IF) LOG_IF(INFO, !tensor_core_eligible && kernel.is_kernel_using_tensor_core()) diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc index 3921c7d6aab..4721047a856 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc @@ -111,8 +111,6 @@ bool IsOpTensorCoreEligible(absl::string_view tf_op_name) { || absl::EndsWith(tf_op_name, "DepthwiseConv2dNative") || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropFilter") || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropInput") - // Using Contains because of numeric suffix and possible Xla prefix. - || absl::StrContains(tf_op_name, "Einsum") // Using Contains to match V2/V3 suffixes. || absl::StrContains(tf_op_name, "BatchMatMul") // MatMul requires exact matching. @@ -128,6 +126,20 @@ bool IsOpTensorCoreEligible(absl::string_view tf_op_name) { // clang-format on } +bool IsEinsumTensorCoreEligible(absl::string_view equation) { + if (equation.empty()) { + return false; + } + const std::vector input_output = + absl::StrSplit(equation, "->"); + if (input_output.size() != 2) { + return false; + } + const std::vector lhs_rhs = + absl::StrSplit(input_output[0], ','); + return lhs_rhs.size() == 2; +} + bool KernelReportLessThanComparator::operator()(const KernelReport& lhs, const KernelReport& rhs) { // Disable formatting to keep vertical alignment for better readability, diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.h b/tensorflow/core/profiler/utils/kernel_stats_utils.h index 7b121b49e85..5b66596d683 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.h +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.h @@ -34,6 +34,9 @@ bool IsKernelUsingTensorCore(absl::string_view kernel_name); // Returns true if operation is eligible to use TensorCores. bool IsOpTensorCoreEligible(absl::string_view tf_op_name); +// Returns true if Einsum equation is eligible to use TensorCores. +bool IsEinsumTensorCoreEligible(absl::string_view equation); + // Less than comparator for Kernel Reports. struct KernelReportLessThanComparator { bool operator()(const KernelReport& lhs, const KernelReport& rhs); diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index 8eb7bd7c76d..9de8028f8eb 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -139,6 +139,7 @@ const StatTypeMap& GetStatTypeMap() { {"tf_op", kTfOp}, {"hlo_op", kHloOp}, {"hlo_module", kHloModule}, + {"equation", kEquation}, // Performance counter related. {"Raw Value", kRawValue}, {"Scaled Value", kScaledValue}, diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index ad8efd60033..03e7b8ee720 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -130,6 +130,7 @@ enum StatType { kTfOp, kHloOp, kHloModule, + kEquation, // Performance counter related. kRawValue, kScaledValue, From f819114a2d9d393a60e954d3a3e42d8700ff3b19 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 00:11:55 -0700 Subject: [PATCH 297/492] Remove direct dependency on the static libcudart; it is now linked dynamically via the stub everywhere. PiperOrigin-RevId: 301974345 Change-Id: I041786954d8aaa22bf76fdeeab48b08fbe7c2ec0 --- third_party/nccl/archive.BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/nccl/archive.BUILD b/third_party/nccl/archive.BUILD index 9fd5b6f44ea..4936844b6b2 100644 --- a/third_party/nccl/archive.BUILD +++ b/third_party/nccl/archive.BUILD @@ -94,6 +94,5 @@ cc_library( ":device", ":include_hdrs", ":src_hdrs", - "@local_config_cuda//cuda:cudart_static", ], ) From 91e9cf21561a85c41d0915769b0032c58c52b9ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 00:16:40 -0700 Subject: [PATCH 298/492] Update RBE docker image to include a newer version of clang. PiperOrigin-RevId: 301974757 Change-Id: Ic1dfb379fc66104f9efe9faf3c4022d40504ef1c --- third_party/toolchains/preconfig/generate/containers.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/toolchains/preconfig/generate/containers.bzl b/third_party/toolchains/preconfig/generate/containers.bzl index 397b36d9f3b..be19af8ceeb 100644 --- a/third_party/toolchains/preconfig/generate/containers.bzl +++ b/third_party/toolchains/preconfig/generate/containers.bzl @@ -8,7 +8,7 @@ container_digests = { "cuda10.0-cudnn7-centos6": "sha256:a1909ba09c703340ee0074ce63dd94fe8fea48035a25264677907a609e2375e0", "cuda10.1-cudnn7-centos6": "sha256:454b899657e87893ee5e68dc0f87df59b6a0a7418ae09cafcc3dd65ac71feca9", "cuda10.0-cudnn7-ubuntu16.04-manylinux2010": "sha256:5812d9d0ef0a3276fc5faaf4cd01f3d6e03d635893a6e2d2e04f6f01d626c432", - "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:23db3de806535c9d26170567ba55cf653e503057345a0e9c129124c08ea118a3", + "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:177e1e55894b3c6edcfd7aa5d6db53716924b02553922bbf907e16b3d319e18c", "rocm-ubuntu16.04": "sha256:e645447dd6127325f3e97b8bf23424f637a8579d963b34fcc6772cf7cfaa0ebe", "windows-1803": "sha256:f109576c7c0c8a1783ff22b666e8923b52dbbe7933f69a1c7a7275202c304a12", } From 765ddddb2278551441ecaf45528898cd839df5c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 00:20:56 -0700 Subject: [PATCH 299/492] Add clang RBE configuration. PiperOrigin-RevId: 301975133 Change-Id: I6bb5929d5def7669a81a360a34edda10c068aaae --- .bazelrc | 29 ++++++++++++++----- .../toolchains/remote_config/configs.bzl | 12 ++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/.bazelrc b/.bazelrc index 8be3dadaf4e..8fd166c10a5 100644 --- a/.bazelrc +++ b/.bazelrc @@ -356,7 +356,15 @@ build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/to build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" -build:rbe_linux_cuda_nvcc --config=rbe_linux +build:rbe_linux_cuda_base --config=rbe_linux +build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1 +build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10 +build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7 +build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 +build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1 +test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" + +build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" @@ -365,13 +373,20 @@ build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda1 build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" -build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1 -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10 -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7 -build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_CUDA=1 build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true -test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base + +build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base +build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" +build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" +build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" +build:rbe_linux_cuda_clang --define=using_cuda_clang=true +test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc diff --git a/third_party/toolchains/remote_config/configs.bzl b/third_party/toolchains/remote_config/configs.bzl index 973efb40af1..4ebf5c1c068 100644 --- a/third_party/toolchains/remote_config/configs.bzl +++ b/third_party/toolchains/remote_config/configs.bzl @@ -34,6 +34,18 @@ def initialize_rbe_configs(): tensorrt_version = "6.0", ) + tensorflow_rbe_config( + name = "ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0", + compiler = "/clang_ra21beccea2020f950845cbb68db663d0737e174c/bin/clang", + cuda_version = "10.1", + cudnn_version = "7", + os = "ubuntu16.04-manylinux2010", + python_version = "3", + tensorrt_install_path = "/usr", + tensorrt_version = "6.0", + sysroot = "/dt7", + ) + tensorflow_rbe_config( name = "ubuntu16.04-py3_opt-gcc5-rocm", compiler = "gcc", From 1751e3c8de7a3e2a69dd22a6062aec7d2c7709ef Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Thu, 19 Mar 2020 16:56:51 +0100 Subject: [PATCH 300/492] Add -lrt to linkflags Fixes compilation on e.g. older CentOS, see for details #15129 --- tensorflow/tensorflow.bzl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 390acacefe8..d10650479d6 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -537,7 +537,7 @@ def tf_cc_shared_object( srcs = [], deps = [], data = [], - linkopts = [], + linkopts = if_not_windows(["-lrt"]), framework_so = tf_binary_additional_srcs(), soversion = None, kernels = [], @@ -641,7 +641,7 @@ def tf_cc_binary( srcs = [], deps = [], data = [], - linkopts = [], + linkopts = if_not_windows(["-lrt"]), copts = tf_copts(), kernels = [], per_os_targets = False, # Generate targets with SHARED_LIBRARY_NAME_PATTERNS @@ -737,7 +737,7 @@ def tf_gen_op_wrapper_cc( tf_cc_binary( name = tool, copts = tf_copts(), - linkopts = if_not_windows(["-lm", "-Wl,-ldl"]), + linkopts = if_not_windows(["-lm", "-Wl,-ldl", "-lrt"]), linkstatic = 1, # Faster to link this one-time-use binary dynamically deps = [op_gen] + deps, ) @@ -924,7 +924,7 @@ def tf_gen_op_wrapper_py( tf_cc_binary( name = tool_name, copts = tf_copts(), - linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + cc_linkopts, + linkopts = if_not_windows(["-lm", "-Wl,-ldl", "-lrt"]) + cc_linkopts, linkstatic = 1, # Faster to link this one-time-use binary dynamically visibility = [clean_dep("//tensorflow:internal")], deps = ([ @@ -1221,7 +1221,7 @@ def tf_cc_tests( tags = [], size = "medium", args = None, - linkopts = [], + linkopts = if_not_windows(["-lrt"]), kernels = [], create_named_test_suite = False, visibility = None): From cc2f3e2ccf8d91ce5fd38bb73cd5ef9b30349857 Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Fri, 20 Mar 2020 09:29:22 +0100 Subject: [PATCH 301/492] Guard -lrt adding in check for linux variant --- tensorflow/tensorflow.bzl | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d10650479d6..714f4039a29 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -262,6 +262,17 @@ def if_nccl(if_true, if_false = []): "//conditions:default": if_true, }) +# Linux systems may required -lrt linker flag for e.g. clock_gettime +# see https://github.com/tensorflow/tensorflow/issues/15129 +def lrt_if_needed(): + lrt = ["-lrt"] + return select({ + clean_dep("//tensorflow:linux_aarch64"): lrt, + clean_dep("//tensorflow:linux_x86_64"): lrt, + clean_dep("//tensorflow:linux_ppc64le"): lrt, + "//conditions:default": [], + }) + def get_win_copts(is_external = False): WINDOWS_COPTS = [ "/DPLATFORM_WINDOWS", @@ -537,7 +548,7 @@ def tf_cc_shared_object( srcs = [], deps = [], data = [], - linkopts = if_not_windows(["-lrt"]), + linkopts = lrt_if_needed(), framework_so = tf_binary_additional_srcs(), soversion = None, kernels = [], @@ -641,7 +652,7 @@ def tf_cc_binary( srcs = [], deps = [], data = [], - linkopts = if_not_windows(["-lrt"]), + linkopts = lrt_if_needed(), copts = tf_copts(), kernels = [], per_os_targets = False, # Generate targets with SHARED_LIBRARY_NAME_PATTERNS @@ -737,7 +748,7 @@ def tf_gen_op_wrapper_cc( tf_cc_binary( name = tool, copts = tf_copts(), - linkopts = if_not_windows(["-lm", "-Wl,-ldl", "-lrt"]), + linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + lrt_if_needed(), linkstatic = 1, # Faster to link this one-time-use binary dynamically deps = [op_gen] + deps, ) @@ -910,7 +921,7 @@ def tf_gen_op_wrapper_py( hidden_file = None, generated_target_name = None, op_whitelist = [], - cc_linkopts = [], + cc_linkopts = lrt_if_needed(), api_def_srcs = []): _ = require_shape_functions # Unused. @@ -924,7 +935,7 @@ def tf_gen_op_wrapper_py( tf_cc_binary( name = tool_name, copts = tf_copts(), - linkopts = if_not_windows(["-lm", "-Wl,-ldl", "-lrt"]) + cc_linkopts, + linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + cc_linkopts, linkstatic = 1, # Faster to link this one-time-use binary dynamically visibility = [clean_dep("//tensorflow:internal")], deps = ([ @@ -1221,7 +1232,7 @@ def tf_cc_tests( tags = [], size = "medium", args = None, - linkopts = if_not_windows(["-lrt"]), + linkopts = lrt_if_needed(), kernels = [], create_named_test_suite = False, visibility = None): From 1c80c08adc2ff218c46a180160f83dd6bcffb9a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 01:46:33 -0700 Subject: [PATCH 302/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301984377 Change-Id: I424103432723829772ce6799ef4d7658a849b82b --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index b8b73bc472d..4f552e456e5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 130e868c9c5a2f01fcd92c19b5dcbfa0a6a59feb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 02:02:46 -0700 Subject: [PATCH 303/492] compat: Update forward compatibility horizon to 2020-03-20 PiperOrigin-RevId: 301986086 Change-Id: Ie1731dc86b206961afbaa1291052f10b0d6121b1 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 6121c71a404..3195e9ce5b9 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 19) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 20) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 61349d9e34f84a727ca498247f5771f9049387c1 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Fri, 20 Mar 2020 02:40:31 -0700 Subject: [PATCH 304/492] DynamicBroadcastInDimOp: make broadcast_dimensions mandatory All code that operates on this attribute would end up having to special-case when it is missing, so remove it. PiperOrigin-RevId: 301990055 Change-Id: I0e5d1639e140248927eac900f3561b3be2793282 --- tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 14 ++------------ tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 6 +----- .../mlir/xla/transforms/hlo_legalize_to_lhlo.cc | 8 ++------ 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 87f6eaecc52..17d0b958084 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -545,19 +545,9 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) { auto operandRank = operandType.getRank(); auto resultRank = resultType.getRank(); - if (!op.broadcast_dimensions()) { - if (operandRank == 0) { - return success(); - } - return op.emitOpError( - llvm::formatv("broadcast_dimensions is absent, but required because " - "operand has non-zero rank ({0})", - operandRank)); - } - // Verify broadcast_dimensions. - auto bcastDimensions = *op.broadcast_dimensions(); - auto bcastDimensionsType = op.broadcast_dimensions()->getType(); + auto bcastDimensions = op.broadcast_dimensions(); + auto bcastDimensionsType = op.broadcast_dimensions().getType(); auto bcastDimensionsRank = bcastDimensionsType.getRank(); // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1. if (bcastDimensionsRank != 1) { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 8f8f6ac62e3..bc05a1c100c 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -793,15 +793,11 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", shaped original, but is being phased as a separate op in order to support compatibility with lowerings and translations that precede dynamic shapes. - - Note that the `broadcast_dimensions` attribute is optional and if omitted, - it is assumed to be an ordered, right-aligned mapping from input to - output dimensions. }]; let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_dimensions, - OptionalAttr:$broadcast_dimensions + BroadcastDimAttr:$broadcast_dimensions ); let results = (outs HLO_Tensor); diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index b2b17a8dd75..4dad8c5a996 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -164,14 +164,10 @@ struct HloToLhloDynamicBroadcastInDimOpConverter xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); - auto broadcast_dimensions = op.broadcast_dimensions(); - if (!broadcast_dimensions.hasValue()) { - return failure(); - } Value resultBuffer = InsertDynamicAllocAndDealloc( loc, op.getResult(), op.output_dimensions(), &rewriter); - rewriter.create( - loc, operands[0], resultBuffer, broadcast_dimensions.getValue()); + rewriter.create(loc, operands[0], resultBuffer, + op.broadcast_dimensions()); rewriter.replaceOp(op, {resultBuffer}); From 7d5fb910040112b1169a92443221a3f349766ee5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 02:44:53 -0700 Subject: [PATCH 305/492] Update metadata schema documentation to better define the spec for ScoreCalibrationOptions. PiperOrigin-RevId: 301990545 Change-Id: I16c52dfaac60a749fad69a8dca27bffcfb3ed436 --- .../support/metadata/metadata_schema.fbs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs b/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs index a70dd044849..f3f3bbcc6ff 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs +++ b/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs @@ -68,10 +68,13 @@ enum AssociatedFileType : byte { // Contains sigmoid-based score calibration parameters, formatted as CSV. // Lines contain for each index of an output tensor the scale, slope, offset - // and min_score parameters to be used for sigmoid fitting (in this order and - // in `strtof`-compatible [1] format). + // and (optional) min_score parameters to be used for sigmoid fitting (in this + // order and in `strtof`-compatible [1] format). // A line may be left empty to default calibrated scores for this index to - // default_score. See documentation for ScoreCalibrationOptions for details. + // default_score. + // In summary, each line should thus contain 0, 3 or 4 comma-separated values. + // + // See documentation for ScoreCalibrationOptions for details. // // [1]: https://en.cppreference.com/w/c/string/byte/strtof TENSOR_AXIS_SCORE_CALIBRATION = 4, @@ -332,11 +335,12 @@ enum ScoreTransformationType : byte { // output, e.g. image classification or detection models. // // For each index in the output tensor, this applies: -// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score`, -// * `f(x) = default_score` otherwise or if no scale, slope, offset and -// min_score have been specified. +// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score` or if no +// `min_score` has been specified, +// * `f(x) = default_score` otherwise or if no scale, slope and offset have been +// specified. // Where: -// * scale, slope, offset and min_score are index-specific parameters +// * scale, slope, offset and (optional) min_score are index-specific parameters // * g(x) is an index-independent transform among those defined in // ScoreTransformationType // * default_score is an index-independent parameter. From 4f468e12edd725e997f08328e6bd74e738ee4550 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 02:53:49 -0700 Subject: [PATCH 306/492] pfor: Add converter for LeakyRelu and LeakyReluGrad. PiperOrigin-RevId: 301991557 Change-Id: Ifad6d0dff4d56f128247e56652c68d0f45fd5350 --- .../python/ops/parallel_for/math_test.py | 23 ++++++++++--------- tensorflow/python/ops/parallel_for/pfor.py | 17 ++++++++++---- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py index 753b61cc572..773195283d6 100644 --- a/tensorflow/python/ops/parallel_for/math_test.py +++ b/tensorflow/python/ops/parallel_for/math_test.py @@ -53,17 +53,17 @@ class MathTest(PForTestCase, parameterized.TestCase): def loop_fn(i): with g: - x1 = array_ops.gather(x, i) - y1 = op(x1) - outputs = [op(x), y1] - if y1.dtype == dtypes.float32: - loss = math_ops.reduce_sum(y1 * y1) - else: - loss = None - if loss is not None: - grad = g.gradient(loss, x1) - if grad is not None: - outputs.append(grad) + y = op(x) + x_i = array_ops.gather(x, i) + y_i = op(x_i) + outputs = [y_i] + # Build cross product of loop variant/invariant outputs and gradients. + for out in (y, y_i): + if out.dtype == dtypes.float32: + for output_gradients in (None, out * math_ops.cast(i, out.dtype)): + grad = g.gradient(out, x_i, output_gradients=output_gradients) + if grad is not None: + outputs.append(grad) return outputs # pylint: enable=cell-var-from-loop @@ -128,6 +128,7 @@ class MathTest(PForTestCase, parameterized.TestCase): nn.elu, nn.relu, nn.relu6, + lambda t: nn.leaky_relu(t, alpha=0.1), nn.selu, nn.softplus, nn.softsign, diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 87642778257..35d5e64334e 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -2733,10 +2733,18 @@ def _convert_cwise(pfor_input, op_type, op_func): # and hence don't need extra arguments passed to the cwise_op call below. for attr in pfor_input.op.node_def.attr.keys(): assert attr in [u"T", u"Tout", u"_xla_compile_id"], (op_type, attr) - pfor_input.expanddim_inputs_for_broadcast() + if pfor_input.num_inputs > 1: + pfor_input.expanddim_inputs_for_broadcast() return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) +@RegisterPFor("LeakyRelu") +def _convert_leaky_relu(pfor_input): + t = pfor_input.stacked_input(0) + alpha = pfor_input.get_attr("alpha") + return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True) + + @RegisterPFor("Equal") def _convert_equal(pfor_input): pfor_input.expanddim_inputs_for_broadcast() @@ -2831,16 +2839,17 @@ def _convert_biasaddgrad(pfor_input): # Some required ops are not exposed under the tf namespace. Hence relying on # _create_op to create them. @RegisterPForWithArgs("EluGrad") +@RegisterPForWithArgs("LeakyReluGrad") +@RegisterPForWithArgs("ReciprocalGrad") @RegisterPForWithArgs("Relu6Grad") @RegisterPForWithArgs("ReluGrad") +@RegisterPForWithArgs("RsqrtGrad") @RegisterPForWithArgs("SeluGrad") @RegisterPForWithArgs("SigmoidGrad") @RegisterPForWithArgs("SoftplusGrad") @RegisterPForWithArgs("SoftsignGrad") -@RegisterPForWithArgs("TanhGrad") @RegisterPForWithArgs("SqrtGrad") -@RegisterPForWithArgs("RsqrtGrad") -@RegisterPForWithArgs("ReciprocalGrad") +@RegisterPForWithArgs("TanhGrad") def _convert_grads(pfor_input, op_type, *args, **kw_args): del args del kw_args From c0306ef626b02ab5ab10aac2cec6d08f56136a5c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 03:46:22 -0700 Subject: [PATCH 307/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 301997806 Change-Id: I8244ef1e364ef561acb4e93a000ce6f84e6e9f8a --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 4f552e456e5..b8b73bc472d 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 0e0fc5a791015d7a866cb4e6c7c87c803ac1f81c Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 20 Mar 2020 04:28:09 -0700 Subject: [PATCH 308/492] [XLA][MLIR] Emit ReduceWindow HLO instruction as xla_lhlo.ReduceWindowOp. PiperOrigin-RevId: 302002238 Change-Id: I93d202b3729b63d85ae56e1515b1c1971e794aa0 --- tensorflow/compiler/mlir/xla/hlo_utils.cc | 10 +- tensorflow/compiler/mlir/xla/hlo_utils.h | 5 +- .../xla/service/mlir_gpu/kernel_lowering.cc | 8 +- .../service/mlir_gpu/lhlo_dialect_emitter.cc | 94 ++++++++++++++----- .../service/mlir_gpu/lhlo_dialect_emitter.h | 1 + .../compiler/xla/service/mlir_gpu/tests/BUILD | 1 + .../mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc | 6 ++ .../service/mlir_gpu/tests/reduce_window.hlo | 34 +++++++ 8 files changed, 124 insertions(+), 35 deletions(-) create mode 100644 tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index e0c5c4a00f0..3caa4f58725 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -108,12 +108,12 @@ StatusOr CreateDenseElementsAttrFromLiteral( } mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( - const llvm::ArrayRef vector, mlir::Builder builder) { + const llvm::ArrayRef vector, mlir::Builder builder, + llvm::ArrayRef shape) { return mlir::DenseIntElementsAttr::get( - mlir::RankedTensorType::get(vector.size(), - builder.getIntegerType(64)), - vector) - .cast(); + mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape, + builder.getIntegerType(64)), + vector); } StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.h b/tensorflow/compiler/mlir/xla/hlo_utils.h index 003eda0b992..764c40ed93b 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/hlo_utils.h @@ -30,8 +30,11 @@ namespace xla { StatusOr CreateDenseElementsAttrFromLiteral( const LiteralBase& literal, mlir::Builder builder); +// Creates an DenseIntElementsAttr using the elements of the vector and the +// optional shape. mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( - const llvm::ArrayRef vector, mlir::Builder builder); + const llvm::ArrayRef vector, mlir::Builder builder, + llvm::ArrayRef shape = {}); StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, mlir::Builder builder); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 151d82fd2a1..c3a2607ad73 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -70,12 +70,10 @@ struct FusionToLhloConverter target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>(); ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns); - getFunction().walk([&](FusionOp op) { - if (failed(applyPartialConversion(op, target, patterns, nullptr))) { - signalPassFailure(); + getFunction().walk([&](mlir::Operation* op) { + if (op->getNumRegions() == 0) { + return; } - }); - getFunction().walk([&](mlir::xla_lhlo::ReduceOp op) { if (failed(applyPartialConversion(op, target, patterns, nullptr))) { signalPassFailure(); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 1f681bfab00..55cc0af4d55 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -42,6 +42,7 @@ namespace { using ::mlir::ArrayRef; using ::mlir::Attribute; using ::mlir::Builder; +using ::mlir::DenseIntElementsAttr; using ::mlir::FuncOp; using ::mlir::Identifier; using ::mlir::Location; @@ -143,6 +144,37 @@ StatusOr> GetInstructionArgTypes( return arg_types; } +// Converts HloComputation into a block with HLO dialect ops. The block gets +// memref arguments corresponding to HloComputation arguments and results. +Status SpliceHloComputation(OpBuilder builder, mlir::Location loc, + const HloComputation& hlo_computation, + xla::mlir_gpu::EmissionContext* emission_context) { + auto block = builder.getInsertionBlock(); + llvm::SmallVector arg_values; + // First map parameters to memrefs on the operation. + for (auto param : hlo_computation.parameter_instructions()) { + TF_ASSIGN_OR_RETURN( + auto arg_type, ConvertShapeToType(param->shape(), builder)); + auto block_arg = block->addArgument(arg_type); + arg_values.push_back(builder.create<::mlir::TensorLoadOp>(loc, block_arg)); + } + HloDialectEmitter hlo_emitter(emission_context, builder, arg_values); + + TF_ASSIGN_OR_RETURN(auto result, + hlo_emitter.EmitComputation(hlo_computation)); + + // Now add a block arg and store for the result. + builder.setInsertionPoint(block->getTerminator()); + TF_ASSIGN_OR_RETURN( + auto result_type, + ConvertShapeToType( + hlo_computation.root_instruction()->shape(), builder)); + auto block_arg = block->addArgument(result_type); + builder.create<::mlir::TensorStoreOp>(loc, result, block_arg); + + return Status::OK(); +} + } // namespace mlir::Location LhloDialectEmitter::getLocation( @@ -268,33 +300,47 @@ Status LhloDialectEmitter::HandleReduce(HloInstruction* reduce) { auto reduce_op = builder.create(loc, inputs, init_values, results, dimensions_attr); reduce_op.ensureTerminator(reduce_op.body(), builder, getLocation(reduce)); + return SpliceHloComputation(OpBuilder{&reduce_op.body()}, loc, + *reduce->to_apply(), emission_context_); +} - OpBuilder body_builder(reduce_op.body()); - auto block = body_builder.getInsertionBlock(); - auto to_apply = reduce->to_apply(); - llvm::SmallVector reduce_arg_values; - // First map parameters to memrefs on the operation. - for (auto param : to_apply->parameter_instructions()) { - TF_ASSIGN_OR_RETURN(auto arg_type, ConvertShapeToType( - param->shape(), builder_)); - auto block_arg = block->addArgument(arg_type); - reduce_arg_values.push_back( - body_builder.create<::mlir::TensorLoadOp>(loc, block_arg)); +Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* reduce_window) { + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*reduce_window)); + llvm::SmallVector arg_values{function.args_begin(), + function.args_end()}; + OpBuilder builder(function.getBody()); + auto loc = getLocation(reduce_window); + + // Collect attribute values. + llvm::SmallVector window_dimensions, window_strides, base_dilations, + window_dilations; + llvm::SmallVector padding; + int64 rank = reduce_window->window().dimensions_size(); + window_dimensions.reserve(rank); + window_strides.reserve(rank); + base_dilations.reserve(rank); + window_dilations.reserve(rank); + padding.reserve(2 * rank); + for (const auto& window : reduce_window->window().dimensions()) { + window_dimensions.push_back(window.size()); + window_strides.push_back(window.stride()); + base_dilations.push_back(window.base_dilation()); + window_dilations.push_back(window.window_dilation()); + padding.push_back(window.padding_low()); + padding.push_back(window.padding_high()); } - HloDialectEmitter hlo_emitter(emission_context_, body_builder, - reduce_arg_values); - TF_ASSIGN_OR_RETURN(auto result, hlo_emitter.EmitComputation(*to_apply)); - - // Now add a block arg and store for the result. - body_builder.setInsertionPoint(block->getTerminator()); - TF_ASSIGN_OR_RETURN(auto result_type, - ConvertShapeToType( - to_apply->root_instruction()->shape(), builder)); - auto block_arg = block->addArgument(result_type); - body_builder.create<::mlir::TensorStoreOp>(loc, result, block_arg); - - return Status::OK(); + auto reduce_window_op = builder.create( + loc, /*operand=*/arg_values[0], /*init_value=*/arg_values[1], + /*out=*/arg_values[2], + CreateDenseIntElementsAttrFromVector(window_dimensions, builder), + CreateDenseIntElementsAttrFromVector(window_strides, builder), + CreateDenseIntElementsAttrFromVector(base_dilations, builder), + CreateDenseIntElementsAttrFromVector(window_dilations, builder), + CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2})); + reduce_window_op.ensureTerminator(reduce_window_op.body(), builder, loc); + return SpliceHloComputation(OpBuilder{&reduce_window_op.body()}, loc, + *reduce_window->to_apply(), emission_context_); } Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h index 48d275ef5e0..dc7300490f8 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -61,6 +61,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, Status HandleIota(HloInstruction* iota) override; Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; + Status HandleReduceWindow(HloInstruction* reduce_window) override; Status HandleTuple(HloInstruction* tuple) override; Status FinishVisit(HloInstruction* root) override; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD index e2523d82b91..dbaea44ea3a 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -47,6 +47,7 @@ tf_cc_test( "iota_add_multiply.hlo", "log.hlo", "neg.hlo", + "reduce_window.hlo", "rem.hlo", "rsqrt.hlo", "select.hlo", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index 206d46debdf..b45acf98664 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -175,6 +175,12 @@ TEST_F(LhloGenTest, Neg) { "neg.hlo")); } +TEST_F(LhloGenTest, ReduceWindow) { + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "reduce_window.hlo")); +} + TEST_F(LhloGenTest, Rem) { CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo new file mode 100644 index 00000000000..1d4786e8151 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo @@ -0,0 +1,34 @@ +HloModule ReduceWindow + +%max (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %max = f32[] maximum(f32[] %x, f32[] %y) +} + +ENTRY %ReduceWindow (x: f32[128,64,112,112], y: f32[]) -> f32[128,64,56,56] { + %x = f32[128,64,112,112] parameter(0) + %y = f32[] parameter(1) + ROOT %reduce-window = f32[128,64,56,56] reduce-window( + f32[128,64,112,112] %x, + f32[] %y + ), + window={size=1x1x3x3 stride=1x1x2x2 pad=0_0x0_0x0_1x0_1}, to_apply=%max +} + +// CHECK: func @"reduce-window"( +// CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[CST:%.*]]: memref, [[RES:%.*]]: [[REST:.*]]) { +// CHECK: "xla_lhlo.reduce_window"([[LHS:%.*]], [[RHS:%.*]], [[OUT:%.*]]) ( { +// CHECK: ^bb0([[LHS:%.*]]: memref, [[RHS:%.*]]: memref, [[OUT:%.*]]: memref): +// CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]] +// CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]] +// CHECK: [[OUT_TENSOR:%.*]] = xla_hlo.maximum [[LHS_TENSOR]], [[RHS_TENSOR]] +// CHECK: tensor_store [[OUT_TENSOR]], [[OUT]] +// CHECK: "xla_lhlo.terminator"() : () -> () +// CHECK: }) { +// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64> +// CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]> +// CHECK-SAME: window_dilations = dense<1> : tensor<4xi64> +// CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]> +// CHECK-SAME: window_strides = dense<[1, 1, 2, 2]> +// CHECK: } : ([[ARGT]], memref, [[REST]]) -> () From 5c91eab5d43c6d84193527f22d3ced693a8df26d Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 20 Mar 2020 05:28:43 -0700 Subject: [PATCH 309/492] [XLA][MLIR] Generalize FusionToLhloConverter into NestedHloRegionsConverter. This converter already supports not only FusionOp but ReduceOp and soon it should start supporting ReduceWindowOp and SelectAndScatterOp. It makes sense to extend it so that we can convert all LHLO ops that have HLO ops in their bodies. PiperOrigin-RevId: 302008870 Change-Id: I61fe4e84b1d8473bf7c934e4f98165af167e3e21 --- .../xla/service/mlir_gpu/kernel_lowering.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index c3a2607ad73..748306561d4 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -58,11 +58,12 @@ using ::mlir::xla_lhlo::FusionOp; // Following are some small transformations that are required to clean up code // after lowering from linalg to loops. -// A simple pass that applies lowering of HLO to LHLO only within Fusion -// operations. This is needed, as FusionOp is not closed from above and hence -// nested pass managers can not be applied. -struct FusionToLhloConverter - : public mlir::FunctionPass { +// A simple pass that applies lowering of HLO to LHLO only within LHLO ops that +// contain regions with HLO ops, e.g. FusionOp, ReduceOp, SelectAndScatterOp. +// This is needed, as these ops are not closed from above and hence nested pass +// managers can not be applied. +struct NestedHloRegionsConverter + : public mlir::FunctionPass { void runOnFunction() override { auto& ctx = getContext(); mlir::OwningRewritePatternList patterns; @@ -270,8 +271,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) { mlir::PassManager pm(module.getContext()); EnableIRPrinting(&pm); - // First, lower bodies of fusion operations from hlo to lhlo. - pm.addPass(absl::make_unique()); + // First, lower bodies of lhlo operations that contain hlo ops. + pm.addPass(absl::make_unique()); // Next, we can strip the outer fusion operation. pm.addPass(absl::make_unique()); // Remove unnecessary Lhlo copies. From abf182a8826570bec2196feff6af32a769276663 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 20 Mar 2020 05:32:03 -0700 Subject: [PATCH 310/492] [XLA][MLIR] Emit SelectAndScatter HLO instruction as lhlo.SelectAndScatterOp. PiperOrigin-RevId: 302009183 Change-Id: If4ef8d3d23118c5815e33affcfb58ac7c7612352 --- .../service/mlir_gpu/lhlo_dialect_emitter.cc | 48 ++++++++++++++++- .../service/mlir_gpu/lhlo_dialect_emitter.h | 1 + .../compiler/xla/service/mlir_gpu/tests/BUILD | 1 + .../mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc | 6 +++ .../mlir_gpu/tests/select_and_scatter.hlo | 53 +++++++++++++++++++ 5 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 55cc0af4d55..3d4e1078ca2 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -150,6 +150,7 @@ Status SpliceHloComputation(OpBuilder builder, mlir::Location loc, const HloComputation& hlo_computation, xla::mlir_gpu::EmissionContext* emission_context) { auto block = builder.getInsertionBlock(); + builder.setInsertionPoint(block->getTerminator()); llvm::SmallVector arg_values; // First map parameters to memrefs on the operation. for (auto param : hlo_computation.parameter_instructions()) { @@ -242,7 +243,7 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { } Status LhloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) { - mlir::DenseIntElementsAttr broadcast_dim = + DenseIntElementsAttr broadcast_dim = CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_); TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*broadcast)); @@ -343,6 +344,51 @@ Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* reduce_window) { *reduce_window->to_apply(), emission_context_); } +Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* hlo) { + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*hlo)); + llvm::SmallVector arg_values{function.args_begin(), + function.args_end()}; + OpBuilder builder(function.getBody()); + auto loc = getLocation(hlo); + + // Collect attribute values. + llvm::SmallVector window_dimensions, window_strides, padding; + int64 rank = hlo->window().dimensions_size(); + window_dimensions.reserve(rank); + window_strides.reserve(rank); + padding.reserve(2 * rank); + for (const auto& window : hlo->window().dimensions()) { + window_dimensions.push_back(window.size()); + window_strides.push_back(window.stride()); + padding.push_back(window.padding_low()); + padding.push_back(window.padding_high()); + } + + auto select_scatter_op = builder.create( + loc, /*operand=*/arg_values[0], /*source=*/arg_values[1], + /*init_value=*/arg_values[2], + /*out=*/arg_values[3], + CreateDenseIntElementsAttrFromVector(window_dimensions, builder), + CreateDenseIntElementsAttrFromVector(window_strides, builder), + CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2})); + + // Convert `select` computation. + builder.createBlock(&select_scatter_op.select()); + OpBuilder select_builder{&select_scatter_op.select()}; + select_builder.create(loc); + TF_RETURN_IF_ERROR(SpliceHloComputation(select_builder, loc, *hlo->select(), + emission_context_)); + + // Convert `scatter` computation. + builder.createBlock(&select_scatter_op.scatter()); + OpBuilder scatter_builder{&select_scatter_op.scatter()}; + scatter_builder.create(loc); + TF_RETURN_IF_ERROR(SpliceHloComputation(scatter_builder, loc, *hlo->scatter(), + emission_context_)); + + return Status::OK(); +} + Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) { return ThunkEmitter(this).HandleCustomCall(custom_call); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h index dc7300490f8..ee0dbd6f320 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -62,6 +62,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleReduceWindow(HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(HloInstruction* hlo) override; Status HandleTuple(HloInstruction* tuple) override; Status FinishVisit(HloInstruction* root) override; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD index dbaea44ea3a..921ea01c8d3 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -51,6 +51,7 @@ tf_cc_test( "rem.hlo", "rsqrt.hlo", "select.hlo", + "select_and_scatter.hlo", "sign.hlo", "sqrt.hlo", "tanh.hlo", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index b45acf98664..b73e4efe2d3 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -193,6 +193,12 @@ TEST_F(LhloGenTest, Rsqrt) { "rsqrt.hlo")); } +TEST_F(LhloGenTest, SelectAndScatter) { + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "select_and_scatter.hlo")); +} + TEST_F(LhloGenTest, Sign) { CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo new file mode 100644 index 00000000000..21979a2815f --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo @@ -0,0 +1,53 @@ +HloModule SelectAndScatter + +%ge (x: f32[], y: f32[]) -> pred[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %compare = pred[] compare(f32[] %x, f32[] %y), direction=GE +} + +%add (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %SelectAndScatter (x: f32[128,64,112,112], + y: f32[128,64,56,56], + z: f32[]) -> f32[128,64,112,112] { + %x = f32[128,64,112,112] parameter(0) + %y = f32[128,64,56,56] parameter(1) + %z = f32[] parameter(2) + ROOT %result = f32[128,64,112,112] select-and-scatter( + f32[128,64,112,112] %x, + f32[128,64,56,56] %y, + f32[] %z), + window={size=1x1x3x3 stride=1x1x2x2 pad=0_0x0_0x0_1x0_1}, + select=%ge, + scatter=%add +} + +// CHECK: func @"select-and-scatter"( +// CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[SRC:%.*]]: [[SRCT:.*]], [[CST:%.*]]: memref, [[RES:%.*]]: [[REST:.*]]) { +// CHECK: "xla_lhlo.select_and_scatter"([[ARG]], [[SRC]], [[CST]], [[RES]]) ( { +// CHECK: ^bb0([[LHS:%.*]]: memref, [[RHS:%.*]]: memref, +// CHECK-SAME: [[OUT:%.*]]: memref): +// CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]] +// CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]] +// CHECK: [[OUT_TENSOR:%.*]] = "xla_hlo.compare" +// CHECK-SAME: ([[LHS_TENSOR]], [[RHS_TENSOR]]) {comparison_direction = "GE"} +// CHECK: tensor_store [[OUT_TENSOR]], [[OUT]] +// CHECK: xla_lhlo.terminator +// CHECK: }, { +// CHECK: ^bb0([[LHS_:%.*]]: memref, [[RHS_:%.*]]: memref, +// CHECK-SAME: [[OUT_:%.*]]: memref): +// CHECK: [[LHS_TENSOR_:%.*]] = tensor_load [[LHS_]] +// CHECK: [[RHS_TENSOR_:%.*]] = tensor_load [[RHS_]] +// CHECK: [[OUT_TENSOR_:%.*]] = xla_hlo.add [[LHS_TENSOR_]], [[RHS_TENSOR_]] +// CHECK: tensor_store [[OUT_TENSOR_]], [[OUT_]] +// CHECK: xla_lhlo.terminator +// CHECK: }) { +// CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]> +// CHECK-SAME: window_strides = dense<[1, 1, 2, 2]> +// CHECK-SAME: } : ([[ARGT]], [[SRCT]], memref, [[REST]]) -> () From 86ece822c81e76961263d93aadf492d2d6602274 Mon Sep 17 00:00:00 2001 From: sunchenggen Date: Fri, 20 Mar 2020 20:54:12 +0800 Subject: [PATCH 311/492] add unit test for bugfix --- tensorflow/cc/framework/gradients_test.cc | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 26e3170ad8e..75291678177 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -503,6 +503,42 @@ TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) { EXPECT_EQ(grad_result[0].flat()(0), 17610.0f); } +TEST_F(GradientsTest, AddSymbolicGradientsTest) { + Scope scope = Scope::NewRootScope(); + for (int cnt = 0; cnt < 100; ++cnt) { + int N = 5 + rand() % 10; + // Construct forward graph. + OutputList inputs; + for (int i = 0; i < N; ++i) { + auto a = Const(scope, i, {1}); + inputs.push_back(a); + } + + auto pack = Stack(scope, inputs); + TF_ASSERT_OK(scope.status()); + + // Construct grad inputs. + OutputList output_grads; + Tensor ts(DT_INT32, {N, 1}); + auto v = ts.matrix(); + for (int i = 0; i < N; ++i) { + v(i, 0) = i; + } + auto dy = Const(scope, ts); + output_grads.push_back(dy); + // Call AddSymbolicGradients. + std::vector grad_outputs; + TF_ASSERT_OK(AddSymbolicGradients(scope, {pack.output}, inputs, + output_grads, &grad_outputs)); + ClientSession session((scope)); + std::vector in_grad; + TF_ASSERT_OK(session.Run(grad_outputs, &in_grad)); + for (int i = 0; i < N; ++i) { + test::ExpectTensorEqual(in_grad[i], test::AsTensor({i}, {1})); + } + } +} + // StopGradientSingleOutputMultiEdgeTest tests combinations of valid and // 'NoGradient' (induced by StopGradient op) returned along multiple edges from // a single nodes output. From b57d910db53d8f91c6a611b00a184e55fcaee06a Mon Sep 17 00:00:00 2001 From: Xunkai Zhang Date: Fri, 20 Mar 2020 05:52:10 -0700 Subject: [PATCH 312/492] Opensource TFLite Support codegen. PiperOrigin-RevId: 302011153 Change-Id: Idb2f649dc48fdc449fac2d6e9009719d29afb2ad --- .../lite/experimental/support/codegen/BUILD | 87 ++ .../experimental/support/codegen/README.md | 13 + .../support/codegen/android_java_generator.cc | 978 ++++++++++++++++++ .../support/codegen/android_java_generator.h | 107 ++ .../support/codegen/code_generator.cc | 179 ++++ .../support/codegen/code_generator.h | 80 ++ .../support/codegen/code_generator_test.cc | 126 +++ .../support/codegen/metadata_helper.cc | 92 ++ .../support/codegen/metadata_helper.h | 51 + .../experimental/support/codegen/python/BUILD | 38 + .../support/codegen/python/codegen.py | 96 ++ .../support/codegen/python/codegen_lib.cc | 49 + .../experimental/support/codegen/utils.cc | 194 ++++ .../lite/experimental/support/codegen/utils.h | 127 +++ .../support/codegen/utils_test.cc | 97 ++ 15 files changed, 2314 insertions(+) create mode 100644 tensorflow/lite/experimental/support/codegen/BUILD create mode 100644 tensorflow/lite/experimental/support/codegen/README.md create mode 100644 tensorflow/lite/experimental/support/codegen/android_java_generator.cc create mode 100644 tensorflow/lite/experimental/support/codegen/android_java_generator.h create mode 100644 tensorflow/lite/experimental/support/codegen/code_generator.cc create mode 100644 tensorflow/lite/experimental/support/codegen/code_generator.h create mode 100644 tensorflow/lite/experimental/support/codegen/code_generator_test.cc create mode 100644 tensorflow/lite/experimental/support/codegen/metadata_helper.cc create mode 100644 tensorflow/lite/experimental/support/codegen/metadata_helper.h create mode 100644 tensorflow/lite/experimental/support/codegen/python/BUILD create mode 100644 tensorflow/lite/experimental/support/codegen/python/codegen.py create mode 100644 tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc create mode 100644 tensorflow/lite/experimental/support/codegen/utils.cc create mode 100644 tensorflow/lite/experimental/support/codegen/utils.h create mode 100644 tensorflow/lite/experimental/support/codegen/utils_test.cc diff --git a/tensorflow/lite/experimental/support/codegen/BUILD b/tensorflow/lite/experimental/support/codegen/BUILD new file mode 100644 index 00000000000..96bb3e35952 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/BUILD @@ -0,0 +1,87 @@ +# The tools for generating wrapper classes for a TFLite model with metadata. + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "utils", + srcs = [ + "utils.cc", + ], + hdrs = [ + "utils.h", + ], + deps = [ + ], +) + +cc_library( + name = "code_generator", + srcs = [ + "code_generator.cc", + ], + hdrs = [ + "code_generator.h", + ], + deps = [ + ":utils", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", + ], +) + +cc_library( + name = "metadata_helper", + srcs = [ + "metadata_helper.cc", + ], + hdrs = [ + "metadata_helper.h", + ], + deps = [ + ":utils", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_library( + name = "android_java_generator", + srcs = [ + "android_java_generator.cc", + ], + hdrs = [ + "android_java_generator.h", + ], + deps = [ + ":code_generator", + ":metadata_helper", + ":utils", + "//tensorflow/core/platform:logging", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "code_generator_test", + size = "small", + srcs = ["code_generator_test.cc"], + data = ["//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"], + deps = [ + ":code_generator", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + deps = [ + ":utils", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/support/codegen/README.md b/tensorflow/lite/experimental/support/codegen/README.md new file mode 100644 index 00000000000..425dab37b04 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/README.md @@ -0,0 +1,13 @@ +# TensorFlow Lite Android Wrapper Code Generator + +For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md), +developers can use the TensorFlow Lite Android wrapper code generator to create +platform specific wrapper code. The wrapper code removes the need to interact +directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow +Lite model with typed objects such as `Bitmap` and `Rect`. + +The usefulness of the code generator depend on the completeness of the +TensorFlow Lite model's metadata entry. Refer to the `` section +under relevant fields in +[metadata_schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs), +to see how the codegen tool parses each field. diff --git a/tensorflow/lite/experimental/support/codegen/android_java_generator.cc b/tensorflow/lite/experimental/support/codegen/android_java_generator.cc new file mode 100644 index 00000000000..b16db570aaa --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/android_java_generator.cc @@ -0,0 +1,978 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h" + +#include + +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" +#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h" +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace { + +using details_android_java::ModelInfo; +using details_android_java::TensorInfo; + +// Helper class to organize the C++ code block as a generated code block. +// Using ctor and dtor to simulate an enter/exit schema like `with` in Python. +class AsBlock { + public: + AsBlock(CodeWriter* code_writer, const std::string& before, + bool trailing_blank_line = false) + : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) { + code_writer_->AppendNoNewLine(before); + code_writer_->Append(" {"); + code_writer_->Indent(); + } + ~AsBlock() { + code_writer_->Outdent(); + code_writer_->Append("}"); + if (trailing_blank_line_) { + code_writer_->NewLine(); + } + } + + private: + CodeWriter* code_writer_; + bool trailing_blank_line_; +}; + +// Declare the functions first, so that the functions can follow a logical +// order. +bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*); + +std::string GetModelVersionedName(const ModelMetadata* metadata) { + std::string model_name = "MyModel"; + if (metadata->name() != nullptr && !(metadata->name()->str().empty())) { + model_name = metadata->name()->str(); + } + std::string model_version = "unknown"; + if (metadata->version() != nullptr && !(metadata->version()->str().empty())) { + model_version = metadata->version()->str(); + } + return model_name + " (Version: " + model_version + ")"; +} + +TensorInfo CreateTensorInfo(const TensorMetadata* metadata, + const std::string& name, bool is_input, int index, + ErrorReporter* err) { + TensorInfo tensor_info; + std::string tensor_identifier = is_input ? "input" : "output"; + tensor_identifier += " " + std::to_string(index); + tensor_info.associated_axis_label_index = FindAssociatedFile( + metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err); + tensor_info.associated_value_label_index = FindAssociatedFile( + metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err); + if (is_input && (tensor_info.associated_axis_label_index >= 0 || + tensor_info.associated_value_label_index >= 0)) { + err->Warning( + "Found label file on input tensor (%s). Label file for input " + "tensor is not supported yet. The " + "file will be ignored.", + tensor_identifier.c_str()); + } + if (tensor_info.associated_axis_label_index >= 0 && + tensor_info.associated_value_label_index >= 0) { + err->Warning( + "Found both axis label file and value label file for tensor (%s), " + "which is not supported. Only the axis label file will be used.", + tensor_identifier.c_str()); + } + tensor_info.is_input = is_input; + tensor_info.name = SnakeCaseToCamelCase(name); + tensor_info.upper_camel_name = tensor_info.name; + tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]); + tensor_info.normalization_unit = + FindNormalizationUnit(metadata, tensor_identifier, err); + if (metadata->content()->content_properties_type() == + ContentProperties_ImageProperties) { + if (metadata->content() + ->content_properties_as_ImageProperties() + ->color_space() == ColorSpaceType_RGB) { + tensor_info.content_type = "image"; + tensor_info.wrapper_type = "TensorImage"; + tensor_info.processor_type = "ImageProcessor"; + return tensor_info; + } else { + err->Warning( + "Found Non-RGB image on tensor (%s). Codegen currently does not " + "support it, and regard it as a plain numeric tensor.", + tensor_identifier.c_str()); + } + } + tensor_info.content_type = "tensor"; + tensor_info.wrapper_type = "TensorBuffer"; + tensor_info.processor_type = "TensorProcessor"; + return tensor_info; +} + +ModelInfo CreateModelInfo(const ModelMetadata* metadata, + const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path, + ErrorReporter* err) { + ModelInfo model_info; + if (!CodeGenerator::VerifyMetadata(metadata, err)) { + // TODO(b/150116380): Create dummy model info. + err->Error("Validating metadata failed."); + return model_info; + } + model_info.package_name = package_name; + model_info.model_class_name = model_class_name; + model_info.model_asset_path = model_asset_path; + model_info.model_versioned_name = GetModelVersionedName(metadata); + const auto* graph = metadata->subgraph_metadata()->Get(0); + auto names = CodeGenerator::NameInputsAndOutputs( + graph->input_tensor_metadata(), graph->output_tensor_metadata()); + std::vector input_tensor_names = std::move(names.first); + std::vector output_tensor_names = std::move(names.second); + for (int i = 0; i < graph->input_tensor_metadata()->size(); i++) { + model_info.inputs.push_back( + CreateTensorInfo(graph->input_tensor_metadata()->Get(i), + input_tensor_names[i], true, i, err)); + } + for (int i = 0; i < graph->output_tensor_metadata()->size(); i++) { + model_info.outputs.push_back( + CreateTensorInfo(graph->output_tensor_metadata()->Get(i), + output_tensor_names[i], false, i, err)); + } + return model_info; +} + +void SetCodeWriterWithTensorInfo(CodeWriter* code_writer, + const TensorInfo& tensor_info) { + code_writer->SetTokenValue("NAME", tensor_info.name); + code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name); + code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type); + code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type); + std::string wrapper_name = tensor_info.wrapper_type; + wrapper_name[0] = tolower(wrapper_name[0]); + code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name); + code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type); + code_writer->SetTokenValue("NORMALIZATION_UNIT", + std::to_string(tensor_info.normalization_unit)); + code_writer->SetTokenValue( + "ASSOCIATED_AXIS_LABEL_INDEX", + std::to_string(tensor_info.associated_axis_label_index)); + code_writer->SetTokenValue( + "ASSOCIATED_VALUE_LABEL_INDEX", + std::to_string(tensor_info.associated_value_label_index)); +} + +void SetCodeWriterWithModelInfo(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->SetTokenValue("PACKAGE", model_info.package_name); + code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path); + code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name); +} + +constexpr char JAVA_DEFAULT_PACKAGE[] = "default"; + +std::string ConvertPackageToPath(const std::string& package) { + if (package == JAVA_DEFAULT_PACKAGE) { + return ""; + } + std::string path = package; + std::replace(path.begin(), path.end(), '.', '/'); + return path; +} + +bool IsImageUsed(const ModelInfo& model) { + for (const auto& input : model.inputs) { + if (input.content_type == "image") { + return true; + } + } + for (const auto& output : model.outputs) { + if (output.content_type == "image") { + return true; + } + } + return false; +} + +bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("// Generated by TFLite Support."); + code_writer->Append("package {{PACKAGE}};"); + code_writer->NewLine(); + + if (!GenerateWrapperImports(code_writer, model, err)) { + err->Error("Fail to generate imports for wrapper class."); + return false; + } + if (!GenerateWrapperClass(code_writer, model, err)) { + err->Error("Fail to generate wrapper class."); + return false; + } + code_writer->NewLine(); + return true; +} + +bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + const std::string support_pkg = "org.tensorflow.lite.support."; + std::vector imports{ + "android.content.Context", + "java.io.IOException", + "java.nio.ByteBuffer", + "java.nio.FloatBuffer", + "java.util.Arrays", + "java.util.HashMap", + "java.util.List", + "java.util.Map", + "org.checkerframework.checker.nullness.qual.Nullable", + "org.tensorflow.lite.DataType", + "org.tensorflow.lite.Tensor.QuantizationParams", + support_pkg + "common.FileUtil", + support_pkg + "common.TensorProcessor", + support_pkg + "common.ops.CastOp", + support_pkg + "common.ops.DequantizeOp", + support_pkg + "common.ops.NormalizeOp", + support_pkg + "common.ops.QuantizeOp", + support_pkg + "label.TensorLabel", + support_pkg + "metadata.MetadataExtractor", + support_pkg + "metadata.schema.NormalizationOptions", + support_pkg + "model.Model", + support_pkg + "model.Model.Device", + support_pkg + "tensorbuffer.TensorBuffer", + }; + if (IsImageUsed(model)) { + for (const auto& target : + {"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp", + "image.ops.ResizeOp.ResizeMethod"}) { + imports.push_back(support_pkg + target); + } + imports.push_back("android.graphics.Bitmap"); + } + + std::sort(imports.begin(), imports.end()); + for (const auto target : imports) { + code_writer->SetTokenValue("TARGET", target); + code_writer->Append("import {{TARGET}};"); + } + code_writer->NewLine(); + return true; +} + +bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->SetTokenValue("MODEL_VERSIONED_NAME", + model.model_versioned_name); + code_writer->Append( + R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)"); + const auto code_block = + AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}"); + code_writer->Append(R"(private final Metadata metadata; +private final Model model; +private static final String MODEL_NAME = "{{MODEL_PATH}}";)"); + for (const auto& tensor : model.outputs) { + if (tensor.associated_axis_label_index >= 0) { + code_writer->SetTokenValue("NAME", tensor.name); + code_writer->Append("private final List {{NAME}}Labels;"); + } + } + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append( + "@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append( + "@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;"); + } + code_writer->NewLine(); + if (!GenerateWrapperInputs(code_writer, model, err)) { + err->Error("Failed to generate input classes"); + return false; + } + code_writer->NewLine(); + if (!GenerateWrapperOutputs(code_writer, model, err)) { + err->Error("Failed to generate output classes"); + return false; + } + code_writer->NewLine(); + if (!GenerateWrapperMetadata(code_writer, model, err)) { + err->Error("Failed to generate the metadata class"); + return false; + } + code_writer->NewLine(); + if (!GenerateWrapperAPI(code_writer, model, err)) { + err->Error("Failed to generate the common APIs"); + return false; + } + return true; +} + +bool GenerateWrapperInputs(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("/** Input wrapper of {@link {{MODEL_CLASS_NAME}}} */"); + auto class_block = AsBlock(code_writer, "public class Inputs"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append("private {{WRAPPER_TYPE}} {{NAME}};"); + } + code_writer->NewLine(); + // Ctor + { + auto ctor_block = AsBlock(code_writer, "public Inputs()"); + code_writer->Append( + "Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + if (tensor.content_type == "image") { + code_writer->Append( + "{{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());"); + } else { + code_writer->Append( + "{{NAME}} = " + "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), " + "metadata.get{{NAME_U}}Type());"); + } + } + } + for (const auto& tensor : model.inputs) { + code_writer->NewLine(); + SetCodeWriterWithTensorInfo(code_writer, tensor); + // Loaders + if (tensor.content_type == "image") { + { + auto bitmap_loader_block = + AsBlock(code_writer, "public void load{{NAME_U}}(Bitmap bitmap)"); + code_writer->Append(R"({{NAME}}.load(bitmap); +{{NAME}} = preprocess{{NAME_U}}({{NAME}});)"); + } + code_writer->NewLine(); + { + auto tensor_image_loader_block = AsBlock( + code_writer, "public void load{{NAME_U}}(TensorImage tensorImage)"); + code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorImage);"); + } + } else { // content_type == "FEATURE" or "UNKNOWN" + auto tensorbuffer_loader_block = AsBlock( + code_writer, "public void load{{NAME_U}}(TensorBuffer tensorBuffer)"); + code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorBuffer);"); + } + code_writer->NewLine(); + // Processor + code_writer->Append( + R"(private {{WRAPPER_TYPE}} preprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}}) { + if ({{NAME}}Preprocessor == null) { + return {{WRAPPER_NAME}}; + } + return {{NAME}}Preprocessor.process({{WRAPPER_NAME}}); +} +)"); + } + { + const auto get_buffer_block = AsBlock(code_writer, "Object[] getBuffer()"); + code_writer->AppendNoNewLine("return new Object[] {"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->AppendNoNewLine("{{NAME}}.getBuffer(), "); + } + code_writer->Backspace(2); + code_writer->Append("};"); + } + return true; +} + +bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */"); + auto class_block = AsBlock(code_writer, "public class Outputs"); + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};"); + } + code_writer->NewLine(); + { + const auto ctor_block = AsBlock(code_writer, "public Outputs()"); + code_writer->Append( + "Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;"); + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + if (tensor.content_type == "image") { + code_writer->Append( + R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type()); +{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)"); + } else { // FEATURE, UNKNOWN + code_writer->Append( + "{{NAME}} = " + "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), " + "metadata.get{{NAME_U}}Type());"); + } + } + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->NewLine(); + if (tensor.associated_axis_label_index >= 0) { + if (tensor.content_type == "image") { + err->Warning( + "Axis label for images is not supported. The labels will " + "be ignored."); + } else { + code_writer->Append(R"(public Map get{{NAME_U}}() { + return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getMapWithFloatValue(); +})"); + } + } else { + code_writer->Append(R"(public {{WRAPPER_TYPE}} get{{NAME_U}}() { + return postprocess{{NAME_U}}({{NAME}}); +})"); + } + code_writer->NewLine(); + { + auto processor_block = + AsBlock(code_writer, + "private {{WRAPPER_TYPE}} " + "postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})"); + code_writer->Append(R"(if ({{NAME}}Postprocessor == null) { + return {{WRAPPER_NAME}}; +} +return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)"); + } + } + code_writer->NewLine(); + { + const auto get_buffer_block = + AsBlock(code_writer, "Map getBuffer()"); + code_writer->Append("Map outputs = new HashMap<>();"); + for (int i = 0; i < model.outputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());"); + } + code_writer->Append("return outputs;"); + } + return true; +} + +bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append( + "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */"); + const auto class_block = AsBlock(code_writer, "public static class Metadata"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"(private final int[] {{NAME}}Shape; +private final DataType {{NAME}}DataType; +private final QuantizationParams {{NAME}}QuantizationParams;)"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"(private final float[] {{NAME}}Mean; +private final float[] {{NAME}}Stddev;)"); + } + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"(private final int[] {{NAME}}Shape; +private final DataType {{NAME}}DataType; +private final QuantizationParams {{NAME}}QuantizationParams;)"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"(private final float[] {{NAME}}Mean; +private final float[] {{NAME}}Stddev;)"); + } + if (tensor.associated_axis_label_index >= 0 || + tensor.associated_value_label_index >= 0) { + code_writer->Append("private final List {{NAME}}Labels;"); + } + } + code_writer->NewLine(); + { + const auto ctor_block = AsBlock( + code_writer, + "public Metadata(ByteBuffer buffer, Model model) throws IOException"); + code_writer->Append( + "MetadataExtractor extractor = new MetadataExtractor(buffer);"); + for (int i = 0; i < model.inputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append( + R"({{NAME}}Shape = extractor.getInputTensorShape({{ID}}); +{{NAME}}DataType = extractor.getInputTensorType({{ID}}); +{{NAME}}QuantizationParams = extractor.getInputTensorQuantizationParams({{ID}});)"); + if (model.inputs[i].normalization_unit >= 0) { + code_writer->Append( + R"(NormalizationOptions {{NAME}}NormalizationOptions = + (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions()); +FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer(); +{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()]; +{{NAME}}MeanBuffer.get({{NAME}}Mean); +FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer(); +{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()]; +{{NAME}}StddevBuffer.get({{NAME}}Stddev);)"); + } + } + for (int i = 0; i < model.outputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append( + R"({{NAME}}Shape = model.getOutputTensorShape({{ID}}); +{{NAME}}DataType = extractor.getOutputTensorType({{ID}}); +{{NAME}}QuantizationParams = extractor.getOutputTensorQuantizationParams({{ID}});)"); + if (model.outputs[i].normalization_unit >= 0) { + code_writer->Append( + R"(NormalizationOptions {{NAME}}NormalizationOptions = + (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions()); +FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer(); +{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()]; +{{NAME}}MeanBuffer.get({{NAME}}Mean); +FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer(); +{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()]; +{{NAME}}StddevBuffer.get({{NAME}}Stddev);)"); + } + if (model.outputs[i].associated_axis_label_index >= 0) { + code_writer->Append(R"(String {{NAME}}LabelsFileName = + extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name(); +{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)"); + } else if (model.outputs[i].associated_value_label_index >= 0) { + code_writer->Append(R"(String {{NAME}}LabelsFileName = + extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name(); +{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)"); + } + } + } + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public int[] get{{NAME_U}}Shape() { + return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length); +} + +public DataType get{{NAME_U}}Type() { + return {{NAME}}DataType; +} + +public QuantizationParams get{{NAME_U}}QuantizationParams() { + return {{NAME}}QuantizationParams; +})"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"( +public float[] get{{NAME_U}}Mean() { + return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length); +} + +public float[] get{{NAME_U}}Stddev() { + return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length); +})"); + } + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public int[] get{{NAME_U}}Shape() { + return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length); +} + +public DataType get{{NAME_U}}Type() { + return {{NAME}}DataType; +} + +public QuantizationParams get{{NAME_U}}QuantizationParams() { + return {{NAME}}QuantizationParams; +})"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"( +public float[] get{{NAME_U}}Mean() { + return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length); +} + +public float[] get{{NAME_U}}Stddev() { + return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length); +})"); + } + if (tensor.associated_axis_label_index >= 0 || + tensor.associated_value_label_index >= 0) { + code_writer->Append(R"( +public List get{{NAME_U}}Labels() { + return {{NAME}}Labels; +})"); + } + } + return true; +} + +bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append(R"(public Metadata getMetadata() { + return metadata; +} +)"); + code_writer->Append(R"(/** + * Creates interpreter and loads associated files if needed. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public {{MODEL_CLASS_NAME}}(Context context) throws IOException { + this(context, MODEL_NAME, Device.CPU, 1); +} + +/** + * Creates interpreter and loads associated files if needed, but loading another model in the same + * input / output structure with the original one. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public {{MODEL_CLASS_NAME}}(Context context, String modelPath) throws IOException { + this(context, modelPath, Device.CPU, 1); +} + +/** + * Creates interpreter and loads associated files if needed, with device and number of threads + * configured. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) throws IOException { + this(context, MODEL_NAME, device, numThreads); +} + +/** + * Creates interpreter for a user-specified model. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public {{MODEL_CLASS_NAME}}(Context context, String modelPath, Device device, int numThreads) throws IOException { + model = new Model.Builder(context, modelPath).setDevice(device).setNumThreads(numThreads).build(); + metadata = new Metadata(model.getData(), model);)"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( + {{PROCESSOR_TYPE}}.Builder {{NAME}}PreprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder())"); + if (tensor.content_type == "image") { + code_writer->Append(R"( .add(new ResizeOp( + metadata.get{{NAME_U}}Shape()[1], + metadata.get{{NAME_U}}Shape()[2], + ResizeMethod.NEAREST_NEIGHBOR)))"); + } + if (tensor.normalization_unit >= 0) { + code_writer->Append( + R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))"); + } + code_writer->Append( + R"( .add(new QuantizeOp( + metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(), + metadata.get{{NAME_U}}QuantizationParams().getScale())) + .add(new CastOp(metadata.get{{NAME_U}}Type())); + {{NAME}}Preprocessor = {{NAME}}PreprocessorBuilder.build();)"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->AppendNoNewLine(R"( + {{PROCESSOR_TYPE}}.Builder {{NAME}}PostprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder() + .add(new DequantizeOp( + metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(), + metadata.get{{NAME_U}}QuantizationParams().getScale())))"); + if (tensor.normalization_unit >= 0) { + code_writer->AppendNoNewLine(R"( + .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))"); + } + code_writer->Append(R"(; + {{NAME}}Postprocessor = {{NAME}}PostprocessorBuilder.build();)"); + if (tensor.associated_axis_label_index >= 0) { + code_writer->Append(R"( + {{NAME}}Labels = metadata.get{{NAME_U}}Labels();)"); + } + } + code_writer->Append("}"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public void reset{{NAME_U}}Preprocessor(@Nullable {{PROCESSOR_TYPE}} processor) { + {{NAME}}Preprocessor = processor; +})"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public void reset{{NAME_U}}Postprocessor(@Nullable {{PROCESSOR_TYPE}} processor) { + {{NAME}}Postprocessor = processor; +})"); + } + code_writer->Append(R"( +/** Creates inputs */ +public Inputs createInputs() { + return new Inputs(); +} + +/** Triggers the model. */ +public Outputs run(Inputs inputs) { + Outputs outputs = new Outputs(); + model.run(inputs.getBuffer(), outputs.getBuffer()); + return outputs; +} + +/** Closes the model. */ +public void close() { + model.close(); +})"); + return true; +} + +bool GenerateBuildGradleContent(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->Append(R"(buildscript { + repositories { + google() + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:3.2.1' + } +} + +allprojects { + repositories { + google() + jcenter() + flatDir { + dirs 'libs' + } + } +} + +apply plugin: 'com.android.library' + +android { + compileSdkVersion 29 + defaultConfig { + targetSdkVersion 29 + versionCode 1 + versionName "1.0" + } + aaptOptions { + noCompress "tflite" + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + lintOptions { + abortOnError false + } +} + +configurations { + libMetadata +} + +dependencies { + libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic' +} + +task downloadLibs(type: Sync) { + from configurations.libMetadata + into "$buildDir/libs" + rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar" +} + +preBuild.dependsOn downloadLibs + +dependencies { + compileOnly 'org.checkerframework:checker-qual:2.5.8' + api 'org.tensorflow:tensorflow-lite:0.0.0-nightly' + api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly' + api files("$buildDir/libs/tensorflow-lite-support-metadata.jar") + implementation 'org.apache.commons:commons-compress:1.19' +})"); + return true; +} + +bool GenerateAndroidManifestContent(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->Append(R"( + +)"); + return true; +} + +bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) { + code_writer->Append("# {{MODEL_CLASS_NAME}} Usage"); + code_writer->AppendNoNewLine(R"( +``` +import {{PACKAGE}}.{{MODEL_CLASS_NAME}}; + +// 1. Initialize the Model +{{MODEL_CLASS_NAME}} model = null; + +try { + model = new {{MODEL_CLASS_NAME}}(context); // android.content.Context + // Create the input container. + {{MODEL_CLASS_NAME}}.Inputs inputs = model.createInputs(); +} catch (IOException e) { + e.printStackTrace(); +} + +if (model != null) { + + // 2. Set the inputs)"); + for (const auto& t : model_info.inputs) { + SetCodeWriterWithTensorInfo(code_writer, t); + if (t.content_type == "image") { + code_writer->Append(R"( + // Load input tensor "{{NAME}}" from a Bitmap with ARGB_8888 format. + Bitmap bitmap = ...; + inputs.load{{NAME_U}}(bitmap); + // Alternatively, load the input tensor "{{NAME}}" from a TensorImage. + // Check out TensorImage documentation to load other image data structures. + // TensorImage tensorImage = ...; + // inputs.load{{NAME_U}}(tensorImage);)"); + } else { + code_writer->Append(R"( + // Load input tensor "{{NAME}}" from a TensorBuffer. + // Check out TensorBuffer documentation to load other data structures. + TensorBuffer tensorBuffer = ...; + inputs.load{{NAME_U}}(tensorBuffer);)"); + } + } + code_writer->Append(R"( + // 3. Run the model + {{MODEL_CLASS_NAME}}.Outputs outputs = model.run(inputs);)"); + code_writer->Append(R"( + // 4. Retrieve the results)"); + for (const auto& t : model_info.outputs) { + SetCodeWriterWithTensorInfo(code_writer, t); + if (t.associated_axis_label_index >= 0) { + code_writer->SetTokenValue("WRAPPER_TYPE", "Map"); + } + code_writer->Append( + R"( {{WRAPPER_TYPE}} {{NAME}} = outputs.get{{NAME_U}}();)"); + } + code_writer->Append(R"(} +```)"); + return true; +} + +GenerationResult::File GenerateWrapperFile(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto java_path = JoinPath(module_root, "src/main/java"); + const auto package_path = + JoinPath(java_path, ConvertPackageToPath(model_info.package_name)); + const auto file_path = + JoinPath(package_path, model_info.model_class_name + JAVA_EXT); + + CodeWriter code_writer(err); + code_writer.SetIndentString(" "); + SetCodeWriterWithModelInfo(&code_writer, model_info); + + if (!GenerateWrapperFileContent(&code_writer, model_info, err)) { + err->Error("Generating Java wrapper content failed."); + } + + const auto java_file = code_writer.ToString(); + return GenerationResult::File{file_path, java_file}; +} + +GenerationResult::File GenerateBuildGradle(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto file_path = JoinPath(module_root, "build.gradle"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateBuildGradleContent(&code_writer, model_info)) { + err->Error("Generating build.gradle failed."); + } + const auto content = code_writer.ToString(); + return GenerationResult::File{file_path, content}; +} + +GenerationResult::File GenerateAndroidManifest(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateAndroidManifestContent(&code_writer, model_info)) { + err->Error("Generating AndroidManifest.xml failed."); + } + return GenerationResult::File{file_path, code_writer.ToString()}; +} + +GenerationResult::File GenerateDoc(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + std::string lower = model_info.model_class_name; + for (int i = 0; i < lower.length(); i++) { + lower[i] = std::tolower(lower[i]); + } + const auto file_path = JoinPath(module_root, lower + ".md"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateDocContent(&code_writer, model_info)) { + err->Error("Generating doc failed."); + } + return GenerationResult::File{file_path, code_writer.ToString()}; +} + +} // namespace + +AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root) + : CodeGenerator(), module_root_(module_root) {} + +GenerationResult AndroidJavaGenerator::Generate( + const Model* model, const std::string& package_name, + const std::string& model_class_name, const std::string& model_asset_path) { + GenerationResult result; + const ModelMetadata* metadata = GetMetadataFromModel(model); + if (metadata == nullptr) { + err_.Error( + "Cannot find TFLite Metadata in the model. Codegen will generate " + "nothing."); + return result; + } + details_android_java::ModelInfo model_info = CreateModelInfo( + metadata, package_name, model_class_name, model_asset_path, &err_); + result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_)); + result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_)); + result.files.push_back( + GenerateAndroidManifest(module_root_, model_info, &err_)); + result.files.push_back(GenerateDoc(module_root_, model_info, &err_)); + return result; +} + +GenerationResult AndroidJavaGenerator::Generate( + const char* model_storage, const std::string& package_name, + const std::string& model_class_name, const std::string& model_asset_path) { + const Model* model = GetModel(model_storage); + return Generate(model, package_name, model_class_name, model_asset_path); +} + +std::string AndroidJavaGenerator::GetErrorMessage() { + return err_.GetMessage(); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/android_java_generator.h b/tensorflow/lite/experimental/support/codegen/android_java_generator.h new file mode 100644 index 00000000000..f8821a0de70 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/android_java_generator.h @@ -0,0 +1,107 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ + +#include +#include +#include + +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace details_android_java { + +/// The intermediate data structure for generating code from TensorMetadata. +/// Should only be used as const reference when created. +struct TensorInfo { + std::string name; + std::string upper_camel_name; + std::string content_type; + std::string wrapper_type; + std::string processor_type; + bool is_input; + /// Optional. Set to -1 if not applicable. + int normalization_unit; + /// Optional. Set to -1 if associated_axis_label is empty. + int associated_axis_label_index; + /// Optional. Set to -1 if associated_value_label is empty. + int associated_value_label_index; +}; + +/// The intermediate data structure for generating code from ModelMetadata. +/// Should only be used as const reference when created. +struct ModelInfo { + std::string package_name; + std::string model_asset_path; + std::string model_class_name; + std::string model_versioned_name; + std::vector inputs; + std::vector outputs; +}; + +} // namespace details_android_java + +constexpr char JAVA_EXT[] = ".java"; + +/// Generates Android supporting codes and modules (in Java) based on TFLite +/// metadata. +class AndroidJavaGenerator : public CodeGenerator { + public: + /// Creates an AndroidJavaGenerator. + /// Args: + /// - module_root: The root of destination Java module. + explicit AndroidJavaGenerator(const std::string& module_root); + + /// Generates files. Returns the file paths and contents. + /// Args: + /// - model: The TFLite model with Metadata filled. + /// - package_name: The name of the Java package which generated classes + /// belong to. + /// - model_class_name: A readable name of the generated wrapper class, such + /// as "ImageClassifier", "MobileNetV2" or "MyModel". + /// - model_asset_path: The relevant path to the model file in the asset. + // TODO(b/141225157): Automatically generate model_class_name. + GenerationResult Generate(const Model* model, const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path); + + /// Generates files and returns the file paths and contents. + /// It's mostly identical with the previous one, but the model here is + /// provided as binary flatbuffer content without parsing. + GenerationResult Generate(const char* model_storage, + const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path); + + std::string GetErrorMessage(); + + private: + const std::string module_root_; + ErrorReporter err_; +}; + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ diff --git a/tensorflow/lite/experimental/support/codegen/code_generator.cc b/tensorflow/lite/experimental/support/codegen/code_generator.cc new file mode 100644 index 00000000000..687724815ef --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/code_generator.cc @@ -0,0 +1,179 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" + +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace { + +void ResolveConflictedNamesByAddingIndex(std::vector* names_ptr) { + auto& names = *names_ptr; + std::unordered_map indexes; + std::unordered_map first_appearance; + for (int i = 0; i < names.size(); i++) { + if (indexes.find(names[i]) == indexes.end()) { + indexes[names[i]] = 1; + first_appearance[names[i]] = i; + } else { + indexes[names[i]] += 1; + names[i].append(std::to_string(indexes[names[i]])); + } + } + for (const auto it : first_appearance) { + const auto& name = it.first; + const auto i = it.second; + if (indexes[name] > 1) { + names[i].append("1"); + } + } +} + +} // namespace + +CodeGenerator::CodeGenerator() {} + +bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata, + ErrorReporter* err) { + if (metadata == nullptr) { + err->Error("Loading nullptr is not allowed"); + return false; + } + if (metadata->subgraph_metadata()->size() != 1) { + err->Error("Only exact 1 subgraph is supported"); + return false; + } + return true; +} + +std::pair, std::vector> +CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs, + const TensorMetadataList* outputs) { + std::vector input_names; + std::vector output_names; + if (inputs != nullptr) { + input_names.reserve(inputs->size()); + for (const auto* tensor : *inputs) { + input_names.push_back(NameTensor(*tensor, "input")); + } + } + if (outputs != nullptr) { + output_names.reserve(outputs->size()); + for (const auto* tensor : *outputs) { + output_names.push_back(NameTensor(*tensor, "output")); + } + } + // Solve conflict + ResolveConflictedInputAndOutputNames(&input_names, &output_names); + return std::make_pair(input_names, output_names); +} + +std::string CodeGenerator::ConvertToValidName(const std::string& name) { + // lowercase all + std::string result = name; + for (int i = 0; i < result.size(); i++) { + result[i] = std::tolower(result[i]); + } + // replace all non-alpha or non-numeric with underscores, except underscore + // itself + for (int i = 0; i < result.size(); i++) { + if (result[i] != '_' && !std::isalnum(result[i])) { + result[i] = '_'; + } + } + // remove leading underscores + int leading_underscores = 0; + while (leading_underscores < result.size() && + result[leading_underscores] == '_') { + leading_underscores++; + } + result.erase(0, leading_underscores); + if (result.empty()) { + return ""; + } + // first char should be alpha + if (std::isalpha(result[0])) { + return result; + } + return "tensor_" + result; +} + +std::string CodeGenerator::NameTensor(const TensorMetadata& tensor, + const std::string& default_name) { + if (tensor.name() != nullptr && tensor.name()->size() > 0) { + // TODO(b/141225157) Validate tensor name. It should be in lower case. + auto suggested_name = ConvertToValidName(tensor.name()->str()); + if (!suggested_name.empty()) { + return suggested_name; + } + } + auto* content = tensor.content(); + if (content == nullptr || content->content_properties() == nullptr) { + return default_name; + } + switch (content->content_properties_type()) { + case ContentProperties_ImageProperties: + return "image"; + case ContentProperties_FeatureProperties: + return "feature"; + default: + return default_name; + } +} + +void CodeGenerator::ResolveConflictedInputAndOutputNames( + std::vector* inputs, std::vector* outputs) { + std::unordered_set io_conflict; + auto& input_names = *inputs; + auto& output_names = *outputs; + for (const auto input : input_names) { + if (io_conflict.find(input) != io_conflict.end()) { + continue; + } + for (const auto output : output_names) { + if (input == output) { + io_conflict.insert(input); + break; + } + } + } + for (int i = 0; i < input_names.size(); i++) { + if (io_conflict.find(input_names[i]) != io_conflict.end()) { + input_names[i] = "input_" + input_names[i]; + } + } + for (int i = 0; i < output_names.size(); i++) { + if (io_conflict.find(output_names[i]) != io_conflict.end()) { + output_names[i] = "output_" + output_names[i]; + } + } + // 2. Second, add index if input[i] == input[j] + ResolveConflictedNamesByAddingIndex(&input_names); + ResolveConflictedNamesByAddingIndex(&output_names); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/code_generator.h b/tensorflow/lite/experimental/support/codegen/code_generator.h new file mode 100644 index 00000000000..5bb151e50a0 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/code_generator.h @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_ + +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +struct GenerationResult { + struct File { + std::string path; + std::string content; + }; + std::vector files; +}; + +/// Defines language-independent codegen strategies, like class naming, .etc. +/// Should not be used directly. +class CodeGenerator { + public: + CodeGenerator(); + + using TensorMetadataList = + typename flatbuffers::Vector>; + + virtual ~CodeGenerator() {} + + // Strategies. + /// Names all the IO tensors. It's useful when they don't have names, or the + /// names have conflicts. We have to name every tensor for code generation. + // TODO(b/141225157): Add reserved keywords check. + static std::pair, std::vector> + NameInputsAndOutputs(const TensorMetadataList* inputs, + const TensorMetadataList* outputs); + + /// Loads a metadata for code generation. + /// Returns false if the metadata is not good for generation. + static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err); + + protected: + /// Converts a name into a valid form. Rules: + /// - lower all letters. + /// - replace all non alphabet nor numeric characters with underscores. + /// - remove prefix underscores. + /// - add prefix if the leading character is a number. + /// Returns empty string if not possible. + static std::string ConvertToValidName(const std::string& name); + static std::string NameTensor(const TensorMetadata& tensor, + const std::string& default_name); + static void ResolveConflictedInputAndOutputNames( + std::vector* input, std::vector* output); +}; + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_ diff --git a/tensorflow/lite/experimental/support/codegen/code_generator_test.cc b/tensorflow/lite/experimental/support/codegen/code_generator_test.cc new file mode 100644 index 00000000000..57c5cec60e4 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/code_generator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" + +#include +#include + +namespace tflite { +namespace support { +namespace codegen { +namespace { + +using ::testing::ElementsAreArray; + +class CodeGeneratorTest : public ::testing::Test { + public: + class TestingCodeGenerator : public CodeGenerator { + public: + explicit TestingCodeGenerator() : CodeGenerator() {} + + // Make tested method public. + static std::string ConvertToValidName(const std::string& name) { + return CodeGenerator::ConvertToValidName(name); + } + static void ResolveConflictedInputAndOutputNames( + std::vector* input, std::vector* output) { + CodeGenerator::ResolveConflictedInputAndOutputNames(input, output); + } + }; +}; + +TEST_F(CodeGeneratorTest, UpperCasesShouldLower) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"), + "alphabetcool"); +} + +TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_"); +} + +TEST_F(CodeGeneratorTest, NoLeadingUnderscore) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z"); +} + +TEST_F(CodeGeneratorTest, NoLeadingNumbers) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"), + "tensor_3000_cool_tensors"); +} + +TEST_F(CodeGeneratorTest, TestSimpleIONames) { + std::vector inputs = {"image"}; + std::vector outputs = {"output"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"image"})); + EXPECT_THAT(outputs, ElementsAreArray({"output"})); +} + +TEST_F(CodeGeneratorTest, TestIOConflict) { + std::vector inputs = {"image"}; + std::vector outputs = {"image"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"input_image"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image"})); +} + +TEST_F(CodeGeneratorTest, TestInternalConflict) { + std::vector inputs = {"image", "image"}; + std::vector outputs = {"output"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflictNTo1) { + std::vector inputs = {"image", "image"}; + std::vector outputs = {"image"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflict) { + std::vector inputs = {"image", "audio", "image", "audio", + "audio"}; + std::vector outputs = {"image", "image", "audio", "feature", + "feature"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, + ElementsAreArray({"input_image1", "input_audio1", "input_image2", + "input_audio2", "input_audio3"})); + EXPECT_THAT(outputs, + ElementsAreArray({"output_image1", "output_image2", + "output_audio", "feature1", "feature2"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflictReversed) { + std::vector inputs = {"image", "image", "audio", "feature", + "feature"}; + std::vector outputs = {"image", "audio", "image", "audio", + "audio"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, + ElementsAreArray({"input_image1", "input_image2", "input_audio", + "feature1", "feature2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1", + "output_image2", "output_audio2", + "output_audio3"})); +} + +} // namespace +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/metadata_helper.cc b/tensorflow/lite/experimental/support/codegen/metadata_helper.cc new file mode 100644 index 00000000000..3fcc7aee3bf --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/metadata_helper.cc @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h" + +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +constexpr char BUFFER_KEY[] = "TFLITE_METADATA"; +const ModelMetadata* GetMetadataFromModel(const Model* model) { + if (model->metadata() == nullptr) { + return nullptr; + } + for (auto i = 0; i < model->metadata()->size(); i++) { + if (model->metadata()->Get(i)->name()->str() == BUFFER_KEY) { + const auto buffer_index = model->metadata()->Get(i)->buffer(); + const auto* buffer = model->buffers()->Get(buffer_index)->data()->data(); + return GetModelMetadata(buffer); + } + } + return nullptr; +} + +int FindAssociatedFile(const TensorMetadata* metadata, + const AssociatedFileType file_type, + const std::string& tensor_identifier, + ErrorReporter* err) { + int result = -1; + if (metadata->associated_files() == nullptr || + metadata->associated_files()->size() == 0) { + return result; + } + for (int i = 0; i < metadata->associated_files()->size(); i++) { + const auto* file_metadata = metadata->associated_files()->Get(i); + if (file_metadata->type() == file_type) { + if (result >= 0) { + err->Warning( + "Multiple associated file of type %d found on tensor %s. Only the " + "first one will be used.", + file_type, tensor_identifier.c_str()); + continue; + } + result = i; + } + } + return result; +} + +int FindNormalizationUnit(const TensorMetadata* metadata, + const std::string& tensor_identifier, + ErrorReporter* err) { + int result = -1; + if (metadata->process_units() == nullptr || + metadata->process_units()->size() == 0) { + return result; + } + for (int i = 0; i < metadata->process_units()->size(); i++) { + const auto* process_uint = metadata->process_units()->Get(i); + if (process_uint->options_type() == + ProcessUnitOptions_NormalizationOptions) { + if (result >= 0) { + err->Warning( + "Multiple normalization unit found in tensor %s. Only the first " + "one will be effective.", + tensor_identifier.c_str()); + continue; + } + result = i; + } + } + return result; +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/metadata_helper.h b/tensorflow/lite/experimental/support/codegen/metadata_helper.h new file mode 100644 index 00000000000..0d5e06b4506 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/metadata_helper.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_ + +#include + +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's +/// lifetime is scoped by the model. +/// Returns nullptr if we cannot find any metadata. +const ModelMetadata* GetMetadataFromModel(const Model* model); + +/// Finds an associated file from a TensorMetadata of certain type. If there're +/// multiple files meet the criteria, only the first one is used. If there's no +/// file meets the criteria, -1 will be returned. +int FindAssociatedFile(const TensorMetadata* metadata, + const AssociatedFileType file_type, + const std::string& tensor_identifier, + ErrorReporter* err); + +/// Find the first normalization unit. If none, return -1. +int FindNormalizationUnit(const TensorMetadata* metadata, + const std::string& tensor_identifier, + ErrorReporter* err); + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_ diff --git a/tensorflow/lite/experimental/support/codegen/python/BUILD b/tensorflow/lite/experimental/support/codegen/python/BUILD new file mode 100644 index 00000000000..d364d82eaeb --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/python/BUILD @@ -0,0 +1,38 @@ +load("//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_codegen", + srcs = [ + "codegen_lib.cc", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_codegen", + deps = [ + "//tensorflow/lite/experimental/support/codegen:android_java_generator", + "//tensorflow/lite/experimental/support/codegen:code_generator", + "//tensorflow/python:pybind11_lib", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + +py_binary( + name = "codegen", + srcs = [ + "codegen.py", + ], + python_version = "PY3", + deps = [ + ":_pywrap_codegen", + "@absl_py//absl:app", + "@absl_py//absl/flags", + "@absl_py//absl/logging", + ], +) diff --git a/tensorflow/lite/experimental/support/codegen/python/codegen.py b/tensorflow/lite/experimental/support/codegen/python/codegen.py new file mode 100644 index 00000000000..f28bafe5cff --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/python/codegen.py @@ -0,0 +1,96 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generates Android Java sources from a TFLite model with metadata.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +from absl import app +from absl import flags +from absl import logging + +from tensorflow.lite.experimental.support.codegen.python import _pywrap_codegen + +FLAGS = flags.FLAGS + +flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.') +flags.DEFINE_string('destination', None, 'Path of destination of generation.') +flags.DEFINE_string('package_name', 'org.tensorflow.lite.support', + 'Name of generated java package to put the wrapper class.') +flags.DEFINE_string( + 'model_class_name', 'MyModel', + 'Name of generated wrapper class (should not contain package name).') +flags.DEFINE_string( + 'model_asset_path', '', + '(Optional) Path to the model in generated assets/ dir. If not set, ' + 'generator will use base name of input model.' +) + + +def get_model_buffer(path): + if not os.path.isfile(path): + logging.error('Cannot find model at path %s.', path) + with open(path, 'rb') as f: + buf = f.read() + return buf + + +def prepare_directory_for_file(file_path): + target_dir = os.path.dirname(file_path) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + return + if not os.path.isdir(target_dir): + logging.error('Cannot write to %s', target_dir) + + +def main(argv): + if len(argv) > 1: + logging.error('None flag arguments found: [%s]', ', '.join(argv[1:])) + + codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination) + model_buffer = get_model_buffer(FLAGS.model) + model_asset_path = FLAGS.model_asset_path + if not model_asset_path: + model_asset_path = os.path.basename(FLAGS.model) + result = codegen.generate(model_buffer, FLAGS.package_name, + FLAGS.model_class_name, model_asset_path) + error_message = codegen.get_error_message().strip() + if error_message: + logging.error(error_message) + if not result.files: + logging.error('Generation failed!') + return + + for each in result.files: + prepare_directory_for_file(each.path) + with open(each.path, 'w') as f: + f.write(each.content) + + logging.info('Generation succeeded!') + model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets', + model_asset_path) + prepare_directory_for_file(model_asset_path) + shutil.copy(FLAGS.model, model_asset_path) + logging.info('Model copied into assets!') + + +if __name__ == '__main__': + flags.mark_flag_as_required('model') + flags.mark_flag_as_required('destination') + app.run(main) diff --git a/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc b/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc new file mode 100644 index 00000000000..e3db29b1959 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "include/pybind11/detail/common.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" +#include "include/pybind11/stl.h" +#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h" +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" + +namespace tflite { +namespace support { +namespace codegen { + +template +using overload_cast_ = pybind11::detail::overload_cast_impl; + +PYBIND11_MODULE(_pywrap_codegen, m) { + pybind11::class_(m, "AndroidJavaGenerator") + .def(pybind11::init()) + .def("generate", + overload_cast_()( + &AndroidJavaGenerator::Generate)) + .def("get_error_message", &AndroidJavaGenerator::GetErrorMessage); + pybind11::class_(m, "GenerationResult") + .def(pybind11::init<>()) + .def_readwrite("files", &GenerationResult::files); + pybind11::class_(m, "GenerationResultFile") + .def(pybind11::init<>()) + .def_readwrite("path", &GenerationResult::File::path) + .def_readwrite("content", &GenerationResult::File::content); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/utils.cc b/tensorflow/lite/experimental/support/codegen/utils.cc new file mode 100644 index 00000000000..394c147a33f --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/utils.cc @@ -0,0 +1,194 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/support/codegen/utils.h" + +#include + +namespace tflite { +namespace support { +namespace codegen { + +int ErrorReporter::Warning(const char* format, ...) { + va_list args; + va_start(args, format); + return Report("[WARN] ", format, args); +} + +int ErrorReporter::Error(const char* format, ...) { + va_list args; + va_start(args, format); + return Report("[ERROR] ", format, args); +} + +int ErrorReporter::Report(const char* prefix, const char* format, + va_list args) { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << prefix << buf << std::endl; + return formatted; +} + +std::string ErrorReporter::GetMessage() { + std::string value = buffer_.str(); + buffer_.str(""); + return value; +} + +CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {} + +void CodeWriter::SetTokenValue(const std::string& token, + const std::string& value) { + value_map_[token] = value; +} + +const std::string CodeWriter::GetTokenValue(const std::string& token) const { + auto iter = value_map_.find(token); + if (iter == value_map_.end()) { + // Typically only Code Generator's call this function (or `Append`). It's + // their duty to make sure the token is valid, and requesting for an invalid + // token implicits flaws in the code generation logic. + err_->Error("Internal: Cannot find value with token '%s'", token.c_str()); + return ""; + } + return iter->second; +} + +void CodeWriter::SetIndentString(const std::string& indent_str) { + indent_str_ = indent_str; +} + +void CodeWriter::Indent() { indent_++; } + +void CodeWriter::Outdent() { indent_--; } + +std::string CodeWriter::GenerateIndent() const { + std::string res; + res.reserve(indent_str_.size() * indent_); + for (int i = 0; i < indent_; i++) { + res.append(indent_str_); + } + return res; +} + +void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); } + +void CodeWriter::AppendNoNewLine(const std::string& text) { + AppendInternal(text, false); +} + +void CodeWriter::AppendInternal(const std::string& text, bool newline) { + // Prefix indent + if ((buffer_.empty() // nothing in the buffer + || buffer_.back() == '\n') // is on new line + // is writing on current line + && (!text.empty() && text[0] != '\n' && text[0] != '\r')) { + buffer_.append(GenerateIndent()); + } + // State machine variables + bool in_token = false; + int i = 0; + // Rough memory reserve + buffer_.reserve(buffer_.size() + text.size()); + std::string token_buffer; + // A simple LL1 analysis + while (i < text.size()) { + char cur = text[i]; + char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian + if (in_token == false) { + if (cur == '{' && cur_next == '{') { // Enter token + in_token = true; + i += 2; + } else if (cur == '\n') { // We need to apply global indent here + buffer_.push_back(cur); + if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') { + buffer_.append(GenerateIndent()); + } + i += 1; + } else { + buffer_.push_back(cur); + i += 1; + } + } else { + if (cur == '}' && cur_next == '}') { // Close token + in_token = false; + const auto value = GetTokenValue(token_buffer); + buffer_.append(value); + token_buffer.clear(); + i += 2; + } else { + token_buffer.push_back(cur); + i += 1; + } + } + } + if (!token_buffer.empty()) { + // Typically only Code Generator's call this function. It's + // their duty to make sure the code (or template) has valid syntax, and + // unclosed "{{...}}" implicits severe error in the template. + err_->Error("Internal: Invalid template: {{token}} is not closed."); + } + if (newline) { + buffer_.push_back('\n'); + } +} + +void CodeWriter::NewLine() { Append(""); } + +void CodeWriter::Backspace(int n) { + buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0); +} + +std::string CodeWriter::ToString() const { return buffer_; } + +bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); } + +void CodeWriter::Clear() { + buffer_.clear(); + value_map_.clear(); + indent_ = 0; +} + +std::string SnakeCaseToCamelCase(const std::string& s) { + std::string t; + t.reserve(s.length()); + size_t i = 0; + // Note: Use simple string += for simplicity. + bool cap = false; + while (i < s.size()) { + const char c = s[i++]; + if (c == '_') { + cap = true; + } else if (cap) { + t += toupper(c); + cap = false; + } else { + t += c; + } + } + return t; +} + +std::string JoinPath(const std::string& a, const std::string& b) { + if (a.empty()) return b; + std::string a_fixed = a; + if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back(); + std::string b_fixed = b; + if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1); + return a_fixed + "/" + b_fixed; +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/utils.h b/tensorflow/lite/experimental/support/codegen/utils.h new file mode 100644 index 00000000000..17153bd6ad0 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/utils.h @@ -0,0 +1,127 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_ + +#include +#include +#include + +namespace tflite { +namespace support { +namespace codegen { + +/// Collects runtime error logs which could be showed later. +// TODO(b/150538286): Consider a better mechanism to simplify callsite code. +class ErrorReporter { + public: + int Warning(const char* format, ...); + int Error(const char* format, ...); + std::string GetMessage(); + + private: + int Report(const char* prefix, const char* format, va_list args); + std::stringstream buffer_; +}; + +/// Implements basic code generating with text templates. +/// +/// It could accept code templates and concatenate them into complete codes. A +/// template could contain named values. +/// +/// Example code: +/// CodeWriter code; +/// code.SetValue("NAME", "Foo"); +/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }"); +/// code.SetValue("NAME", "Bar"); +/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }"); +/// +/// Output: +/// void Foo() { printf("%s", "Foo"); } +/// void Bar() { printf("%s", "Bar"); } +class CodeWriter { + public: + explicit CodeWriter(ErrorReporter* err); + /// Sets value to a token. When generating code with template, a string in a + /// pair of {{ and }} will be regarded as a token and replaced with the + /// corresponding value in code generation. + /// It rewrites if the token already has a value. + void SetTokenValue(const std::string& token, const std::string& value); + + /// Gets the current value set on the given token. + const std::string GetTokenValue(const std::string& token) const; + + /// Sets the unit indent string. For example, in Java it should be " ". + void SetIndentString(const std::string& indent); + + /// Increases the indent by a unit (the string set in SetIndentString). + void Indent(); + + /// Decreases the indent by a unit (the string set in SetIndentString). + void Outdent(); + + /// Generates the indentation string. + std::string GenerateIndent() const; + + /// Appends a piece of template codes to the stream. Every named value will be + /// replaced via the real value. A new line will always be appended at the + /// end. + void Append(const std::string& text); + + /// Appends a piece of template codes to the stream. Same with `Append`, but a + /// new line will not be appended at the end. + void AppendNoNewLine(const std::string& text); + + /// Appends a new line to the stream. + void NewLine(); + + /// Deletes the last N charaters in the stream. If the stream has less than N + /// characters, deletes all. + void Backspace(int n); + + std::string ToString() const; + + /// Checks if the internal string stream is empty. Note: This method has + // overhead. + bool IsStreamEmpty() const; + + /// Clears all the internal string stream and value map. + void Clear(); + + private: + void AppendInternal(const std::string& text, bool newline); + + std::string indent_str_; + int indent_; + + std::map value_map_; + std::string buffer_; + + ErrorReporter* err_; +}; + +/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given +/// string "s" is already in snake case; or unexpected behavior may occur. +std::string SnakeCaseToCamelCase(const std::string& s); + +/// Joins 2 parts of file path into one, connected by unix path seperator '/'. +/// It's callers duty to ensure the two parts are valid. +std::string JoinPath(const std::string& a, const std::string& b); + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_ diff --git a/tensorflow/lite/experimental/support/codegen/utils_test.cc b/tensorflow/lite/experimental/support/codegen/utils_test.cc new file mode 100644 index 00000000000..8cdb838129c --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/utils_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/utils.h" + +#include + +namespace tflite { +namespace support { +namespace codegen { +namespace { + +TEST(ErrorReporterTest, TestReportError) { + ErrorReporter err; + err.Error("some text"); + EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n"); + EXPECT_EQ(err.GetMessage(), ""); +} + +TEST(CodeGeneratorTest, TestExample) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })"; + writer.Append(text); + writer.SetTokenValue("NAME", "Bar"); + writer.Append(text); + EXPECT_EQ( + "void Foo() { printf(\"%s\", \"Foo\"); }\n" + "void Bar() { printf(\"%s\", \"Bar\"); }\n", + writer.ToString()); +} + +TEST(CodeGeneratorTest, TestInexistentToken) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{name}}() {})"; + writer.Append(text); + EXPECT_EQ(err.GetMessage(), + "[ERROR] Internal: Cannot find value with token 'name'\n"); +} + +TEST(CodeGeneratorTest, TestUnclosedToken) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{NAME}() {})"; + writer.Append(text); + EXPECT_EQ(err.GetMessage(), + "[ERROR] Internal: Invalid template: {{token}} is not closed.\n"); +} + +TEST(CodeGeneratorTest, TestIndentControl) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetIndentString(" "); + writer.Indent(); + writer.AppendNoNewLine("abcde"); // Will indent + EXPECT_EQ(" abcde", writer.ToString()); + writer.Clear(); + writer.Indent(); + writer.AppendNoNewLine("abc\n\nde"); + // The blank line will not indent + EXPECT_EQ(" abc\n\n de", writer.ToString()); + writer.Clear(); + writer.Indent(); + writer.Append("abc"); + writer.Outdent(); + writer.AppendNoNewLine("def"); + EXPECT_EQ(" abc\ndef", writer.ToString()); +} + +TEST(CaseConversionTest, TestSnakeToCamel) { + EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel")); + EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_")); + EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel")); + EXPECT_EQ("", SnakeCaseToCamelCase("_")); + EXPECT_EQ("camel", SnakeCaseToCamelCase("camel")); +} + +} // namespace +} // namespace codegen +} // namespace support +} // namespace tflite From 22b65b2d1b2abb60055966b7acd5d7042902c666 Mon Sep 17 00:00:00 2001 From: Leslie-Fang Date: Fri, 20 Mar 2020 21:09:29 +0800 Subject: [PATCH 313/492] add more unsupport type --- tensorflow/compiler/tf2xla/kernels/reduction_ops.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index bad324b1aca..f6cc2f008e8 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -89,7 +89,9 @@ class MaxOp : public XlaReductionOp { } Status TypeCheck(xla::PrimitiveType xla_reduction_type_) { - if (xla_reduction_type_ == xla::C64) { + if (xla_reduction_type_ == xla::C64 || xla_reduction_type_ == xla::C128 || + xla_reduction_type_ == xla::TUPLE || + xla_reduction_type_ == xla::OPAQUE_TYPE) { return errors::InvalidArgument( "Unsupported PrimitiveType in MaxOp: '", xla::PrimitiveType_Name(xla_reduction_type_), "'"); From ec799ba3e10e9ba1508f09c9bc0372795a28d691 Mon Sep 17 00:00:00 2001 From: Leslie-Fang Date: Fri, 20 Mar 2020 22:06:54 +0800 Subject: [PATCH 314/492] fix the method name and input name --- tensorflow/compiler/tf2xla/kernels/reduction_ops.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index f6cc2f008e8..4f63c0d1b66 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -85,16 +85,16 @@ class MaxOp : public XlaReductionOp { public: explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) { - OP_REQUIRES_OK(ctx, TypeCheck(xla_reduction_type_)); + OP_REQUIRES_OK(ctx, PrimitiveTypeCheck(xla_reduction_type_)); } - Status TypeCheck(xla::PrimitiveType xla_reduction_type_) { - if (xla_reduction_type_ == xla::C64 || xla_reduction_type_ == xla::C128 || - xla_reduction_type_ == xla::TUPLE || - xla_reduction_type_ == xla::OPAQUE_TYPE) { + static Status PrimitiveTypeCheck(xla::PrimitiveType xla_reduction_type) { + if (xla_reduction_type == xla::C64 || xla_reduction_type == xla::C128 || + xla_reduction_type == xla::TUPLE || + xla_reduction_type == xla::OPAQUE_TYPE) { return errors::InvalidArgument( "Unsupported PrimitiveType in MaxOp: '", - xla::PrimitiveType_Name(xla_reduction_type_), "'"); + xla::PrimitiveType_Name(xla_reduction_type), "'"); } else { return Status::OK(); } From b7855c709849aff9fd788f61d16f96cff9bdcfbf Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 20 Mar 2020 08:53:16 -0700 Subject: [PATCH 315/492] Disable the thread sanitizer for some tests until they get fixed. PiperOrigin-RevId: 302036289 Change-Id: Ib24ae9ad19048dd9b37871efa8d3a6ab32579326 --- tensorflow/compiler/tests/BUILD | 10 ---------- tensorflow/compiler/xrt/tests/BUILD | 4 ---- 2 files changed, 14 deletions(-) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 3018fb5f857..77cd3dc074c 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -346,8 +346,6 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - # TODO(b/151948649): Fails on 2020-03-19. - "notsan", ], deps = [ ":xla_test", @@ -914,8 +912,6 @@ tf_xla_py_test( shard_count = 10, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - # TODO(b/151948649): Fails on 2020-03-19. - "notsan", ], deps = [ ":xla_test", @@ -1552,8 +1548,6 @@ cuda_py_test( tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", - # TODO(b/151948649): Fails on 2020-03-19. - "notsan", ], xla_enable_strict_auto_jit = False, xla_enabled = True, @@ -1579,8 +1573,6 @@ cuda_py_test( tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", - # TODO(b/151948649): Fails on 2020-03-19. - "notsan", ], xla_enable_strict_auto_jit = False, xla_enabled = True, @@ -1772,8 +1764,6 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - # TODO(b/151948649): Fails on 2020-03-19. - "notsan", ], deps = [ ":xla_test", diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index 918c802604a..2f1faf1cdf1 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -58,10 +58,6 @@ tf_cc_test( "--xla_test_device=XLA_CPU", "--xla_platform=CPU", ], - tags = [ - # TODO(b/151948649): Fails on 2020-03-19. - "notsan", - ], deps = [ ":raw_api_test_lib", "//tensorflow/compiler/jit:xla_cpu_device", From bb97495f7700b1d87fede5f35489628cba327660 Mon Sep 17 00:00:00 2001 From: Martin Wicke Date: Fri, 20 Mar 2020 08:59:37 -0700 Subject: [PATCH 316/492] Adding test for BaseResourceVariable.__repr__ *ding ding* shame shame shame *ding ding* PiperOrigin-RevId: 302037315 Change-Id: Ifddf2475af96487b522d6b89ae307bf9bf46d5a9 --- .../resource_variable_ops_test.py | 20 +++++++++++++++++++ .../python/ops/resource_variable_ops.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index cbd8f6a2ebe..41ce9eb8a57 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -208,6 +208,26 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, resource_variable_ops.assign_variable_op( handle, constant_op.constant([1.], dtype=dtypes.float32)) + def testRepr(self): + with context.eager_mode(): + v = resource_variable_ops.ResourceVariable(1) + text = "%r" % v + self.assertEqual( + "", text) + + def testReprUnavailable(self): + with context.eager_mode(): + v = resource_variable_ops.ResourceVariable(1) + + # Monkey-patch this variable to not have an available value + def broken_read(): + raise ValueError("This doesn't work") + + v.read_value = broken_read + text = "%r" % v + self.assertEqual(">", text) + def testUnprintableHandle(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 591e0f5786b..f99f886f210 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -1355,7 +1355,7 @@ class ResourceVariable(BaseResourceVariable): which is the initial value for the Variable. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound - to a shape before being used here.) + to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. From d46aa971be6e7d5a2ec2b9029f38c24bdfb8c277 Mon Sep 17 00:00:00 2001 From: Sachin Joglekar Date: Fri, 20 Mar 2020 09:07:06 -0700 Subject: [PATCH 317/492] Adds GraphTransformation to add QuantizeAndDequantize nodes in GPU graph PiperOrigin-RevId: 302038856 Change-Id: I009684ea5b611a3bfc05c88b4fd8a40c570cfd86 --- tensorflow/lite/delegates/gpu/common/BUILD | 1 + tensorflow/lite/delegates/gpu/common/model.h | 10 ++ .../gpu/common/transformations/BUILD | 31 ++++ .../transformations/add_quant_adjustments.cc | 110 ++++++++++++ .../transformations/add_quant_adjustments.h | 45 +++++ .../add_quant_adjustments_test.cc | 166 ++++++++++++++++++ 6 files changed, 363 insertions(+) create mode 100644 tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc create mode 100644 tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h create mode 100644 tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 08945c70d0b..08612e37b3e 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -92,6 +92,7 @@ cc_library( ":tensor", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:any", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/lite/delegates/gpu/common/model.h b/tensorflow/lite/delegates/gpu/common/model.h index f5aad207168..6989584a24c 100644 --- a/tensorflow/lite/delegates/gpu/common/model.h +++ b/tensorflow/lite/delegates/gpu/common/model.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/any.h" +#include "absl/types/optional.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -39,6 +40,13 @@ using ValueId = uint32_t; using NodeId = uint32_t; +// Used to emulate quantized behavior. +struct QuantizationParams { + float min = 0; + float max = 0; + float scale = 0; +}; + // Connects tensor's producer and operation that depends on this tensor. template struct Value { @@ -47,6 +55,8 @@ struct Value { const ValueId id; TensorType tensor; + + absl::optional quant_params; }; struct Operation { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/BUILD b/tensorflow/lite/delegates/gpu/common/transformations/BUILD index d0411473fae..3fe22f540ad 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/BUILD +++ b/tensorflow/lite/delegates/gpu/common/transformations/BUILD @@ -19,6 +19,37 @@ cc_library( ], ) +cc_library( + name = "add_quant_adjustments", + srcs = ["add_quant_adjustments.cc"], + hdrs = ["add_quant_adjustments.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + ], +) + +cc_test( + name = "add_quant_adjustments_test", + srcs = ["add_quant_adjustments_test.cc"], + deps = [ + ":add_quant_adjustments", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "@com_google_absl//absl/types:any", + "@com_google_absl//absl/types:optional", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "fuse_add_to_conv", srcs = ["fuse_add_to_conv.cc"], diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc new file mode 100644 index 00000000000..872c4bcd903 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { + +class AddQuantAdjustments : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final { + if (node->operation.type == + ToString(OperationType::QUANTIZE_AND_DEQUANTIZE)) { + return {TransformStatus::SKIPPED, ""}; + } + + bool transform_applied = false; + auto node_outputs = graph->FindOutputs(node->id); + for (auto output_value : node_outputs) { + // Skip if quantization doesn't apply. + if (!output_value->quant_params) continue; + auto consumers = graph->FindConsumers(output_value->id); + // No need to do anything if this isn't consumed by another node. + if (consumers.empty()) { + continue; + } + + // Add a new QuantizeAndDequantize node. + auto* quant_and_dequant_node = graph->NewNode(); + quant_and_dequant_node->operation.type = + ToString(OperationType::QUANTIZE_AND_DEQUANTIZE); + QuantizeAndDequantizeAttributes attr; + attr.min = output_value->quant_params.value().min; + attr.max = output_value->quant_params.value().max; + attr.scale = output_value->quant_params.value().scale; + quant_and_dequant_node->operation.attributes = attr; + + // Add one output Value for the new node. + // The tensor information should rename the same. + Value>* adjusted_value = graph->NewValue(); + adjusted_value->tensor = output_value->tensor; + Status status = + graph->SetProducer(quant_and_dequant_node->id, adjusted_value->id); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Could not create QuantizeAndDequantize node."}; + } + + // Replace output_value with adjusted_value on all consumers. + for (auto& consumer : consumers) { + status = graph->ReplaceInput(consumer->id, output_value->id, + adjusted_value->id); + if (!status.ok()) { + return {TransformStatus::INVALID, + absl::StrCat( + "Failed to associate quant-adjusted value for consumer: ", + status.message())}; + } + } + + // Add QuantizeAndDequantize node as a consumer of output_value. + status = graph->AddConsumer(quant_and_dequant_node->id, output_value->id); + if (!status.ok()) { + return {TransformStatus::INVALID, + absl::StrCat( + "Could not associate output to QuantizeAndDequantize: ", + status.message())}; + } + + // Remove quant params on output_value, to make the transformation + // idempotent. + output_value->quant_params.reset(); + transform_applied = true; + } + + if (transform_applied) { + return {TransformStatus::APPLIED, ""}; + } + return {TransformStatus::SKIPPED, ""}; + } +}; + +std::unique_ptr NewAddQuantAdjustments() { + return absl::make_unique(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h new file mode 100644 index 00000000000..6eb4aaaf029 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_QUANT_ADJUSTMENTS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_QUANT_ADJUSTMENTS_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +// This pass is used to support inference on quantized models with the GPU +// delegate. +// +// When delegating quantized models, we still run float-point inference on GPU +// under-the-hood. This is done by dequantizing inputs (at runtime) & constants +// (during delegation). +// However, intermediate tensors can still deviate from the original quantized +// inference, since activations may not follow the attributes set by the +// original quantizion parameters. +// To prevent this, we add "QuantizeAndDequantize" nodes for each node-output +// that was originally fixed-point: +// op1 -> op2 +// becomes +// op1 -> QuantizeAndDequantize -> op2 +std::unique_ptr NewAddQuantAdjustments(); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_QUANT_ADJUSTMENTS_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc new file mode 100644 index 00000000000..fc0913d2494 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h" + +#include +#include +#include "absl/types/any.h" +#include "absl/types/optional.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +namespace tflite { +namespace gpu { +namespace { + +void AddQuantParams(absl::optional* params, float min, + float max, float scale) { + params->emplace(); + params->value().min = min; + params->value().max = max; + params->value().scale = scale; +} + +// Scenario: +// -> Add -> +// +// Since there is only one node output with no consumers, no new node should be +// added. +TEST(AddQuantAdjustments, OneNode) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 8); + AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/1.0, + /*scale=*/0.004); + + Tensor add_tensor; + add_tensor.shape = Linear(8); + add_tensor.data.resize(8); + AddAttributes add_attr; + add_attr.param = add_tensor; + auto add_node = graph.NewNode(); + add_node->operation.type = ToString(OperationType::ADD); + add_node->operation.attributes = add_attr; + + ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok()); + + Value>* output; + AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/2.0, + /*scale=*/0.008); + ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); + output->tensor.shape = BHWC(1, 4, 4, 8); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + + auto transformation = NewAddQuantAdjustments(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("add_quant_adjustments", transformation.get()); + + EXPECT_EQ(1, graph.nodes().size()); + EXPECT_EQ(2, graph.values().size()); +} + +// Scenario: +// -> Add -> QuantizeAndDequantize -> Add -> +// | ^ +// | | +// ------------------------------ +// +// A new QuantizeAndDequantize should only be added after the left/first 'Add' +// op, and it should connect to both its consumers. +TEST(AddQuantAdjustments, GeneralCase) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 8); + AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/1.0, + /*scale=*/0.004); + + // First Add. + Tensor add_tensor; + add_tensor.shape = Linear(8); + add_tensor.data.resize(8); + AddAttributes add_attr; + add_attr.param = add_tensor; + auto add1_node = graph.NewNode(); + add1_node->operation.type = ToString(OperationType::ADD); + add1_node->operation.attributes = add_attr; + // QuantizeAndDequantize. + QuantizeAndDequantizeAttributes quant_attr; + quant_attr.min = -1.0; + quant_attr.max = 1.0; + quant_attr.scale = 0.008; + auto quant_node = graph.NewNode(); + quant_node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE); + quant_node->operation.attributes = quant_attr; + // Second Add. + auto add2_node = graph.NewNode(); + add2_node->operation.type = ToString(OperationType::ADD); + + // Connections. + ASSERT_TRUE(graph.AddConsumer(add1_node->id, input->id).ok()); + Value>* link1; + ASSERT_TRUE(ConnectTwoNodes(&graph, add1_node, quant_node, &link1).ok()); + AddQuantParams(&link1->quant_params, /*min=*/0.0, /*max=*/2.0, + /*scale=*/0.008); + link1->tensor.shape = BHWC(1, 4, 4, 8); + ASSERT_TRUE(graph.AddConsumer(add2_node->id, link1->id).ok()); + Value>* link2; + ASSERT_TRUE(ConnectTwoNodes(&graph, quant_node, add2_node, &link2).ok()); + AddQuantParams(&link2->quant_params, /*min=*/-1.0, /*max=*/1.0, + /*scale=*/0.008); + link2->tensor.shape = BHWC(1, 4, 4, 8); + Value>* output; + ASSERT_TRUE(AddOutput(&graph, add2_node, &output).ok()); + AddQuantParams(&output->quant_params, /*min=*/-1.0, /*max=*/1.0, + /*scale=*/0.008); + output->tensor.shape = BHWC(1, 4, 4, 8); + + ASSERT_EQ(3, graph.nodes().size()); + ASSERT_EQ(4, graph.values().size()); + + auto transformation = NewAddQuantAdjustments(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("add_quant_adjustments", transformation.get()); + + EXPECT_EQ(4, graph.nodes().size()); + EXPECT_EQ(5, graph.values().size()); + EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[0]->operation.type); + EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE), + graph.nodes()[1]->operation.type); + EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[2]->operation.type); + EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE), + graph.nodes()[3]->operation.type); + auto new_quant_attr = absl::any_cast( + graph.nodes()[3]->operation.attributes); + EXPECT_EQ(0.0, new_quant_attr.min); + EXPECT_EQ(2.0, new_quant_attr.max); + const auto& new_quant_consumers = graph.FindConsumers(graph.values()[4]->id); + EXPECT_EQ(2, new_quant_consumers.size()); + EXPECT_EQ(quant_node, new_quant_consumers[0]); + EXPECT_EQ(add2_node, new_quant_consumers[1]); + + // Transformation should be idempotent. + transformer.Apply("add_quant_adjustments", transformation.get()); + EXPECT_EQ(4, graph.nodes().size()); + EXPECT_EQ(5, graph.values().size()); +} + +} // namespace +} // namespace gpu +} // namespace tflite From e9650ec721ad94f08cf10242a36bd40fa44d1fe4 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Fri, 20 Mar 2020 09:27:53 -0700 Subject: [PATCH 318/492] Temporarily disable elemental_ir_emitter_test on windows PiperOrigin-RevId: 302042189 Change-Id: Ibdc03fd9da39e6d2df0398153953ed6aa9622515 --- tensorflow/compiler/xla/service/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 925afd689f7..6d470149ca8 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3744,6 +3744,9 @@ xla_test( "cpu", "gpu", ], + tags = [ + "no_windows", # TODO(b/152037541) + ], deps = [ ":hlo_parser", "//tensorflow/compiler/xla:execution_options_util", From 5cc23b291b5984098eda6153c58312b2004ce68b Mon Sep 17 00:00:00 2001 From: Cesar Crusius Date: Fri, 20 Mar 2020 09:55:57 -0700 Subject: [PATCH 319/492] Intenral Copybara change. PiperOrigin-RevId: 302046979 Change-Id: Ieddde65749676169894c7fa1a01ab21bb779b4e7 --- tensorflow/core/kernels/eigen_contraction_kernel.cc | 10 ++++++++-- tensorflow/core/platform/strcat.h | 2 +- tensorflow/lite/delegates/gpu/gl/kernels/resize.cc | 2 +- tensorflow/lite/delegates/gpu/metal/kernels/resize.cc | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/eigen_contraction_kernel.cc b/tensorflow/core/kernels/eigen_contraction_kernel.cc index aa6cb4b9cb9..4959651569c 100644 --- a/tensorflow/core/kernels/eigen_contraction_kernel.cc +++ b/tensorflow/core/kernels/eigen_contraction_kernel.cc @@ -28,7 +28,9 @@ limitations under the License. // the configuration through the environment variable. // // Example: -// bazel test --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test +// bazel test \ +// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ +// //path/to:test #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) @@ -37,7 +39,11 @@ namespace internal { // TODO(ezhulenev): This is a temporary workaround for disabling custom kernels // at runtime in tests. We should always rely on compile time flags for that. -// Example: ... --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test +// +// Example: +// bazel test \ +// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ +// //path/to:test EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels() { static bool use_custom_contraction_kernel = true; diff --git a/tensorflow/core/platform/strcat.h b/tensorflow/core/platform/strcat.h index 6b435dceca3..640355c9ea5 100644 --- a/tensorflow/core/platform/strcat.h +++ b/tensorflow/core/platform/strcat.h @@ -33,7 +33,7 @@ limitations under the License. // to your function, your callers will automatically convert bools, integers, // and floating point values to strings for you. // -// NOTE: Use of AlphaNum outside of the //strings package is unsupported except +// NOTE: Use of AlphaNum outside of the "strings" package is unsupported except // for the specific case of function parameters of type "AlphaNum" or "const // AlphaNum &". In particular, instantiating AlphaNum directly as a stack // variable is not supported. diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc index b8949e41426..33d59518987 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc @@ -93,7 +93,7 @@ class Resize : public NodeShader { st.xy = max(icoord_floor, ivec2(0, 0)); st.zw = min(icoord_floor + ivec2(1, 1), borders); - vec2 t = coord - coord_floor; //interpolating factors + vec2 t = coord - coord_floor; // interpolating factors vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$; vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc index 2ed75ad65b1..24d7bcf13bc 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc @@ -54,7 +54,7 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) { int4 st; st.xy = max(itex_coord_floor, int2(0, 0)); st.zw = min(itex_coord_floor + int2(1, 1), borders); - const float2 t = tex_coord - tex_coord_floor; //interpolating factors + const float2 t = tex_coord - tex_coord_floor; // interpolating factors const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x; const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z; const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x; From 73e780cbc2263130b3fc562e1f0e7bd5695e4936 Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Fri, 20 Mar 2020 10:24:02 -0700 Subject: [PATCH 320/492] Fix inference_interface tests PiperOrigin-RevId: 302052660 Change-Id: Id1a5adb077d1736d8f79b4d9b98ed3e1ff58126d --- tensorflow/tools/android/inference_interface/BUILD | 5 +++++ .../contrib/android/TensorFlowInferenceInterface.java | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/tools/android/inference_interface/BUILD b/tensorflow/tools/android/inference_interface/BUILD index d82d932c664..cbd161f05b3 100644 --- a/tensorflow/tools/android/inference_interface/BUILD +++ b/tensorflow/tools/android/inference_interface/BUILD @@ -87,3 +87,8 @@ cc_binary( LINKER_SCRIPT, ], ) + +cc_library( + name = "android_tensorflow_inference_native", + srcs = if_android([":libtensorflow_inference.so"]), +) diff --git a/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index abddadac5bc..618c772e92d 100644 --- a/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -543,7 +543,8 @@ public class TensorFlowInferenceInterface { } catch (UnsatisfiedLinkError e2) { throw new RuntimeException( "Native TF methods not found; check that the correct native" - + " libraries are present in the APK."); + + " libraries are present in the APK: " + + e2); } } } From ac8be4e7d0dddef822c240f1900df510b5a4d99f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 10:42:54 -0700 Subject: [PATCH 321/492] Internal change PiperOrigin-RevId: 302056585 Change-Id: I515a51033729ef52f8e47a203f9402769575ec4b --- tensorflow/core/kernels/eigen_contraction_kernel.cc | 10 ++-------- tensorflow/core/platform/strcat.h | 2 +- tensorflow/lite/delegates/gpu/gl/kernels/resize.cc | 2 +- tensorflow/lite/delegates/gpu/metal/kernels/resize.cc | 2 +- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/kernels/eigen_contraction_kernel.cc b/tensorflow/core/kernels/eigen_contraction_kernel.cc index 4959651569c..aa6cb4b9cb9 100644 --- a/tensorflow/core/kernels/eigen_contraction_kernel.cc +++ b/tensorflow/core/kernels/eigen_contraction_kernel.cc @@ -28,9 +28,7 @@ limitations under the License. // the configuration through the environment variable. // // Example: -// bazel test \ -// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ -// //path/to:test +// bazel test --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) @@ -39,11 +37,7 @@ namespace internal { // TODO(ezhulenev): This is a temporary workaround for disabling custom kernels // at runtime in tests. We should always rely on compile time flags for that. -// -// Example: -// bazel test \ -// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ -// //path/to:test +// Example: ... --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels() { static bool use_custom_contraction_kernel = true; diff --git a/tensorflow/core/platform/strcat.h b/tensorflow/core/platform/strcat.h index 640355c9ea5..6b435dceca3 100644 --- a/tensorflow/core/platform/strcat.h +++ b/tensorflow/core/platform/strcat.h @@ -33,7 +33,7 @@ limitations under the License. // to your function, your callers will automatically convert bools, integers, // and floating point values to strings for you. // -// NOTE: Use of AlphaNum outside of the "strings" package is unsupported except +// NOTE: Use of AlphaNum outside of the //strings package is unsupported except // for the specific case of function parameters of type "AlphaNum" or "const // AlphaNum &". In particular, instantiating AlphaNum directly as a stack // variable is not supported. diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc index 33d59518987..b8949e41426 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc @@ -93,7 +93,7 @@ class Resize : public NodeShader { st.xy = max(icoord_floor, ivec2(0, 0)); st.zw = min(icoord_floor + ivec2(1, 1), borders); - vec2 t = coord - coord_floor; // interpolating factors + vec2 t = coord - coord_floor; //interpolating factors vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$; vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc index 24d7bcf13bc..2ed75ad65b1 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc @@ -54,7 +54,7 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) { int4 st; st.xy = max(itex_coord_floor, int2(0, 0)); st.zw = min(itex_coord_floor + int2(1, 1), borders); - const float2 t = tex_coord - tex_coord_floor; // interpolating factors + const float2 t = tex_coord - tex_coord_floor; //interpolating factors const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x; const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z; const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x; From f4a7fea8fe276a8b762e6c8a4c361030883f3ac7 Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Fri, 20 Mar 2020 10:57:04 -0700 Subject: [PATCH 322/492] TFLM: Use scratch buffer in FC. This will reduce the inference latency on FullyConnected layers as we no longer need to re-calculate OpData for each inference. In the same time, the arena requirement will only slightly increase. This CL also contains some changes on the testing infra so that the single op test can also take account the kernel memory allocations. This is based a dumb implementation so easy to understand and debug. PiperOrigin-RevId: 302059607 Change-Id: I0cde36ebb191d6cb17bedece1931521785746e07 --- .../lite/micro/kernels/fully_connected.cc | 25 +---- .../micro/kernels/fully_connected_test.cc | 1 + tensorflow/lite/micro/test_helpers.h | 3 + tensorflow/lite/micro/testing/BUILD | 1 - tensorflow/lite/micro/testing/test_utils.cc | 103 ++---------------- tensorflow/lite/micro/testing/test_utils.h | 63 +---------- 6 files changed, 21 insertions(+), 175 deletions(-) diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 91df80b328c..64bf788f538 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -71,35 +71,18 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - OpData* data = nullptr; - TfLiteStatus status = context->AllocatePersistentBuffer( - context, sizeof(OpData), reinterpret_cast(&data)); - if (status != kTfLiteOk || data == nullptr) { - return nullptr; - } - return data; + return nullptr; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - OpData* data = reinterpret_cast(node->user_data); - auto* params = - reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); - - TfLiteType data_type = input->type; - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, - filter, bias, output, data)); - return kTfLiteOk; } @@ -195,7 +178,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - OpData* data = reinterpret_cast(node->user_data); + TfLiteType data_type = input->type; + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 4687ae89108..0859e4af591 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -49,6 +49,7 @@ void TestFullyConnectedFloat( TfLiteContext context; PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + ::tflite::ops::micro::AllOpsResolver resolver; const TfLiteRegistration* registration = resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index f4e7fa8dfba..010e1f9e336 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -58,6 +58,9 @@ CreateFlatbufferBuffers(); // Performs a simple string comparison without requiring standard C library. int TestStrcmp(const char* a, const char* b); +// Wrapper to forward kernel errors to the interpreter's error reporter. +void ReportOpError(struct TfLiteContext* context, const char* format, ...); + void PopulateContext(TfLiteTensor* tensors, int tensors_size, TfLiteContext* context); diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 42f25f0e8b0..01bdffc6892 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -17,7 +17,6 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", ], diff --git a/tensorflow/lite/micro/testing/test_utils.cc b/tensorflow/lite/micro/testing/test_utils.cc index 5fd0161d621..9f7803fcf62 100644 --- a/tensorflow/lite/micro/testing/test_utils.cc +++ b/tensorflow/lite/micro/testing/test_utils.cc @@ -15,107 +15,24 @@ limitations under the License. #include "tensorflow/lite/micro/testing/test_utils.h" -#include "tensorflow/lite/kernels/internal/compatibility.h" - namespace tflite { namespace testing { -TfLiteStatus FakeAllocator::AllocatePersistentBuffer(size_t bytes, void** ptr) { - uint8_t* addr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); - *ptr = addr; - return kTfLiteOk; -} - -TfLiteStatus FakeAllocator::RequestScratchBufferInArena(int node_idx, - size_t bytes, - int* buffer_idx) { - if (scratch_buffers_count_ >= max_scratch_buffers_count_) { - return kTfLiteError; - } - uint8_t* ptr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); - scratch_buffers_[scratch_buffers_count_] = ptr; - *buffer_idx = scratch_buffers_count_; - scratch_buffers_count_++; - return kTfLiteOk; -} - -void FakeAllocator::Reset() { - // Get A fresh memory allocator. - memory_allocator_ = CreateInPlaceSimpleMemoryAllocator(arena_, arena_size_); - TFLITE_DCHECK_NE(memory_allocator_, nullptr); - - // Allocate enough space holding pointers to the scrtach buffers. - scratch_buffers_ = - reinterpret_cast(memory_allocator_->AllocateFromTail( - sizeof(uint8_t*) * max_scratch_buffers_count_, alignof(uint8_t*))); - TFLITE_DCHECK_NE(scratch_buffers_, nullptr); - - scratch_buffers_count_ = 0; -} - -void* FakeAllocator::GetScratchBuffer(int buffer_idx) { - if (buffer_idx < 0 || buffer_idx >= scratch_buffers_count_) { - return nullptr; - } - return scratch_buffers_[buffer_idx]; -} - -TfLiteStatus FakeContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx, - size_t bytes, - void** ptr) { - return reinterpret_cast(ctx->impl_) - ->allocator_->AllocatePersistentBuffer(bytes, ptr); -} - -TfLiteStatus FakeContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx, - size_t bytes, - int* buffer_idx) { - FakeContextHelper* helper = reinterpret_cast(ctx->impl_); - // FakeAllocator doesn't do memory reusing so it doesn't need node_idx to - // calculate the lifetime of the scratch buffer. - int node_idx = -1; - return helper->allocator_->RequestScratchBufferInArena(node_idx, bytes, - buffer_idx); -} - -void* FakeContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) { - return reinterpret_cast(ctx->impl_) - ->allocator_->GetScratchBuffer(buffer_idx); -} - -void FakeContextHelper::ReportOpError(struct TfLiteContext* context, - const char* format, ...) { - FakeContextHelper* helper = static_cast(context->impl_); - va_list args; - va_start(args, format); - TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args); - va_end(args); -} - -namespace { -constexpr size_t kArenaSize = 10000; -constexpr int kMaxScratchBufferCount = 32; -uint8_t arena[kArenaSize]; -} // namespace - // TODO(b/141330728): Move this method elsewhere as part clean up. void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context) { - // This should be a large enough arena for each test cases. - static FakeAllocator allocator(arena, kArenaSize, kMaxScratchBufferCount); - static FakeContextHelper helper(error_reporter, &allocator); - // Reset the allocator so that it's ready for another test. - allocator.Reset(); - - *context = {}; - context->recommended_num_threads = 1; context->tensors_size = tensors_size; context->tensors = tensors; - context->impl_ = static_cast(&helper); - context->AllocatePersistentBuffer = helper.AllocatePersistentBuffer; - context->RequestScratchBufferInArena = helper.RequestScratchBufferInArena; - context->GetScratchBuffer = helper.GetScratchBuffer; - context->ReportError = helper.ReportOpError; + context->impl_ = static_cast(error_reporter); + context->GetExecutionPlan = nullptr; + context->ResizeTensor = nullptr; + context->ReportError = ReportOpError; + context->AddTensors = nullptr; + context->GetNodeAndRegistration = nullptr; + context->ReplaceNodeSubsetsWithDelegateKernels = nullptr; + context->recommended_num_threads = 1; + context->GetExternalContext = nullptr; + context->SetExternalContext = nullptr; for (int i = 0; i < tensors_size; ++i) { if (context->tensors[i].is_variable) { diff --git a/tensorflow/lite/micro/testing/test_utils.h b/tensorflow/lite/micro/testing/test_utils.h index f7f5dff6bb1..7aa1e9d488f 100644 --- a/tensorflow/lite/micro/testing/test_utils.h +++ b/tensorflow/lite/micro/testing/test_utils.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/tensor_utils.h" #include "tensorflow/lite/micro/micro_utils.h" -#include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -96,67 +95,7 @@ inline int32_t F2Q32(const float value, const float scale) { return static_cast(quantized); } -// A fake version of MemoryAllocator that allocates everything from the tail -// without static memory planning or reusing. -// TODO(b/150260678): Consider splitting this into its own file and inherit from -// the same public interface as MicroAllocator. -class FakeAllocator { - public: - FakeAllocator(uint8_t* arena, size_t arena_size, - size_t max_scratch_buffers_count) - : arena_(arena), - arena_size_(arena_size), - max_scratch_buffers_count_(max_scratch_buffers_count) { - Reset(); - } - - TfLiteStatus AllocatePersistentBuffer(size_t bytes, void** ptr); - TfLiteStatus RequestScratchBufferInArena(int node_idx, size_t bytes, - int* buffer_idx); - void* GetScratchBuffer(int buffer_idx); - - // Reset the allocator to the intial state. - void Reset(); - - private: - uint8_t* arena_; - size_t arena_size_; - size_t max_scratch_buffers_count_; - - SimpleMemoryAllocator* memory_allocator_; - // An array of buffer pointers. - uint8_t** scratch_buffers_; - size_t scratch_buffers_count_ = 0; - static constexpr size_t kBufferAlignment = 16; -}; - -// A fake implementation of ContextHelper. Instead of forwarding requests to -// MicroAllocator, it calls into FakeAllocator. -// PopulateContext will point context->impl_ to an instance of this class. -// TODO(b/150260678): Consider moving this into the same file as FakeAllocator. -class FakeContextHelper { - public: - explicit FakeContextHelper(ErrorReporter* error_reporter, - FakeAllocator* allocator) - : allocator_(allocator), error_reporter_(error_reporter) {} - - static TfLiteStatus AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes, - void** ptr); - - static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx, - size_t bytes, - int* buffer_idx); - - static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx); - - static void ReportOpError(struct TfLiteContext* context, const char* format, - ...); - - private: - FakeAllocator* allocator_; - ErrorReporter* error_reporter_; -}; - +// TODO(b/141330728): Move this method elsewhere as part clean up. void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context); From f695f0e8a72ec52ce99e12002dc50639cc9ca4ba Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Fri, 20 Mar 2020 11:07:54 -0700 Subject: [PATCH 323/492] Temporarily disable some multi gpu testing PiperOrigin-RevId: 302062080 Change-Id: Ib0fa872a93c0bdd363b33656195c7f83c5c07ff2 --- tensorflow/python/distribute/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 0667913bc76..8b5308c4d52 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -938,7 +938,7 @@ distribute_py_test( srcs = ["values_test.py"], main = "values_test.py", tags = [ - "multi_and_single_gpu", + # "multi_and_single_gpu", # b/151865826 ], tpu_tags = [ "no_oss", # Target too big to run serially reliably. @@ -1206,7 +1206,7 @@ cuda_py_test( srcs = ["mirrored_strategy_test.py"], shard_count = 5, tags = [ - "multi_and_single_gpu", + # "multi_and_single_gpu", # b/151862653 "no_windows_gpu", # TODO(b/130551176) ], deps = [ From 7d7a5c9b4feaa2a66f08e1f217d63aba1d23c79c Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Fri, 20 Mar 2020 11:19:03 -0700 Subject: [PATCH 324/492] Avoid creating tensors as global variables in dataset_test Globals and static initializers can be ordered differently, causing failures in tensor creation. PiperOrigin-RevId: 302064361 Change-Id: Icce2e3deadc5518b4a69f624bcecbbd0ccd2507d --- tensorflow/core/framework/dataset_test.cc | 57 +++++++++++++---------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/tensorflow/core/framework/dataset_test.cc b/tensorflow/core/framework/dataset_test.cc index b1e12379538..9dbb3be7faf 100644 --- a/tensorflow/core/framework/dataset_test.cc +++ b/tensorflow/core/framework/dataset_test.cc @@ -37,7 +37,10 @@ enum DataTypeTest { struct DatasetTestParam { const DataTypeTest type; - const std::vector tensor; + // This has to be a function pointer, to make sure the tensors we use as + // parameters of the test case do not become globals. Ordering of static + // initializers and globals can cause errors in the test. + std::function()> tensor_factory; const int64 expected_bytes; }; @@ -48,38 +51,44 @@ TEST_P(DatasetTestTotalBytes, TestTotalBytes) { const DatasetTestParam& test_case = GetParam(); if (test_case.type == _tf_string_) { // TotalBytes() is approximate and gives an upper bound for strings - EXPECT_LE(data::GetTotalBytes(test_case.tensor), test_case.expected_bytes); + EXPECT_LE(data::GetTotalBytes(test_case.tensor_factory()), + test_case.expected_bytes); } else { - EXPECT_EQ(data::GetTotalBytes(test_case.tensor), test_case.expected_bytes); + EXPECT_EQ(data::GetTotalBytes(test_case.tensor_factory()), + test_case.expected_bytes); } } -std::vector tensor_tf_int_32s{test::AsTensor({1, 2, 3, 4, 5}), - test::AsTensor({1, 2, 3, 4})}; +std::vector tensor_tf_int_32s() { + return {test::AsTensor({1, 2, 3, 4, 5}), + test::AsTensor({1, 2, 3, 4})}; +} -std::vector tensor_tf_int_64s{test::AsTensor({1, 2, 3, 4, 5}), - test::AsTensor({10, 12})}; +std::vector tensor_tf_int_64s() { + return {test::AsTensor({1, 2, 3, 4, 5}), + test::AsTensor({10, 12})}; +} -std::vector tensor_tf_float_s{ - test::AsTensor({1.0, 2.0, 3.0, 4.0})}; +std::vector tensor_tf_float_s() { + return {test::AsTensor({1.0, 2.0, 3.0, 4.0})}; +} -std::vector tensor_tf_double_s{ - test::AsTensor({100.0}), test::AsTensor({200.0}), - test::AsTensor({400.0}), test::AsTensor({800.0})}; +std::vector tensor_tf_double_s() { + return {test::AsTensor({100.0}), test::AsTensor({200.0}), + test::AsTensor({400.0}), test::AsTensor({800.0})}; +} const tstring str = "test string"; // NOLINT -std::vector tensor_strs{test::AsTensor({str})}; +std::vector tensor_strs() { return {test::AsTensor({str})}; } -const DatasetTestParam test_cases[] = { - {_tf_int_32, tensor_tf_int_32s, 4 /*bytes*/ * 9 /*elements*/}, - {_tf_int_64, tensor_tf_int_64s, 8 /*bytes*/ * 7 /*elements*/}, - {_tf_float_, tensor_tf_float_s, 4 /*bytes*/ * 4 /*elements*/}, - {_tf_double_, tensor_tf_double_s, 8 /*bytes*/ * 4 /*elements*/}, - {_tf_string_, tensor_strs, - static_cast(sizeof(str) + str.size()) /*bytes*/}, -}; - -INSTANTIATE_TEST_SUITE_P(DatasetTestTotalBytes, DatasetTestTotalBytes, - ::testing::ValuesIn(test_cases)); +INSTANTIATE_TEST_SUITE_P( + DatasetTestTotalBytes, DatasetTestTotalBytes, + ::testing::ValuesIn(std::vector{ + {_tf_int_32, tensor_tf_int_32s, 4 /*bytes*/ * 9 /*elements*/}, + {_tf_int_64, tensor_tf_int_64s, 8 /*bytes*/ * 7 /*elements*/}, + {_tf_float_, tensor_tf_float_s, 4 /*bytes*/ * 4 /*elements*/}, + {_tf_double_, tensor_tf_double_s, 8 /*bytes*/ * 4 /*elements*/}, + {_tf_string_, tensor_strs, + static_cast(sizeof(str) + str.size()) /*bytes*/}})); } // namespace tensorflow From 0ffd38260a885e5d37ae1ea647ad8d7e58e12efb Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Fri, 20 Mar 2020 11:19:45 -0700 Subject: [PATCH 325/492] Extend error message when super new GraphDef import fails PiperOrigin-RevId: 302064508 Change-Id: Ibc3174157ebf91811628c6d49fb225d991d1f6c9 --- tensorflow/core/BUILD | 1 + tensorflow/core/graph/graph_constructor.cc | 43 +++++++++++++++++-- .../core/graph/graph_constructor_test.cc | 32 +++++++++++++- 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7b995af7656..c1b889751d7 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2520,6 +2520,7 @@ tf_cuda_library( "//third_party/eigen3", ] + if_static([ "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ]), ) diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 39bb0514c34..6bb1772e02d 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -457,6 +458,33 @@ class NodeDefMovingGraphConstructor : public GraphConstructor { std::vector is_consumed_; }; +bool ForwardCompatibilityWindowPassed(const VersionDef& versions) { + // TF_GRAPH_DEF_VERSION is incremented daily. + // TF has a 3 week forward compatibility guarantee. + return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21; +} + +Status MaybeAppendVersionWarning(const VersionDef* versions, + const Status& import_status) { + if (versions && ForwardCompatibilityWindowPassed(*versions)) { + return Status( + import_status.code(), + absl::StrCat( + "Converting GraphDef to Graph has failed. The binary trying to " + "import the GraphDef was built when GraphDef version was ", + TF_GRAPH_DEF_VERSION, + ". The GraphDef was produced by a binary built when GraphDef " + "version was ", + versions->producer(), + ". The difference between these versions is larger than " + "TensorFlow's forward compatibility guarantee. The following error " + "might be due to the binary trying to import the GraphDef being " + "too old: ", + import_status.error_message())); + } + return import_status; +} + /* static */ Status GraphConstructor::Construct( const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, @@ -471,8 +499,11 @@ class NodeDefMovingGraphConstructor : public GraphConstructor { NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g, refiner, return_tensors, return_nodes, missing_unused_input_map_keys); - const Status s = c.TryImport(); - if (!s.ok()) c.Undo(); + Status s = c.TryImport(); + if (!s.ok()) { + c.Undo(); + s = MaybeAppendVersionWarning(versions, s); + } return s; } @@ -484,11 +515,15 @@ class NodeDefMovingGraphConstructor : public GraphConstructor { TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION, TF_GRAPH_DEF_VERSION_MIN_PRODUCER, "GraphDef", "graph")); + VersionDef version_def = graph_def.versions(); NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner, return_tensors, return_nodes, missing_unused_input_map_keys); - const Status s = c.TryImport(); - if (!s.ok()) c.Undo(); + Status s = c.TryImport(); + if (!s.ok()) { + c.Undo(); + s = MaybeAppendVersionWarning(&version_def, s); + } return s; } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index cf740156a40..8475032665e 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include + #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb.h" @@ -51,7 +52,8 @@ class GraphConstructorTest : public ::testing::Test { } void ExpectError(const string& gdef_ascii, - const std::vector& expected_error_strs) { + const std::vector& expected_error_strs, + string not_expected_error_str = "") { // Used to verify that errors don't change graph const string original_graph_description = GraphDebugString(); @@ -65,6 +67,13 @@ class GraphConstructorTest : public ::testing::Test { << "Expected to find '" << error << "' in " << status; } + if (!not_expected_error_str.empty()) { + EXPECT_TRUE(status.error_message().find(not_expected_error_str) == + string::npos) + << "Expected not to find '" << not_expected_error_str << "' in " + << status; + } + EXPECT_EQ(original_graph_description, GraphDebugString()); } @@ -825,6 +834,27 @@ TEST_F(GraphConstructorTest, VersionGraph) { ExpectVersions(TF_GRAPH_DEF_VERSION_MIN_CONSUMER, TF_GRAPH_DEF_VERSION); } +TEST_F(GraphConstructorTest, ForwardCompatError) { + ExpectError( + strings::StrCat( + "node { name: 'a:b' op: 'ABC' }\n" // 'a:b' is an invalid name. + "versions { producer: ", + TF_GRAPH_DEF_VERSION + 22, + " min_consumer: ", TF_GRAPH_DEF_VERSION_MIN_CONSUMER, "}"), + {"forward compatibility guarantee"}); +} + +TEST_F(GraphConstructorTest, NoForwardCompatError) { + ExpectError( + strings::StrCat( + "node { name: 'a:b' op: 'ABC' }\n" // 'a:b' is an invalid name. + "versions { producer: ", + TF_GRAPH_DEF_VERSION + 21, + " min_consumer: ", TF_GRAPH_DEF_VERSION_MIN_CONSUMER, "}"), + {"Node name contains invalid characters"}, + "forward compatibility guarantee"); +} + TEST_F(GraphConstructorTest, LowVersion) { ExpectError(strings::StrCat("versions { producer: ", -1, " }"), {strings::StrCat("GraphDef producer version -1 below min " From 488f27046cd07d2825d61bcb9f0df452df867111 Mon Sep 17 00:00:00 2001 From: Jing Pu Date: Fri, 20 Mar 2020 11:22:03 -0700 Subject: [PATCH 326/492] Add CreateLegalizeHloToTfPass to a public header. PiperOrigin-RevId: 302064923 Change-Id: I43d5311816e30dcbb0d1c2fcd7b94d92369d231b --- tensorflow/compiler/mlir/tensorflow/transforms/passes.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 332e181c9ed..92d15e13621 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -117,6 +117,9 @@ std::unique_ptr> CreateStackOpsDecompositionPass(); // Converts tensor list operations into operations on buffers and sizes. Needs // static shapes and known max element count. std::unique_ptr> CreateTensorListOpsDecompositionPass(); + +// Create a pass that legalize HLO to TF dialect. +std::unique_ptr> CreateLegalizeHloToTfPass(); } // namespace TF namespace TFControlFlow { From 2fd08c48a3860e89e31c8d72e8dcd48238b75f22 Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Fri, 20 Mar 2020 11:27:19 -0700 Subject: [PATCH 327/492] Enable bluepill as part of the presubmit checks. PiperOrigin-RevId: 302065920 Change-Id: I4ce1aa3752986a3b4ef069923db9aeaf95fe1826 --- tensorflow/lite/micro/tools/ci_build/test_all.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/tools/ci_build/test_all.sh b/tensorflow/lite/micro/tools/ci_build/test_all.sh index e46c041a84e..5172f950eac 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_all.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_all.sh @@ -37,9 +37,8 @@ echo "Starting to run micro tests at `date`" echo "Running Arduino tests at `date`" tensorflow/lite/micro/tools/ci_build/test_arduino.sh -# TODO(b/151695791): reenable once the root cause is fixed. -#echo "Running bluepill tests at `date`" -#tensorflow/lite/micro/tools/ci_build/test_bluepill.sh +echo "Running bluepill tests at `date`" +tensorflow/lite/micro/tools/ci_build/test_bluepill.sh echo "Running mbed tests at `date`" tensorflow/lite/micro/tools/ci_build/test_mbed.sh PRESUBMIT From 16051cb33c4b3b0d93d33390ceb85adb4a74ee1f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 11:32:07 -0700 Subject: [PATCH 328/492] Make elementwise operations with two inputs support all of the cases: elementwise, scalar, broadcast and const vector. PiperOrigin-RevId: 302066884 Change-Id: I94a7497f006b466cc6d7a1b1fdba090b4ef30a00 --- .../delegates/gpu/common/model_builder.cc | 206 +++++---- .../lite/delegates/gpu/gl/kernels/BUILD | 1 + .../delegates/gpu/gl/kernels/elementwise.cc | 183 +++----- .../gpu/gl/kernels/elementwise_test.cc | 427 ++++++++++++++---- 4 files changed, 523 insertions(+), 294 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 2a03ff9ff14..b37c3542413 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -236,6 +236,12 @@ int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context, return number_of_runtime_inputs; } +int GetNumberOfConstInputsForNode(const TfLiteContext* context, + const TfLiteNode* tflite_node) { + return tflite_node->inputs->size - + GetNumberOfRuntimeInputsForNode(context, tflite_node); +} + int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context, const TfLiteNode* tflite_node) { int number_of_runtime_outputs = 0; @@ -258,6 +264,42 @@ Status CheckTensorIsAvailable(const TfLiteContext* context, return OkStatus(); } +Status CheckInputsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, int runtime_inputs, + int outputs) { + int runtime_inputs_from_model = + GetNumberOfRuntimeInputsForNode(context, tflite_node); + if (runtime_inputs_from_model != runtime_inputs) { + return InternalError(absl::StrFormat( + "Expected %d runtime input tensor(s), but node has %d runtime " + "input(s).", + runtime_inputs, runtime_inputs_from_model)); + } + int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); + if (runtime_outputs != outputs) { + return InternalError( + absl::StrFormat("Expected %d output tensor(s), but node has %d " + "output(s).", + outputs, runtime_outputs)); + } + return OkStatus(); +} + +Status CheckInputsConstsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, + int runtime_inputs, int const_inputs, + int outputs) { + int const_inputs_from_model = + GetNumberOfConstInputsForNode(context, tflite_node); + if (const_inputs_from_model != const_inputs) { + return InternalError(absl::StrFormat( + "Expected %d const input tensor(s), but node has %d const " + "input(s).", + const_inputs, const_inputs_from_model)); + } + return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs); +} + class ObjectReader { public: ObjectReader(GraphFloat32* graph, TfLiteContext* context, @@ -367,6 +409,13 @@ class ObjectReader { : nullptr; } + Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node, + int runtime_inputs, int const_inputs, + int outputs) { + return CheckInputsConstsOutputs(context_, tflite_node, runtime_inputs, + const_inputs, outputs); + } + private: GraphFloat32* graph_ = nullptr; const TfLiteContext* context_ = nullptr; @@ -374,59 +423,6 @@ class ObjectReader { std::vector>*>* tensor_to_value_; }; -Status CheckInputsOutputs(const TfLiteContext* context, - const TfLiteNode* tflite_node, int inputs, - int outputs) { - int runtime_inputs = GetNumberOfRuntimeInputsForNode(context, tflite_node); - if (runtime_inputs != inputs) { - return InternalError( - absl::StrFormat("Expected %d input tensor(s), but node has %d runtime " - "input(s).", - inputs, runtime_inputs)); - } - int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); - if (runtime_outputs != outputs) { - return InternalError( - absl::StrFormat("Expected %d output tensor(s), but node has %d runtime " - "output(s).", - outputs, runtime_outputs)); - } - return OkStatus(); -} - -// The function checks input tensors including 1 constant tensor. -Status CheckInputsOutputsAllowingOneConstInput(const TfLiteContext* context, - const TfLiteNode* tflite_node, - int inputs, int outputs) { - int number_of_const_inputs = 0; - int number_of_runtime_inputs = 0; - for (int i = 0; i < tflite_node->inputs->size; i++) { - if (IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) { - number_of_const_inputs++; - } else { - number_of_runtime_inputs++; - } - } - if (tflite_node->inputs->size != inputs) { - return InternalError(absl::StrFormat( - "Expected %d input tensor(s), but node has %d input(s).", inputs, - tflite_node->inputs->size)); - } - if (number_of_const_inputs > 1) { - return InternalError(absl::StrFormat( - "Expected 1 const input tensor, but node has %d const input(s).", - number_of_const_inputs)); - } - int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); - if (runtime_outputs != outputs) { - return InternalError( - absl::StrFormat("Expected %d output tensor(s), but node has %d runtime " - "output(s).", - outputs, runtime_outputs)); - } - return OkStatus(); -} - // A parser responsible for parsing TFLite operation and adding it to a graph. class TFLiteOperationParser { public: @@ -893,8 +889,8 @@ class Conv2DOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); TfLiteConvParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -977,8 +973,8 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); TfLiteDepthwiseConvParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -1095,16 +1091,20 @@ class ElementwiseOperationParser : public TFLiteOperationParser { const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); if (IsOneArgumentOperation()) { - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1, - /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, + /*runtime_inputs=*/1, + /*const_inputs=*/0, + /*outputs=*/1)); } else if (IsTwoArgumentOperation()) { - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/2, - /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, + /*runtime_inputs=*/2, + /*const_inputs=*/0, + /*outputs=*/1)); } else if (IsTwoArgumentOperationWithConst()) { - RETURN_IF_ERROR(CheckInputsOutputsAllowingOneConstInput(context, - tflite_node, - /*inputs=*/2, - /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, + /*runtime_inputs=*/1, + /*const_inputs=*/1, + /*outputs=*/1)); } else { return InvalidArgumentError("Op can only handle 1 or 2 operand(s)."); } @@ -1120,8 +1120,17 @@ class ElementwiseOperationParser : public TFLiteOperationParser { node->operation.type = ToString(operation_type_); if (IsOneArgumentOperation()) { + RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, + /*runtime_inputs=*/1, + /*const_inputs=*/0, + /*outputs=*/1)); + RETURN_IF_ERROR(reader->AddInput(node, 0)); } else if (IsTwoArgumentOperation()) { + RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, + /*runtime_inputs=*/2, + /*const_inputs=*/0, + /*outputs=*/1)); if (tflite_node->inputs->size != 2) { return InvalidArgumentError("Applies only two input tensors"); } @@ -1156,14 +1165,12 @@ class ElementwiseOperationParser : public TFLiteOperationParser { MaybeFuseActivationToTheSingleOutput(activation, graph, node)); } } else if (IsTwoArgumentOperationWithConst()) { + RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, + /*runtime_inputs=*/1, + /*const_inputs=*/1, + /*outputs=*/1)); ElementwiseAttributes attr; RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); - auto const_vector = - absl::get_if<::tflite::gpu::Tensor>( - &attr.param); - if (const_vector) { - return InvalidArgumentError("Constant vector is not supported"); - } node->operation.attributes = std::move(attr); } else { return InvalidArgumentError("Incorrect operation type passed"); @@ -1228,6 +1235,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser { switch (operation_type_) { case OperationType::MINIMUM: case OperationType::MAXIMUM: + case OperationType::SUB: return true; default: return false; @@ -1311,7 +1319,7 @@ class HardSwishOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration*) final { - return CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } @@ -1350,7 +1358,8 @@ class LSTMOperationParser : public TFLiteOperationParser { const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2)); // TODO(eignasheva): Fix bad check. - // RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/5, + // RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + // /*runtime_inputs=*/5, // /*outputs=*/4)); TfLiteLSTMParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -1599,8 +1608,8 @@ class PadOperationParser : public TFLiteOperationParser { } } RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); return OkStatus(); } @@ -1648,11 +1657,13 @@ class Pooling2DOperationParser : public TFLiteOperationParser { TfLitePoolParams* tf_options = nullptr; auto status = RetrieveCustomInitialData(tflite_node, &tf_options); if (status.ok()) { // custom case with indices as a second output - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/2)); } else { // common pooling with 1 output RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); } RETURN_IF_ERROR(CheckKernelsAndStrides( @@ -1752,8 +1763,8 @@ class ReshapeOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); // TODO(eignasheva): add shape checking return OkStatus(); } @@ -1786,8 +1797,8 @@ class Resize2DOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node)); bool align_corners; @@ -1974,8 +1985,8 @@ class SoftmaxOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); TfLiteSoftmaxParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->beta != 1) { @@ -2018,8 +2029,8 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); // TODO(impjdi): Dims check. TfLiteSpaceToDepthParams* s2d_params = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params)); @@ -2280,8 +2291,8 @@ class TransposeOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); return OkStatus(); } @@ -2317,8 +2328,8 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { TfLitePoolParams* tf_options = nullptr; - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/2, /*outputs=*/1)); RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckKernelsAndStrides( tf_options->filter_height, tf_options->filter_width, @@ -2445,8 +2456,8 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); return OkStatus(); } @@ -2478,8 +2489,8 @@ class TransformTensorOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/2, /*outputs=*/1)); return OkStatus(); } @@ -2515,8 +2526,8 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/2, /*outputs=*/1)); return OkStatus(); } @@ -2549,7 +2560,7 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - return CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } @@ -2581,7 +2592,7 @@ class MeanOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - return CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } @@ -2970,7 +2981,6 @@ bool IsAllFloatTensors(const TfLiteContext* context, } return true; } - } // namespace Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD index 30d759df724..d2ef617a8e2 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD @@ -198,6 +198,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/gl:node_shader", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc index 941a32a8769..35b233cbdcc 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -130,89 +131,9 @@ class ElementwiseTwoArguments : public NodeShader { return true; } - Status ImplementElementwise(const GenerationContext& ctx, - GeneratedCode* generated_code) const { - std::string source; - switch (operation_type_) { - case OperationType::SUB: { - source = "value_0 -= value_1;"; - break; - } - case OperationType::DIV: { - source = "value_0 /= value_1;"; - break; - } - case OperationType::MAXIMUM: { - source = "value_0 = max(value_0, value_1);"; - break; - } - case OperationType::MINIMUM: { - source = "value_0 = min(value_0, value_1);"; - break; - } - case OperationType::POW: { - // From documentation : - // The result is undefined if x<0 or if x=0 and y≤0. - source = "value_0 = pow(value_0, value_1);"; - break; - } - case OperationType::SQUARED_DIFF: { - source = "value_0 = (value_0 - value_1) * (value_0 - value_1);"; - break; - } - - default: - return InvalidArgumentError( - "Incorrect elementwise with two arguments operation type."); - } - *generated_code = { - /*parameters=*/{}, - /*objects=*/{}, - /*shared_variables=*/{}, - /*workload=*/uint3(), - /*workgroup=*/uint3(), - /*source_code=*/source, - /*input=*/IOStructure::AUTO, - /*output=*/IOStructure::AUTO, - }; - return OkStatus(); - } - - Status ImplementElementwiseWithScalar(const GenerationContext& ctx, - const float scalar, - GeneratedCode* generated_code) const { - std::string source; - switch (operation_type_) { - case OperationType::MAXIMUM: { - source = "value_0 = max(value_0, $scalar$);"; - break; - } - case OperationType::MINIMUM: { - source = "value_0 = min(value_0, $scalar$);"; - break; - } - - default: - return InvalidArgumentError( - "Incorrect elementwise with scalar operation type."); - } - *generated_code = { - /*parameters=*/{{"scalar", scalar}}, - /*objects=*/{}, - /*shared_variables=*/{}, - /*workload=*/uint3(), - /*workgroup=*/uint3(), - /*source_code=*/source, - /*input=*/IOStructure::AUTO, - /*output=*/IOStructure::AUTO, - }; - return OkStatus(); - } - bool IsSupportedBroadcast(const GenerationContext& ctx) const { auto inputs = ctx.graph->FindInputs(ctx.node->id); auto outputs = ctx.graph->FindOutputs(ctx.node->id); - if (inputs.size() != 2) { return false; } @@ -223,57 +144,87 @@ class ElementwiseTwoArguments : public NodeShader { return true; } - Status ImplementElementwiseBroadcast(const GenerationContext& ctx, - GeneratedCode* generated_code) const { - std::string source; - switch (operation_type_) { - case OperationType::SQUARED_DIFF: { - source = R"( - vec4 diff = $input_data_0[gid.x, gid.y, gid.z]$ - - $input_data_1[0, 0, gid.z]$; - value_0 = diff * diff; - )"; - break; + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + std::vector parameters; + std::vector> objects; + std::string argument0, argument1; + if (IsSupportedElemwise(ctx)) { + argument0 = "value_0"; + argument1 = "value_1"; + } else if (IsSupportedBroadcast(ctx)) { + argument0 = "$input_data_0[gid.x, gid.y, gid.z]$"; + argument1 = "$input_data_1[0, 0, gid.z]$"; + } else { // Scalar of const vector case + const ElementwiseAttributes* attr = absl::any_cast( + &ctx.node->operation.attributes); + if (!attr) { + return InvalidArgumentError( + "Couldn't read attributes for the scalar of const vector case."); + } + auto* tensor = + absl::get_if<::tflite::gpu::Tensor>( + &attr->param); + auto* scalar = absl::get_if(&attr->param); + if (!tensor && !scalar) { + return InvalidArgumentError( + "Couldn't read scalar of const vector data from the attributes."); } + argument0 = "value_0"; + if (tensor) { + argument1 = "$const_data[gid.z]$"; + objects.push_back({"const_data", MakeReadonlyObject(tensor->data)}); + } else { + argument1 = "vec4($const_data$)"; + parameters.push_back({"const_data", *scalar}); + } + } + + std::string source; + switch (operation_type_) { + case OperationType::DIV: { + source = "value_0 = $0/$1;"; + break; + } + case OperationType::MAXIMUM: { + source = "value_0 = max($0, $1);"; + break; + } + case OperationType::MINIMUM: { + source = "value_0 = min($0, $1);"; + break; + } + case OperationType::SQUARED_DIFF: { + source = "value_0 = ($0 - $1) * ($0 - $1);"; + break; + } + case OperationType::SUB: { + source = "value_0 = $0 - $1;"; + break; + } + case OperationType::POW: { + source = "value_0 = pow($0, $1);"; + break; + } default: return InvalidArgumentError( - "Incorrect elementwise with two arguments operation type."); + "Incorrect elementwise with scalar operation type."); } + source = absl::Substitute(source, argument0, argument1); *generated_code = { - /*parameters=*/{}, - /*objects=*/{}, + /*parameters=*/std::move(parameters), + /*objects=*/std::move(objects), /*shared_variables=*/{}, /*workload=*/uint3(), /*workgroup=*/uint3(), /*source_code=*/source, - /*input=*/IOStructure::ONLY_DEFINITIONS, + /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; return OkStatus(); } - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { - if (IsSupportedElemwise(ctx)) { - return ImplementElementwise(ctx, generated_code); - } - if (IsSupportedBroadcast(ctx)) { - return ImplementElementwiseBroadcast(ctx, generated_code); - } - const ElementwiseAttributes* attr = - absl::any_cast(&ctx.node->operation.attributes); - if (attr) { - auto scalar = absl::get_if(&attr->param); - if (scalar) { - return ImplementElementwiseWithScalar(ctx, *scalar, generated_code); - } - } - return InvalidArgumentError( - "This case is not supported by elementwise with two arguments " - "operation"); - } - private: OperationType operation_type_; }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc index 3316395f5e3..625a09eebf4 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc @@ -36,7 +36,7 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { return tensor_ref; } -TEST(ElementwiseTest, Abs) { +TEST(ElementwiseOneArgumentTest, Abs) { OperationType op_type = OperationType::ABS; const BHWC shape(1, 2, 2, 1); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -48,7 +48,7 @@ TEST(ElementwiseTest, Abs) { Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0})); } -TEST(ElementwiseTest, Cos) { +TEST(ElementwiseOneArgumentTest, Cos) { OperationType op_type = OperationType::COS; const BHWC shape(1, 2, 2, 1); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -60,21 +60,7 @@ TEST(ElementwiseTest, Cos) { Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302})); } -TEST(ElementwiseTest, Div) { - OperationType op_type = OperationType::DIV; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model( - {/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)}, - /*outputs=*/{GetTensorRef(2, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); - ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0})); - ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); - EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0})); -} - -TEST(ElementwiseTest, Exp) { +TEST(ElementwiseOneArgumentTest, Exp) { OperationType op_type = OperationType::EXP; const BHWC shape(1, 1, 1, 7); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -90,7 +76,7 @@ TEST(ElementwiseTest, Exp) { std::exp(-0.01f)})); } -TEST(ElementwiseTest, HardSwish) { +TEST(ElementwiseOneArgumentTest, HardSwish) { OperationType op_type = OperationType::HARD_SWISH; const BHWC shape(1, 1, 1, 7); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -104,7 +90,7 @@ TEST(ElementwiseTest, HardSwish) { {0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f})); } -TEST(ElementwiseTest, Log) { +TEST(ElementwiseOneArgumentTest, Log) { OperationType op_type = OperationType::LOG; const BHWC shape(1, 2, 2, 1); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -116,7 +102,142 @@ TEST(ElementwiseTest, Log) { Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0})); } -TEST(ElementwiseTest, Maximum) { +TEST(ElementwiseOneArgumentTest, Rsqrt) { + OperationType op_type = OperationType::RSQRT; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333})); +} + +TEST(ElementwiseOneArgumentTest, Sigmoid) { + OperationType op_type = OperationType::SIGMOID; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014})); +} + +TEST(ElementwiseOneArgumentTest, Sin) { + OperationType op_type = OperationType::SIN; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471})); +} + +TEST(ElementwiseOneArgumentTest, Sqrt) { + OperationType op_type = OperationType::SQRT; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0})); +} + +TEST(ElementwiseOneArgumentTest, Square) { + OperationType op_type = OperationType::SQUARE; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0})); +} + +TEST(ElementwiseOneArgumentTest, Tanh) { + OperationType op_type = OperationType::TANH; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329})); +} + +TEST(ElementwiseTwoArgumentsTest, DivElementwise) { + OperationType op_type = OperationType::DIV; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)}, + /*outputs=*/{GetTensorRef(2, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); + ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0})); +} + +TEST(ElementwiseTwoArgumentsTest, DivBroadcast) { + OperationType op_type = OperationType::DIV; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 5.0, 4.0, 15.0})); +} + +TEST(ElementwiseTwoArgumentsTest, DivScalar) { + OperationType op_type = OperationType::DIV; + const BHWC shape0(1, 2, 1, 2); + ElementwiseAttributes attr; + attr.param = static_cast(0.5); + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 2.0, 4.0, 6.0})); +} + +TEST(ElementwiseTwoArgumentsTest, DivConstVector) { + OperationType op_type = OperationType::DIV; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {0.4, 0.5}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 2.0, 5.0, 6.0})); +} + +TEST(ElementwiseTwoArgumentsTest, MaximumElementwise) { OperationType op_type = OperationType::MAXIMUM; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -130,7 +251,22 @@ TEST(ElementwiseTest, Maximum) { Pointwise(FloatNear(1e-6), {1.0, 2.0, 3.0, -2.0})); } -TEST(ElementwiseTest, MaximumWithScalar) { +TEST(ElementwiseTwoArgumentsTest, MaximumBroadcast) { + OperationType op_type = OperationType::MAXIMUM; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.5, 1.0, 2.0, 3.0})); +} + +TEST(ElementwiseTwoArgumentsTest, MaximumScalar) { OperationType op_type = OperationType::MAXIMUM; const BHWC shape(1, 2, 2, 1); ElementwiseAttributes attr; @@ -145,7 +281,27 @@ TEST(ElementwiseTest, MaximumWithScalar) { Pointwise(FloatNear(1e-6), {0.0, -1.0, 2.0, -1.0})); } -TEST(ElementwiseTest, Minimum) { +TEST(ElementwiseTwoArgumentsTest, MaximumConstVector) { + OperationType op_type = OperationType::MAXIMUM; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {0.4, 0.5}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.4, 1.0, 2.0, 3.0})); +} + +TEST(ElementwiseTwoArgumentsTest, MinimumElementwise) { OperationType op_type = OperationType::MINIMUM; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -159,7 +315,22 @@ TEST(ElementwiseTest, Minimum) { Pointwise(FloatNear(1e-6), {0.0, -6.2, 2.0, -3.0})); } -TEST(ElementwiseTest, MinimumWithScalar) { +TEST(ElementwiseTwoArgumentsTest, MinimumBroadcast) { + OperationType op_type = OperationType::MINIMUM; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.2, 0.5, 0.2})); +} + +TEST(ElementwiseTwoArgumentsTest, MinimumScalar) { OperationType op_type = OperationType::MINIMUM; const BHWC shape(1, 2, 2, 1); ElementwiseAttributes attr; @@ -174,7 +345,27 @@ TEST(ElementwiseTest, MinimumWithScalar) { Pointwise(FloatNear(1e-6), {-1.0, -6.2, -1.0, -3.0})); } -TEST(ElementwiseTest, Pow) { +TEST(ElementwiseTwoArgumentsTest, MinimumConstVector) { + OperationType op_type = OperationType::MINIMUM; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {0.5, 0.2}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.2, 0.5, 0.2})); +} + +TEST(ElementwiseTwoArgumentsTest, PowElementwise) { OperationType op_type = OperationType::POW; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -188,67 +379,57 @@ TEST(ElementwiseTest, Pow) { Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0})); } -TEST(ElementwiseTest, Rsqrt) { - OperationType op_type = OperationType::RSQRT; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0})); +TEST(ElementwiseTwoArgumentsTest, PowBroadcast) { + OperationType op_type = OperationType::POW; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); + ASSERT_TRUE(model.PopulateTensor(1, {2.0, 0.5})); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333})); + Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 2.0})); } -TEST(ElementwiseTest, Sigmoid) { - OperationType op_type = OperationType::SIGMOID; +TEST(ElementwiseTwoArgumentsTest, PowScalar) { + OperationType op_type = OperationType::POW; const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); - ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); - EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014})); -} - -TEST(ElementwiseTest, Sin) { - OperationType op_type = OperationType::SIN; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0})); - ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); - EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471})); -} - -TEST(ElementwiseTest, Sqrt) { - OperationType op_type = OperationType::SQRT; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); + ElementwiseAttributes attr; + attr.param = 2.0f; + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/std::move(attr)}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(2, shape)}); ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0})); + Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 16.0})); } -TEST(ElementwiseTest, Square) { - OperationType op_type = OperationType::SQUARE; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0})); +TEST(ElementwiseTwoArgumentsTest, PowConstVector) { + OperationType op_type = OperationType::POW; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {2.0, 0.5}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0})); + Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 2.0})); } -TEST(ElementwiseTest, SquaredDiff) { +TEST(ElementwiseTwoArgumentsTest, SquaredDiffElementwise) { OperationType op_type = OperationType::SQUARED_DIFF; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -262,7 +443,56 @@ TEST(ElementwiseTest, SquaredDiff) { Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0})); } -TEST(ElementwiseTest, Sub) { +TEST(ElementwiseTwoArgumentsTest, SquaredDiffBroadcast) { + OperationType op_type = OperationType::SQUARED_DIFF; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {-1.0, 5.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 16.0, 9.0, 4.0})); +} + +TEST(ElementwiseTwoArgumentsTest, SquaredDiffScalar) { + OperationType op_type = OperationType::SQUARED_DIFF; + const BHWC shape0(1, 2, 1, 2); + ElementwiseAttributes attr; + attr.param = static_cast(5.0); + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {25.0, 16.0, 9.0, 4.0})); +} + +TEST(ElementwiseTwoArgumentsTest, SquaredDiffConstVector) { + OperationType op_type = OperationType::SQUARED_DIFF; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {-1.0, 5.0}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 16.0, 9.0, 4.0})); +} + +TEST(ElementwiseTwoArgumentsTest, SubElementwise) { OperationType op_type = OperationType::SUB; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -276,16 +506,53 @@ TEST(ElementwiseTest, Sub) { Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0})); } -TEST(ElementwiseTest, Tanh) { - OperationType op_type = OperationType::TANH; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); +TEST(ElementwiseTwoArgumentsTest, SubBroadcast) { + OperationType op_type = OperationType::SUB; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {0.3, 0.2})); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329})); + Pointwise(FloatNear(1e-6), {-0.3, 0.8, 1.7, 2.8})); +} + +TEST(ElementwiseTwoArgumentsTest, SubScalar) { + OperationType op_type = OperationType::SUB; + const BHWC shape0(1, 2, 1, 2); + ElementwiseAttributes attr; + attr.param = static_cast(0.5); + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-0.5, 0.5, 1.5, 2.5})); +} + +TEST(ElementwiseTwoArgumentsTest, SubConstVector) { + OperationType op_type = OperationType::SUB; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {0.3, 0.2}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-0.3, 0.8, 1.7, 2.8})); } } // namespace From f8da7c2b15751c5b6b9b8fd9a445100b16f30c73 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Fri, 20 Mar 2020 11:44:45 -0700 Subject: [PATCH 329/492] [tf.data service] Add framework for pluggable credentials. PiperOrigin-RevId: 302069447 Change-Id: Ifd01c60826967b03693395668672bf836dd40f61 --- tensorflow/core/data/service/BUILD | 23 ++++ .../core/data/service/credentials_factory.cc | 111 ++++++++++++++++++ .../core/data/service/credentials_factory.h | 69 +++++++++++ .../data/service/credentials_factory_test.cc | 91 ++++++++++++++ 4 files changed, 294 insertions(+) create mode 100644 tensorflow/core/data/service/credentials_factory.cc create mode 100644 tensorflow/core/data/service/credentials_factory.h create mode 100644 tensorflow/core/data/service/credentials_factory_test.cc diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 68c0f2d47d7..d23791e510a 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -78,6 +78,29 @@ tf_cc_test( ], ) +cc_library( + name = "credentials_factory", + srcs = ["credentials_factory.cc"], + hdrs = ["credentials_factory.h"], + deps = [ + "//tensorflow:grpc++", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "credentials_factory_test", + srcs = ["credentials_factory_test.cc"], + deps = [ + ":credentials_factory", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_grpc_library( name = "master_cc_grpc_proto", srcs = [":master_proto"], diff --git a/tensorflow/core/data/service/credentials_factory.cc b/tensorflow/core/data/service/credentials_factory.cc new file mode 100644 index 00000000000..88b0073ae26 --- /dev/null +++ b/tensorflow/core/data/service/credentials_factory.cc @@ -0,0 +1,111 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/credentials_factory.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace data { + +namespace { +mutex* get_lock() { + static mutex lock(LINKER_INITIALIZED); + return &lock; +} + +using CredentialsFactories = + std::unordered_map; +CredentialsFactories& credentials_factories() { + static auto& factories = *new CredentialsFactories(); + return factories; +} +} // namespace + +void CredentialsFactory::Register(CredentialsFactory* factory) { + mutex_lock l(*get_lock()); + if (!credentials_factories().insert({factory->Protocol(), factory}).second) { + LOG(ERROR) + << "Two credentials factories are being registered with protocol " + << factory->Protocol() << ". Which one gets used is undefined."; + } +} + +Status CredentialsFactory::Get(absl::string_view protocol, + CredentialsFactory** out) { + mutex_lock l(*get_lock()); + auto it = credentials_factories().find(std::string(protocol)); + if (it != credentials_factories().end()) { + *out = it->second; + return Status::OK(); + } + + std::vector available_types; + for (const auto& factory : credentials_factories()) { + available_types.push_back(factory.first); + } + + return errors::NotFound("No credentials factory has been registered for ", + "protocol ", protocol, + ". The available types are: [ ", + absl::StrJoin(available_types, ", "), " ]"); +} + +Status CredentialsFactory::CreateServerCredentials( + absl::string_view protocol, std::shared_ptr* out) { + CredentialsFactory* factory; + TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory)); + TF_RETURN_IF_ERROR(factory->CreateServerCredentials(out)); + return Status::OK(); +} + +Status CredentialsFactory::CreateClientCredentials( + absl::string_view protocol, + std::shared_ptr* out) { + CredentialsFactory* factory; + TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory)); + TF_RETURN_IF_ERROR(factory->CreateClientCredentials(out)); + return Status::OK(); +} + +class InsecureCredentialsFactory : public CredentialsFactory { + public: + std::string Protocol() override { return "grpc"; } + + Status CreateServerCredentials( + std::shared_ptr* out) override { + *out = grpc::InsecureServerCredentials(); + return Status::OK(); + } + + Status CreateClientCredentials( + std::shared_ptr* out) override { + *out = grpc::InsecureChannelCredentials(); + return Status::OK(); + } +}; + +class InsecureCredentialsRegistrar { + public: + InsecureCredentialsRegistrar() { + auto factory = new InsecureCredentialsFactory(); + CredentialsFactory::Register(factory); + } +}; +static InsecureCredentialsRegistrar registrar; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/credentials_factory.h b/tensorflow/core/data/service/credentials_factory.h new file mode 100644 index 00000000000..a93b9411ec0 --- /dev/null +++ b/tensorflow/core/data/service/credentials_factory.h @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_ + +#include "grpcpp/grpcpp.h" +#include "grpcpp/security/credentials.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace data { + +// Credential factory implementations should be threadsafe since all callers +// to `GetCredentials` will get the same instance of `CredentialsFactory`. +class CredentialsFactory { + public: + virtual ~CredentialsFactory() = default; + + // Returns a protocol name for the credentials factory. This is the string to + // look up with `GetCredentials` to find the registered credentials factory. + virtual std::string Protocol() = 0; + + // Stores server credentials to `*out`. + virtual Status CreateServerCredentials( + std::shared_ptr* out) = 0; + + // Stores client credentials to `*out`. + virtual Status CreateClientCredentials( + std::shared_ptr* out) = 0; + + // Registers a credentials factory. + static void Register(CredentialsFactory* factory); + + // Creates server credentials using the credentials factory registered as + // `protocol`, and stores them to `*out`. + static Status CreateServerCredentials( + absl::string_view protocol, + std::shared_ptr* out); + + // Creates client credentials using the credentials factory registered as + // `protocol`, and stores them to `*out`. + static Status CreateClientCredentials( + absl::string_view protocol, + std::shared_ptr* out); + + private: + // Gets the credentials factory registered via `Register` for the specified + // protocol, and stores it to `*out`. + static Status Get(const absl::string_view protocol, CredentialsFactory** out); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_ diff --git a/tensorflow/core/data/service/credentials_factory_test.cc b/tensorflow/core/data/service/credentials_factory_test.cc new file mode 100644 index 00000000000..507c553963a --- /dev/null +++ b/tensorflow/core/data/service/credentials_factory_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/credentials_factory.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { + +namespace { +constexpr char kFailedToCreateServerCredentials[] = + "Failed to create server credentials."; +constexpr char kFailedToCreateClientCredentials[] = + "Failed to create client credentials."; + +class TestCredentialsFactory : public CredentialsFactory { + public: + std::string Protocol() override { return "test"; } + + Status CreateServerCredentials( + std::shared_ptr* out) override { + return errors::Internal(kFailedToCreateServerCredentials); + } + + Status CreateClientCredentials( + std::shared_ptr* out) override { + return errors::Internal(kFailedToCreateClientCredentials); + } +}; +} // namespace + +TEST(CredentialsFactory, Register) { + TestCredentialsFactory test_factory; + CredentialsFactory::Register(&test_factory); + std::shared_ptr server_credentials; + ASSERT_EQ(errors::Internal(kFailedToCreateServerCredentials), + CredentialsFactory::CreateServerCredentials(test_factory.Protocol(), + &server_credentials)); + std::shared_ptr client_credentials; + ASSERT_EQ(errors::Internal(kFailedToCreateClientCredentials), + CredentialsFactory::CreateClientCredentials(test_factory.Protocol(), + &client_credentials)); +} + +TEST(CredentialsFactory, DefaultGrpcProtocol) { + std::shared_ptr server_credentials; + TF_ASSERT_OK( + CredentialsFactory::CreateServerCredentials("grpc", &server_credentials)); + std::shared_ptr client_credentials; + TF_ASSERT_OK( + CredentialsFactory::CreateClientCredentials("grpc", &client_credentials)); +} + +TEST(CredentialsFactory, MissingServerProtocol) { + std::shared_ptr server_credentials; + Status s = CredentialsFactory::CreateServerCredentials("unknown_protocol", + &server_credentials); + ASSERT_EQ(error::Code::NOT_FOUND, s.code()); + ASSERT_TRUE( + absl::StrContains(s.ToString(), + "No credentials factory has been registered for " + "protocol unknown_protocol")); +} + +TEST(CredentialsFactory, MissingClientProtocol) { + std::shared_ptr client_credentials; + Status s = CredentialsFactory::CreateClientCredentials("unknown_protocol", + &client_credentials); + ASSERT_EQ(error::Code::NOT_FOUND, s.code()); + ASSERT_TRUE( + absl::StrContains(s.ToString(), + "No credentials factory has been registered for " + "protocol unknown_protocol")); +} + +} // namespace data +} // namespace tensorflow From 0ad3c881ff7724cf6275276670f2b7efe9dab382 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 11:46:35 -0700 Subject: [PATCH 330/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302069758 Change-Id: I9d75916f416c3417d334c3b002c6f23e5a6812bc --- tensorflow/go/op/wrappers.go | 48 ++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index b8b73bc472d..75d86f71b78 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -13925,7 +13925,7 @@ func DebugNumericSummaryGatedGrpc(value bool) DebugNumericSummaryAttr { // Provide a basic summary of numeric value types, range and distribution. // // output: A double tensor of shape [14 + nDimensions], where nDimensions is the -// the number of dimensions of the tensor's shape. The elements of output are: +// number of dimensions of the tensor's shape. The elements of output are: // [0]: is initialized (1.0) or not (0.0). // [1]: total number of elements // [2]: NaN element count @@ -13935,7 +13935,7 @@ func DebugNumericSummaryGatedGrpc(value bool) DebugNumericSummaryAttr { // -inf. Otherwise, this is the count of elements > lower_bound and < 0. // [5]: zero element count // [6]: positive element count (excluding +inf), if upper_bound is the default -// -inf. Otherwise, this is the count of elements < upper_bound and > 0. +// +inf. Otherwise, this is the count of elements < upper_bound and > 0. // [7]: generalized +inf count, elements >= upper_bound. upper_bound is +inf by // default. // Output elements [1:8] are all zero, if the tensor is uninitialized. @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From d0e21cd46860370f6146bf1ffdc07620744f2bbd Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Fri, 20 Mar 2020 11:47:16 -0700 Subject: [PATCH 331/492] Use correct variable _device attribute in Keras optimizer_v2. PiperOrigin-RevId: 302069884 Change-Id: I32ff43f146c6f60d462d2713908c3cf258ace3de --- tensorflow/python/keras/optimizer_v2/optimizer_v2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 2a4d4cf86e8..d9f090c0a60 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -699,8 +699,10 @@ class OptimizerV2(trackable.Trackable): def _prepare(self, var_list): keys = set() for var in var_list: - var_devices = (getattr(var, "devices", None) or # Distributed - [var.device]) # Regular + if isinstance(var, ds_values.DistributedValues): + var_devices = var._devices # pylint: disable=protected-access + else: + var_devices = [var.device] var_dtype = var.dtype.base_dtype for var_device in var_devices: keys.add((var_device, var_dtype)) From a794c690b48d86a1e02e9de4c313606c9245a2b8 Mon Sep 17 00:00:00 2001 From: Robert David Date: Fri, 20 Mar 2020 11:51:05 -0700 Subject: [PATCH 332/492] Move the templatized integer implementation of Softmax to reference/softmax.h, make it work with uint8_t to replace the uint8 version there. PiperOrigin-RevId: 302070491 Change-Id: I3a1148604fbeae271891d2cc202e765aed5b5b9f --- tensorflow/lite/kernels/internal/BUILD | 1 - .../internal/reference/integer_ops/softmax.h | 107 ------------------ .../lite/kernels/internal/reference/softmax.h | 25 ++-- .../internal/softmax_quantized_test.cc | 1 - tensorflow/lite/micro/kernels/softmax.cc | 29 +++-- .../micro/kernels/xtensa_hifimini/softmax.cc | 1 - tensorflow/lite/micro/tools/make/Makefile | 1 - 7 files changed, 31 insertions(+), 134 deletions(-) delete mode 100644 tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 376c3d5d34b..c9e6c082b53 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -438,7 +438,6 @@ cc_library( "reference/integer_ops/mean.h", "reference/integer_ops/mul.h", "reference/integer_ops/pooling.h", - "reference/integer_ops/softmax.h", "reference/integer_ops/tanh.h", "reference/integer_ops/transpose_conv.h", "reference/logistic.h", diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h b/tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h deleted file mode 100644 index 28dfc047533..00000000000 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_ -#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_ - -#include "tensorflow/lite/kernels/internal/common.h" - -namespace tflite { -namespace reference_integer_ops { - -// Quantized softmax with int8 input and int8/int16 output. -template -inline void Softmax(const SoftmaxParams& params, - const RuntimeShape& input_shape, const int8* input_data, - const RuntimeShape& output_shape, OutputT* output_data) { - const int32_t input_beta_multiplier = params.input_multiplier; - const int32_t input_beta_left_shift = params.input_left_shift; - const int diff_min = params.diff_min; - // The representation chosen for the input to the exp() function is Q5.26. - // We need to leave extra space since values that we skip might be as large as - // -32 before multiplying by input_beta_multiplier, and therefore as large as - // -16 afterwards. Note that exp(-8) is definitely not insignificant to - // accumulation, but exp(-16) definitely is. - static const int kScaledDiffIntegerBits = 5; - static const int kAccumulationIntegerBits = 12; - using FixedPointScaledDiff = - gemmlowp::FixedPoint; - using FixedPointAccum = - gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; - - const int trailing_dim = input_shape.DimensionsCount() - 1; - const int outer_size = - MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); - const int depth = - MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - - for (int i = 0; i < outer_size; ++i) { - int8 max_in_row = -128; - for (int c = 0; c < depth; ++c) { - max_in_row = std::max(max_in_row, input_data[i * depth + c]); - } - - FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); - for (int c = 0; c < depth; ++c) { - int32_t input_diff = - static_cast(input_data[i * depth + c]) - max_in_row; - if (input_diff >= diff_min) { - const int32_t input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - sum_of_exps = sum_of_exps + gemmlowp::Rescale( - exp_on_negative_values(scaled_diff_f8)); - } - } - - int num_bits_over_unit; - FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal( - sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit)); - - for (int c = 0; c < depth; ++c) { - int32_t input_diff = - static_cast(input_data[i * depth + c]) - max_in_row; - if (input_diff >= diff_min) { - const int32_t input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - - FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); - const int32_t unsat_output = gemmlowp::RoundingDivideByPOT( - (shifted_scale * exp_in_0).raw(), - num_bits_over_unit + 31 - (sizeof(OutputT) * 8)); - // TODO(b/148494470): Handle int32 shifts properly: - const int32_t shifted_output = - unsat_output - - (static_cast(std::numeric_limits::max()) + 1); - output_data[i * depth + c] = static_cast(std::max( - std::min(shifted_output, - static_cast(std::numeric_limits::max())), - static_cast(std::numeric_limits::min()))); - } else { - output_data[i * depth + c] = std::numeric_limits::min(); - } - } - } -} - -} // namespace reference_integer_ops -} // namespace tflite - -#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_ diff --git a/tensorflow/lite/kernels/internal/reference/softmax.h b/tensorflow/lite/kernels/internal/reference/softmax.h index ac06d49000e..26e402db3da 100644 --- a/tensorflow/lite/kernels/internal/reference/softmax.h +++ b/tensorflow/lite/kernels/internal/reference/softmax.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_ +#include + #include "fixedpoint/fixedpoint.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" @@ -59,9 +61,11 @@ inline void Softmax(const SoftmaxParams& params, } } +// Quantized softmax with int8/uint8 input and int8/uint8/int16 output. +template inline void Softmax(const SoftmaxParams& params, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { + const RuntimeShape& input_shape, const InputT* input_data, + const RuntimeShape& output_shape, OutputT* output_data) { const int32 input_beta_multiplier = params.input_multiplier; const int32 input_beta_left_shift = params.input_left_shift; const int diff_min = params.diff_min; @@ -84,7 +88,7 @@ inline void Softmax(const SoftmaxParams& params, MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { - uint8 max_in_row = 0; + InputT max_in_row = std::numeric_limits::min(); for (int c = 0; c < depth; ++c) { max_in_row = std::max(max_in_row, input_data[i * depth + c]); } @@ -120,14 +124,19 @@ inline void Softmax(const SoftmaxParams& params, FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); int32 unsat_output = gemmlowp::RoundingDivideByPOT( - (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + (shifted_scale * exp_in_0).raw(), + num_bits_over_unit + 31 - (sizeof(OutputT) * 8)); - output_data[i * depth + c] = static_cast( - std::max(std::min(unsat_output, static_cast(255)), - static_cast(0))); + const int32 shifted_output = + unsat_output + + static_cast(std::numeric_limits::min()); + output_data[i * depth + c] = static_cast(std::max( + std::min(shifted_output, + static_cast(std::numeric_limits::max())), + static_cast(std::numeric_limits::min()))); } else { - output_data[i * depth + c] = 0; + output_data[i * depth + c] = std::numeric_limits::min(); } } } diff --git a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc index 0ca030eda80..4f9c6471a1c 100644 --- a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/test_util.h" #include "tensorflow/lite/string_type.h" diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc index fe2bfce5c7a..85952de9d50 100644 --- a/tensorflow/lite/micro/kernels/softmax.cc +++ b/tensorflow/lite/micro/kernels/softmax.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" @@ -116,13 +115,13 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, GetTensorData(output)); } else { if (output->type == kTfLiteInt16) { - tflite::reference_integer_ops::Softmax( - op_params, shape, GetTensorData(input), shape, - GetTensorData(output)); + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); } else { - tflite::reference_integer_ops::Softmax( - op_params, shape, GetTensorData(input), shape, - GetTensorData(output)); + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); } } } @@ -147,13 +146,13 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, GetTensorData(output)); } else { if (output->type == kTfLiteInt16) { - tflite::reference_integer_ops::Softmax( - op_params, shape, GetTensorData(input), shape, - GetTensorData(output)); + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); } else { - tflite::reference_integer_ops::Softmax( - op_params, shape, GetTensorData(input), shape, - GetTensorData(output)); + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); } } } @@ -180,11 +179,11 @@ void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, GetTensorShape(output), GetTensorData(output)); } else { if (output->type == kTfLiteInt16) { - tflite::reference_integer_ops::Softmax( + tflite::reference_ops::Softmax( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { - tflite::reference_integer_ops::Softmax( + tflite::reference_ops::Softmax( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc index 54191b56eaa..4631791fede 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 0038e6813a2..e78979032a2 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -149,7 +149,6 @@ tensorflow/lite/kernels/internal/reference/integer_ops/conv.h \ tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h \ tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \ tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \ -tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h \ tensorflow/lite/kernels/internal/reference/maximum_minimum.h \ tensorflow/lite/kernels/internal/reference/mul.h \ tensorflow/lite/kernels/internal/reference/neg.h \ From 289af050cc8f85f782af7c525977689d34f96f61 Mon Sep 17 00:00:00 2001 From: Cesar Crusius Date: Fri, 20 Mar 2020 11:57:09 -0700 Subject: [PATCH 333/492] Internal Copybara change. PiperOrigin-RevId: 302071704 Change-Id: If250e336a66aa462b09881cf9a8f6eb4e6ccc3a7 --- tensorflow/core/kernels/eigen_contraction_kernel.cc | 10 ++++++++-- tensorflow/core/platform/strcat.h | 2 +- tensorflow/lite/delegates/gpu/gl/kernels/resize.cc | 2 +- tensorflow/lite/delegates/gpu/metal/kernels/resize.cc | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/eigen_contraction_kernel.cc b/tensorflow/core/kernels/eigen_contraction_kernel.cc index aa6cb4b9cb9..4959651569c 100644 --- a/tensorflow/core/kernels/eigen_contraction_kernel.cc +++ b/tensorflow/core/kernels/eigen_contraction_kernel.cc @@ -28,7 +28,9 @@ limitations under the License. // the configuration through the environment variable. // // Example: -// bazel test --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test +// bazel test \ +// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ +// //path/to:test #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) @@ -37,7 +39,11 @@ namespace internal { // TODO(ezhulenev): This is a temporary workaround for disabling custom kernels // at runtime in tests. We should always rely on compile time flags for that. -// Example: ... --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test +// +// Example: +// bazel test \ +// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \ +// //path/to:test EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels() { static bool use_custom_contraction_kernel = true; diff --git a/tensorflow/core/platform/strcat.h b/tensorflow/core/platform/strcat.h index 6b435dceca3..640355c9ea5 100644 --- a/tensorflow/core/platform/strcat.h +++ b/tensorflow/core/platform/strcat.h @@ -33,7 +33,7 @@ limitations under the License. // to your function, your callers will automatically convert bools, integers, // and floating point values to strings for you. // -// NOTE: Use of AlphaNum outside of the //strings package is unsupported except +// NOTE: Use of AlphaNum outside of the "strings" package is unsupported except // for the specific case of function parameters of type "AlphaNum" or "const // AlphaNum &". In particular, instantiating AlphaNum directly as a stack // variable is not supported. diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc index b8949e41426..33d59518987 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc @@ -93,7 +93,7 @@ class Resize : public NodeShader { st.xy = max(icoord_floor, ivec2(0, 0)); st.zw = min(icoord_floor + ivec2(1, 1), borders); - vec2 t = coord - coord_floor; //interpolating factors + vec2 t = coord - coord_floor; // interpolating factors vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$; vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc index 2ed75ad65b1..24d7bcf13bc 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize.cc @@ -54,7 +54,7 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) { int4 st; st.xy = max(itex_coord_floor, int2(0, 0)); st.zw = min(itex_coord_floor + int2(1, 1), borders); - const float2 t = tex_coord - tex_coord_floor; //interpolating factors + const float2 t = tex_coord - tex_coord_floor; // interpolating factors const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x; const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z; const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x; From 9772f7ed65285c062499e02035e941287b0d34ab Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 11:57:23 -0700 Subject: [PATCH 334/492] Fix empty hlo category to show as unknown. PiperOrigin-RevId: 302071742 Change-Id: I6a0f04daf7c3ff4ab8b9b0702b29193c5096adc4 --- tensorflow/core/profiler/utils/op_utils.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc index a4051bfac31..58893557e5b 100644 --- a/tensorflow/core/profiler/utils/op_utils.cc +++ b/tensorflow/core/profiler/utils/op_utils.cc @@ -64,9 +64,10 @@ void DeviceOpMetricsDbBuilder::EnterOp( DCHECK_GE(time_ps, self_time_ps); OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, name); if (op_metrics->category().empty()) - op_metrics->set_category(std::string(category)); + op_metrics->set_category(category == kUnknownOp ? "unknown" + : string(category)); if (op_metrics->provenance().empty()) - op_metrics->set_provenance(std::string(provenance)); + op_metrics->set_provenance(string(provenance)); op_metrics->set_occurrences(op_metrics->occurrences() + occurrences); op_metrics->set_time_ps(op_metrics->time_ps() + time_ps); op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_time_ps); From 6429fed042a8462e6502548945bed9fc99f29d2f Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Fri, 20 Mar 2020 12:01:57 -0700 Subject: [PATCH 335/492] Fix GCS filesystem on windows by using platform independent primitives in GetMatchingPaths This makes all GcsfileSystem tests pass on windows. PiperOrigin-RevId: 302072552 Change-Id: I03d1ab506860e9ce82a1981cd88df3caa25d3324 --- tensorflow/core/platform/cloud/gcs_file_system.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index c30762b96e9..bc23cf14b03 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -1331,7 +1331,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, // Find the fixed prefix by looking for the first wildcard. const string& fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); - const string dir(io::Dirname(fixed_prefix)); + const string dir(this->Dirname(fixed_prefix)); if (dir.empty()) { return errors::InvalidArgument( "A GCS pattern doesn't have a bucket name: ", pattern); @@ -1345,8 +1345,8 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, // Match all obtained paths to the input pattern. for (const auto& path : files_and_folders) { - const string& full_path = io::JoinPath(dir, path); - if (Env::Default()->MatchPath(full_path, pattern)) { + const string& full_path = this->JoinPath(dir, path); + if (this->Match(full_path, pattern)) { results->push_back(full_path); } } From b513093daef99c39f8951e522329f94451b27758 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 12:29:29 -0700 Subject: [PATCH 336/492] [TF:XLA] Support setting up aliasing of tuple args when converting a MLIR module in HLO dialect into a HloModuleProto Extend ConvertToHloModule to handle tuple arguments by marking the corresponding element of an input parameter as an alias. Add a test case. PiperOrigin-RevId: 302077866 Change-Id: I1332ca1fc751da603b82f49c5c719a18826565da --- tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc | 6 +++++- .../mlir/xla/tests/translate/input_output_aliasing.mlir | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 92614755ec3..550d151d968 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -987,7 +987,11 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { // into the resulting HloModule. auto aliasing_output = f.getArgAttrOfType(i, "tf.aliasing_output"); - if (aliasing_output) { + if (!aliasing_output) continue; + if (use_tuple_args_) { + builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()}, + /*param_number=*/0, /*param_index=*/{i}); + } else { builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()}, /*param_number=*/i, /*param_index=*/{}); } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/input_output_aliasing.mlir b/tensorflow/compiler/mlir/xla/tests/translate/input_output_aliasing.mlir index 3ad781b6bbb..a0dc1798dc6 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/input_output_aliasing.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/input_output_aliasing.mlir @@ -1,7 +1,10 @@ // RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-return-tuple %s | FileCheck %s +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-return-tuple %s | FileCheck %s --check-prefix=TUPLE-ARG // CHECK-LABEL: ENTRY %main // CHECK: // OutputIndex {0} aliases with input 0 at {} +// TUPLE-ARG-LABEL: ENTRY %main +// TUPLE-ARG: // OutputIndex {0} aliases with input 0 at {0} func @main(%arg0: tensor<1xf32> {tf.aliasing_output = 0 : i64}) -> (tensor<1xf32>) { %0 = xla_hlo.constant dense<4.200000e+01> : tensor<1xf32> %1 = xla_hlo.add %arg0, %0 : tensor<1xf32> From 975e8b94de2fb875f443bb1f2c279ee1093ac91b Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 20 Mar 2020 12:37:14 -0700 Subject: [PATCH 337/492] [tf.data] Avoid contention in bytes-read metric collection. This change caches the `CounterCell*` for each source dataset type that records the number of bytes read. This avoids contention in the thread-safe map lookup each time a record is read. PiperOrigin-RevId: 302079362 Change-Id: I4efaeff8e0ffb3df2311228489ba107ce378ce91 --- tensorflow/core/common_runtime/metrics.cc | 4 ++-- tensorflow/core/common_runtime/metrics.h | 7 ++++--- .../kernels/data/fixed_length_record_dataset_op.cc | 13 +++++++------ .../core/kernels/data/text_line_dataset_op.cc | 9 +++++---- .../core/kernels/data/tf_record_dataset_op.cc | 6 ++++-- 5 files changed, 22 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/common_runtime/metrics.cc b/tensorflow/core/common_runtime/metrics.cc index f05f9312b50..a26a678af13 100644 --- a/tensorflow/core/common_runtime/metrics.cc +++ b/tensorflow/core/common_runtime/metrics.cc @@ -132,8 +132,8 @@ void RecordTFDataAutotune(const string& name) { tf_data_autotune_counter->GetCell(name)->IncrementBy(1); } -void RecordTFDataBytesRead(const string& name, int64 num_bytes) { - tf_data_bytes_read_counter->GetCell(name)->IncrementBy(num_bytes); +monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name) { + return tf_data_bytes_read_counter->GetCell(name); } void RecordTFDataBytesFetched(int64 num_bytes) { diff --git a/tensorflow/core/common_runtime/metrics.h b/tensorflow/core/common_runtime/metrics.h index a5d43da539b..e95e0495c04 100644 --- a/tensorflow/core/common_runtime/metrics.h +++ b/tensorflow/core/common_runtime/metrics.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_ +#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -26,11 +27,11 @@ namespace metrics { // The `name` argument identifies the Dataset type (e.g. "ParallelMap"). void RecordTFDataAutotune(const string& name); -// Records the number of bytes read from the filesystem by a tf.data.Dataset -// source. +// Returns a counter than can be used to record the number of bytes read from +// the filesystem by a tf.data.Dataset source. // // The `name` argument identifies the Dataset type (e.g. "TFRecordDataset"). -void RecordTFDataBytesRead(const string& name, int64 num_bytes); +monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name); // Records the number of bytes fetched from tf.data.Dataset iterator. void RecordTFDataBytesFetched(int64 num_bytes); diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index 15bfeb01a65..468a22261d5 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc @@ -138,8 +138,9 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { string record; TF_RETURN_IF_ERROR( input_buffer_->ReadNBytes(dataset()->record_bytes_, &record)); - metrics::RecordTFDataBytesRead(kDatasetType, - dataset()->record_bytes_); + static monitoring::CounterCell* bytes_counter = + metrics::GetTFDataBytesReadCounter(kDatasetType); + bytes_counter->IncrementBy(dataset()->record_bytes_); // Produce the record as output. Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); @@ -251,6 +252,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { + static monitoring::CounterCell* bytes_counter = + metrics::GetTFDataBytesReadCounter(kDatasetType); mutex_lock l(mu_); do { // We are currently processing a file, so try to read the next record. @@ -262,8 +265,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { tstring record; TF_RETURN_IF_ERROR(buffered_input_stream_->ReadNBytes( dataset()->record_bytes_, &record)); - metrics::RecordTFDataBytesRead(kDatasetType, - dataset()->record_bytes_); + bytes_counter->IncrementBy(dataset()->record_bytes_); // Produce the record as output. Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); @@ -277,8 +279,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { Status s = buffered_input_stream_->ReadNBytes( dataset()->record_bytes_, &record); if (s.ok()) { - metrics::RecordTFDataBytesRead(kDatasetType, - dataset()->record_bytes_); + bytes_counter->IncrementBy(dataset()->record_bytes_); lookahead_cache_.append(record); StringPiece lookahead_cache_view(lookahead_cache_); record = tstring( diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc index c2c3190bd7f..dc193f53a8d 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc @@ -105,12 +105,13 @@ class TextLineDatasetOp::Dataset : public DatasetBase { if (s.ok()) { // Produce the line as output. - metrics::RecordTFDataBytesRead( - name_utils::OpName(TextLineDatasetOp::kDatasetType), - line_contents.size()); + static monitoring::CounterCell* bytes_counter = + metrics::GetTFDataBytesReadCounter( + name_utils::OpName(TextLineDatasetOp::kDatasetType)); + bytes_counter->IncrementBy(line_contents.size()); out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = std::move(line_contents); + out_tensors->back().scalar()() = line_contents; *end_of_sequence = false; return Status::OK(); } else if (!errors::IsOutOfRange(s)) { diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc index a72d05c5155..94d523b5bfb 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc @@ -119,8 +119,10 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { Status s = reader_->ReadRecord(&out_tensors->back().scalar()()); if (s.ok()) { - metrics::RecordTFDataBytesRead( - kDatasetType, out_tensors->back().scalar()().size()); + static monitoring::CounterCell* bytes_counter = + metrics::GetTFDataBytesReadCounter(kDatasetType); + bytes_counter->IncrementBy( + out_tensors->back().scalar()().size()); *end_of_sequence = false; return Status::OK(); } From a23fe18cb3286b4557947d5c0dbd547bc89ee066 Mon Sep 17 00:00:00 2001 From: Amy Skerry-Ryan Date: Fri, 20 Mar 2020 12:39:52 -0700 Subject: [PATCH 338/492] Wire training argument through feature column helpers / DenseFeatures so that feature columns can be training-mode-aware. V2 Feature Columns allow the construction of user-defined feature columns than can wrap arbitrary tensorflow subgraphs for transforming a feature. In some cases, these sub-graphs need to know about the mode (training vs inference) in which a transform is being applied. PiperOrigin-RevId: 302079903 Change-Id: Ia20100de4229ee4532cdd8c52569f931e2b8179f --- .../python/feature_column/dense_features.py | 32 ++++++++++++++-- .../feature_column/dense_features_v2_test.py | 37 +++++++++++++++++++ .../feature_column/feature_column_v2.py | 29 +++++++++++---- .../feature_column/sequence_feature_column.py | 30 ++++++++++++--- ...eras.experimental.-sequence-features.pbtxt | 2 +- ...sorflow.keras.layers.-dense-features.pbtxt | 2 +- ...eras.experimental.-sequence-features.pbtxt | 2 +- ...sorflow.keras.layers.-dense-features.pbtxt | 2 +- 8 files changed, 114 insertions(+), 22 deletions(-) diff --git a/tensorflow/python/feature_column/dense_features.py b/tensorflow/python/feature_column/dense_features.py index 3bc93b377a1..820f1a6b1b7 100644 --- a/tensorflow/python/feature_column/dense_features.py +++ b/tensorflow/python/feature_column/dense_features.py @@ -22,6 +22,7 @@ import json from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.framework import ops +from tensorflow.python.keras import backend from tensorflow.python.util import serialization from tensorflow.python.util.tf_export import keras_export @@ -49,7 +50,7 @@ class DenseFeatures(fc._BaseFeaturesLayer): # pylint: disable=protected-access price = tf.feature_column.numeric_column('price') keywords_embedded = tf.feature_column.embedding_column( tf.feature_column.categorical_column_with_hash_bucket("keywords", 10K), - dimensions=16) + dimension=16) columns = [price, keywords_embedded, ...] partitioner = tf.compat.v1.fixed_size_partitioner(num_shards=4) feature_layer = tf.compat.v1.keras.layers.DenseFeatures( @@ -115,9 +116,19 @@ class DenseFeatures(fc._BaseFeaturesLayer): # pylint: disable=protected-access def _target_shape(self, input_shape, total_elements): return (input_shape[0], total_elements) - def call(self, features, cols_to_output_tensors=None): + def call(self, features, cols_to_output_tensors=None, training=None): """Returns a dense tensor corresponding to the `feature_columns`. + Example usage: + + >>> t1 = tf.feature_column.embedding_column( + ... tf.feature_column.categorical_column_with_hash_bucket("t1", 2), + ... dimension=8) + >>> t2 = tf.feature_column.numeric_column('t2') + >>> feature_layer = tf.compat.v1.keras.layers.DenseFeatures([t1, t2]) + >>> features = {"t1": tf.constant(["a", "b"]), "t2": tf.constant([1, 2])} + >>> dense_tensor = feature_layer(features, training=True) + Args: features: A mapping from key to tensors. `FeatureColumn`s look up via these keys. For example `numeric_column('price')` will look at 'price' @@ -125,6 +136,13 @@ class DenseFeatures(fc._BaseFeaturesLayer): # pylint: disable=protected-access on corresponding `FeatureColumn`. cols_to_output_tensors: If not `None`, this will be filled with a dict mapping feature columns to output tensors created. + training: Python boolean or None, indicating whether to the layer is being + run in training mode. This argument is passed to the call method of any + `FeatureColumn` that takes a `training` argument. For example, if a + `FeatureColumn` performed dropout, the column could expose a `training` + argument to control whether the dropout should be applied. If `None`, + defaults to `tf.keras.backend.learning_phase()`. + Returns: A `Tensor` which represents input layer of a model. Its shape @@ -134,6 +152,8 @@ class DenseFeatures(fc._BaseFeaturesLayer): # pylint: disable=protected-access Raises: ValueError: If features are not a dictionary. """ + if training is None: + training = backend.learning_phase() if not isinstance(features, dict): raise ValueError('We expected a dictionary here. Instead we got: ', features) @@ -141,8 +161,12 @@ class DenseFeatures(fc._BaseFeaturesLayer): # pylint: disable=protected-access output_tensors = [] for column in self._feature_columns: with ops.name_scope(column.name): - tensor = column.get_dense_tensor(transformation_cache, - self._state_manager) + try: + tensor = column.get_dense_tensor( + transformation_cache, self._state_manager, training=training) + except TypeError: + tensor = column.get_dense_tensor(transformation_cache, + self._state_manager) processed_tensors = self._process_dense_tensor(column, tensor) if cols_to_output_tensors is not None: cols_to_output_tensors[column] = processed_tensors diff --git a/tensorflow/python/feature_column/dense_features_v2_test.py b/tensorflow/python/feature_column/dense_features_v2_test.py index d5a96081f55..71cb163a7d9 100644 --- a/tensorflow/python/feature_column/dense_features_v2_test.py +++ b/tensorflow/python/feature_column/dense_features_v2_test.py @@ -144,6 +144,43 @@ class DenseFeaturesTest(test.TestCase): self.assertAllEqual([0, 1, 2], indexed_slice.indices) self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient) + def test_dense_feature_with_training_arg(self): + price1 = fc.numeric_column('price1', shape=2) + price2 = fc.numeric_column('price2') + + # Monkey patch the second numeric column to simulate a column that has + # different behavior by mode. + def training_aware_get_dense_tensor(transformation_cache, + state_manager, + training=None): + return transformation_cache.get(price2, state_manager, training=training) + + def training_aware_transform_feature(transformation_cache, + state_manager, + training=None): + input_tensor = transformation_cache.get( + price2.key, state_manager, training=training) + if training: + return input_tensor * 10.0 + else: + return input_tensor * 20.0 + + price2.get_dense_tensor = training_aware_get_dense_tensor + price2.transform_feature = training_aware_transform_feature + with ops.Graph().as_default(): + features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]} + train_mode = df.DenseFeatures([price1, price2])(features, training=True) + predict_mode = df.DenseFeatures([price1, price2 + ])(features, training=False) + + self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllClose([[1., 2., 30.], [5., 6., 40.]], + self.evaluate(train_mode)) + self.assertAllClose([[1., 2., 60.], [5., 6., 80.]], + self.evaluate(predict_mode)) + def test_raises_if_empty_feature_columns(self): with self.assertRaisesRegexp(ValueError, 'feature_columns must not be empty'): diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index f117d0ed5ef..4003b8e1093 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -130,8 +130,8 @@ from __future__ import print_function import abc import collections import math - import re + import numpy as np import six @@ -145,9 +145,9 @@ from tensorflow.python.framework import tensor_shape # TODO(b/118385027): Dependency on keras can be problematic if Keras moves out # of the main repo. from tensorflow.python.keras import initializers -from tensorflow.python.keras.utils import generic_utils -from tensorflow.python.keras.engine import training +from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -164,13 +164,13 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import tracking -from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import deprecation from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.tf_export import tf_export _FEATURE_COLUMN_DEPRECATION_DATE = None @@ -618,7 +618,7 @@ class _LinearModelLayer(Layer): # TODO(tanzheny): Cleanup it with respect to Premade model b/132690565. -class LinearModel(training.Model): +class LinearModel(keras_training.Model): """Produces a linear prediction `Tensor` based on given `feature_columns`. This layer generates a weighted sum based on output dimension `units`. @@ -2659,7 +2659,7 @@ class FeatureTransformationCache(object): self._features = features.copy() self._feature_tensors = {} - def get(self, key, state_manager): + def get(self, key, state_manager, training=None): """Returns a `Tensor` for the given key. A `str` key is used to access a base feature (not-transformed). When a @@ -2670,6 +2670,11 @@ class FeatureTransformationCache(object): Args: key: a `str` or a `FeatureColumn`. state_manager: A StateManager object that holds the FeatureColumn state. + training: Boolean indicating whether to the column is being used in + training mode. This argument is passed to the transform_feature method + of any `FeatureColumn` that takes a `training` argument. For example, if + a `FeatureColumn` performed dropout, it could expose a `training` + argument to control whether the dropout should be applied. Returns: The transformed `Tensor` corresponding to the `key`. @@ -2696,7 +2701,15 @@ class FeatureTransformationCache(object): column = key logging.debug('Transforming feature_column %s.', column) - transformed = column.transform_feature(self, state_manager) + + # Some columns may need information about whether the transformation is + # happening in training or prediction mode, but not all columns expose this + # argument. + try: + transformed = column.transform_feature( + self, state_manager, training=training) + except TypeError: + transformed = column.transform_feature(self, state_manager) if transformed is None: raise ValueError('Column {} is not supported.'.format(column.name)) self._feature_tensors[column] = transformed diff --git a/tensorflow/python/feature_column/sequence_feature_column.py b/tensorflow/python/feature_column/sequence_feature_column.py index 0efff8ba386..25f2021e7e7 100644 --- a/tensorflow/python/feature_column/sequence_feature_column.py +++ b/tensorflow/python/feature_column/sequence_feature_column.py @@ -30,6 +30,7 @@ from tensorflow.python.feature_column import utils as fc_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import backend from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import parsing_ops @@ -59,6 +60,9 @@ class SequenceFeatures(fc._BaseFeaturesLayer): Example: ```python + # Behavior of some cells or feature columns may depend on whether we are in + # training or inference mode, e.g. applying dropout. + training = True rating = sequence_numeric_column('rating') watches = sequence_categorical_column_with_identity( 'watches', num_buckets=1000) @@ -68,11 +72,12 @@ class SequenceFeatures(fc._BaseFeaturesLayer): sequence_input_layer = SequenceFeatures(columns) features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) - sequence_input, sequence_length = sequence_input_layer(features) + sequence_input, sequence_length = sequence_input_layer( + features, training=training) sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) - rnn_layer = tf.keras.layers.RNN(rnn_cell) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size, training=training) + rnn_layer = tf.keras.layers.RNN(rnn_cell, training=training) outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` """ @@ -112,11 +117,18 @@ class SequenceFeatures(fc._BaseFeaturesLayer): def _target_shape(self, input_shape, total_elements): return (input_shape[0], input_shape[1], total_elements) - def call(self, features): + def call(self, features, training=None): """Returns sequence input corresponding to the `feature_columns`. Args: features: A dict mapping keys to tensors. + training: Python boolean or None, indicating whether to the layer is being + run in training mode. This argument is passed to the call method of any + `FeatureColumn` that takes a `training` argument. For example, if a + `FeatureColumn` performed dropout, the column could expose a `training` + argument to control whether the dropout should be applied. If `None`, + defaults to `tf.keras.backend.learning_phase()`. + Returns: An `(input_layer, sequence_length)` tuple where: @@ -133,14 +145,20 @@ class SequenceFeatures(fc._BaseFeaturesLayer): if not isinstance(features, dict): raise ValueError('We expected a dictionary here. Instead we got: ', features) + if training is None: + training = backend.learning_phase() transformation_cache = fc.FeatureTransformationCache(features) output_tensors = [] sequence_lengths = [] for column in self._feature_columns: with ops.name_scope(column.name): - dense_tensor, sequence_length = column.get_sequence_dense_tensor( - transformation_cache, self._state_manager) + try: + dense_tensor, sequence_length = column.get_sequence_dense_tensor( + transformation_cache, self._state_manager, training=training) + except TypeError: + dense_tensor, sequence_length = column.get_sequence_dense_tensor( + transformation_cache, self._state_manager) # Flattens the final dimension to produce a 3D Tensor. output_tensors.append(self._process_dense_tensor(column, dense_tensor)) sequence_lengths.append(sequence_length) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt index ef1f523ecab..0b94554a7bb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt @@ -146,7 +146,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'features\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'features\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt index 260fa5cfcd6..8550078d41e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt @@ -146,7 +146,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'features\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'features\', \'cols_to_output_tensors\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt index ef1f523ecab..0b94554a7bb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt @@ -146,7 +146,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'features\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'features\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt index 380304c7777..f010cd09dfb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt @@ -147,7 +147,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'features\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'features\', \'cols_to_output_tensors\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compute_mask" From 9df9b1a1f332a5c4654050834d4ddddf34209a63 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 12:49:05 -0700 Subject: [PATCH 339/492] Use static_cast to convert bfloat16 to float instead of relying on implicit casting in argument to float ctor. PiperOrigin-RevId: 302081696 Change-Id: I089523d7e55cf8ede4297393b701377ebf6b3e73 --- tensorflow/core/framework/bfloat16_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc index 8e780251498..db8590fef58 100644 --- a/tensorflow/core/framework/bfloat16_test.cc +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -36,14 +36,14 @@ TEST(Bfloat16Test, FlushDenormalsToZero) { denorm < std::numeric_limits::denorm_min(); denorm = std::nextafterf(denorm, 1.0f)) { bfloat16 bf_trunc = bfloat16::truncate_to_bfloat16(denorm); - ASSERT_EQ(float{bf_trunc}, 0.0f); + ASSERT_EQ(static_cast(bf_trunc), 0.0f); if (std::signbit(denorm)) { ASSERT_EQ(bf_trunc.value, 0x8000) << denorm; } else { ASSERT_EQ(bf_trunc.value, 0x0000) << denorm; } bfloat16 bf_round = bfloat16::round_to_bfloat16(denorm); - ASSERT_EQ(float{bf_round}, 0.0f); + ASSERT_EQ(static_cast(bf_round), 0.0f); if (std::signbit(denorm)) { ASSERT_EQ(bf_round.value, 0x8000) << denorm; } else { From 38ca061ed3c2c6e9e32dec0654c77ca890a1d734 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Fri, 20 Mar 2020 12:58:57 -0700 Subject: [PATCH 340/492] Fix a check in tfcompile codegen. Previously, tfcompile expects the compiler to translate each resource variable into a function argument. The MLIR bridge doesn't generated a function argument for an unused resource variable. Modify an existing test to test the situation. PiperOrigin-RevId: 302083482 Change-Id: I08301e594422f655b8d4ba4bb66d69103764cc7f --- tensorflow/compiler/aot/codegen.cc | 4 +++- tensorflow/compiler/aot/tests/make_test_graphs.py | 1 + .../aot/tests/test_graph_tfvariable_readonly.config.pbtxt | 8 ++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 4a4fec5a386..c9a36b88795 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -170,7 +170,9 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (config.feed_size() + config.variable_size() != num_args) { + // feed_size() + variable_size() is the maximum number of args as an + // implementation may not create an argument for an unused variable. + if (config.feed_size() + config.variable_size() < num_args) { return errors::InvalidArgument( "mismatch between feed_size(", config.feed_size(), ")+variable_size(", config.variable_size(), ") and num_args(", num_args, ")"); diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 629239d6e4a..532d64c5a3e 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -157,6 +157,7 @@ def tftop_k(_): def tfvariable_readonly(_): x = variables.Variable(1000.0, name='x') + unused_y = variables.Variable(1000.0, name='y') old_x = x.value() with ops.control_dependencies([old_x]): new_value = math_ops.add(old_x, 42.0) diff --git a/tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.config.pbtxt index b615b8f1522..dd2d0399451 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.config.pbtxt @@ -10,3 +10,11 @@ variable { type: DT_FLOAT readonly: true } + +variable { + node_name: "y" + shape { + } + type: DT_FLOAT + readonly: true +} From f023e26d903eef3f9814a5b5bd93d2e5ed77ede9 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Fri, 20 Mar 2020 13:11:59 -0700 Subject: [PATCH 341/492] [tf.data service] Add test utils for testing against dataset GraphDefs. PiperOrigin-RevId: 302086003 Change-Id: I9e8ffcabe7d3fab87559deaceb2d795bed336585 --- tensorflow/core/data/BUILD | 2 +- tensorflow/core/data/service/BUILD | 30 +++ tensorflow/core/data/service/test_util.cc | 58 +++++ tensorflow/core/data/service/test_util.h | 44 ++++ .../core/data/service/test_util_test.cc | 57 +++++ .../data/service/testdata/map_graph_def.pbtxt | 225 ++++++++++++++++++ 6 files changed, 415 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/data/service/test_util.cc create mode 100644 tensorflow/core/data/service/test_util.h create mode 100644 tensorflow/core/data/service/test_util_test.cc create mode 100644 tensorflow/core/data/service/testdata/map_graph_def.pbtxt diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 5170bb27498..9c58be108fc 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -15,6 +15,7 @@ cc_library( srcs = ["standalone.cc"], hdrs = ["standalone.h"], deps = [ + "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -28,7 +29,6 @@ tf_cc_test( srcs = ["standalone_test.cc"], deps = [ ":standalone", - "//tensorflow/core:all_kernels", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + tf_protos_all(), diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index d23791e510a..b597fd70add 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -101,6 +101,36 @@ tf_cc_test( ], ) +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = [ + "test_util.h", + ], + data = glob(["testdata/*.pbtxt"]), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/framework:protos_all_cc", + "//tensorflow/core/kernels/data:dataset_test_base", + ], +) + +tf_cc_test( + name = "test_util_test", + srcs = ["test_util_test.cc"], + deps = [ + ":test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/data:standalone", + "//tensorflow/core/kernels/data:dataset_test_base", + ], +) + cc_grpc_library( name = "master_cc_grpc_proto", srcs = [":master_proto"], diff --git a/tensorflow/core/data/service/test_util.cc b/tensorflow/core/data/service/test_util.cc new file mode 100644 index 00000000000..1c8c3c21827 --- /dev/null +++ b/tensorflow/core/data/service/test_util.cc @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/test_util.h" + +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" + +namespace tensorflow { +namespace data { +namespace test_util { + +namespace { +constexpr char kTestdataDir[] = + "tensorflow/core/data/service/testdata"; + +// Proto content generated by +// +// import tensorflow.compat.v2 as tf +// tf.enable_v2_behavior() +// +// ds = tf.data.Dataset.range(10) +// ds = ds.map(lambda x: x*x) +// g = tf.compat.v1.GraphDef() +// g.ParseFromString(ds._as_serialized_graph().numpy()) +// print(g) +constexpr char kMapGraphDefFile[] = "map_graph_def.pbtxt"; +} // namespace + +Status map_test_case(GraphDefTestCase* test_case) { + std::string filepath = io::JoinPath(kTestdataDir, kMapGraphDefFile); + GraphDef graph_def; + TF_RETURN_IF_ERROR(ReadTextProto(Env::Default(), filepath, &graph_def)); + int num_elements = 10; + std::vector> outputs(num_elements); + for (int i = 0; i < num_elements; ++i) { + outputs[i] = CreateTensors(TensorShape{}, {{i * i}}); + } + *test_case = {"MapGraph", graph_def, outputs}; + return Status::OK(); +} + +} // namespace test_util +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/test_util.h b/tensorflow/core/data/service/test_util.h new file mode 100644 index 00000000000..a6b4514dd01 --- /dev/null +++ b/tensorflow/core/data/service/test_util.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { +namespace test_util { + +struct GraphDefTestCase { + // Name for the test case. + string name; + // A dataset graph. + GraphDef graph_def; + // The expected output from iterating over the dataset represented by the + // graph. + std::vector> output; +}; + +// Fills in the input test_case pointer with test case data representing the +// dataset tf.data.Dataset.range(10).map(lambda x: x*x). Useful for testing +// dataset graph execution. +Status map_test_case(GraphDefTestCase* test_case); + +} // namespace test_util +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ diff --git a/tensorflow/core/data/service/test_util_test.cc b/tensorflow/core/data/service/test_util_test.cc new file mode 100644 index 00000000000..1bd5ab66afa --- /dev/null +++ b/tensorflow/core/data/service/test_util_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/test_util.h" + +#include "tensorflow/core/data/standalone.h" +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace test_util { + +TEST(TestUtil, MapTestCase) { + GraphDefTestCase test_case; + TF_ASSERT_OK(map_test_case(&test_case)); + standalone::Dataset::Params params; + std::unique_ptr dataset; + TF_ASSERT_OK( + standalone::Dataset::FromGraph(params, test_case.graph_def, &dataset)); + + std::unique_ptr iterator; + TF_ASSERT_OK(dataset->MakeIterator(&iterator)); + + bool end_of_input = false; + + std::vector> result; + while (!end_of_input) { + std::vector outputs; + TF_ASSERT_OK(iterator->GetNext(&outputs, &end_of_input)); + if (!end_of_input) { + result.push_back(outputs); + } + } + ASSERT_EQ(result.size(), test_case.output.size()); + for (int i = 0; i < result.size(); ++i) { + TF_EXPECT_OK(DatasetOpsTestBase::ExpectEqual(result[i], test_case.output[i], + /*compare_order=*/true)); + } +} + +} // namespace test_util +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/testdata/map_graph_def.pbtxt b/tensorflow/core/data/service/testdata/map_graph_def.pbtxt new file mode 100644 index 00000000000..6bd813febd4 --- /dev/null +++ b/tensorflow/core/data/service/testdata/map_graph_def.pbtxt @@ -0,0 +1,225 @@ +node { + name: "Const/_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } +} +node { + name: "Const/_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 10 + } + } + } +} +node { + name: "Const/_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } +} +node { + name: "RangeDataset/_3" + op: "RangeDataset" + input: "Const/_0" + input: "Const/_1" + input: "Const/_2" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } +} +node { + name: "MapDataset/_4" + op: "MapDataset" + input: "RangeDataset/_3" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "__inference_Dataset_map_lambda_9" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "preserve_cardinality" + value { + b: true + } + } + attr { + key: "use_inter_op_parallelism" + value { + b: true + } + } +} +node { + name: "dataset" + op: "_Retval" + input: "MapDataset/_4" + attr { + key: "T" + value { + type: DT_VARIANT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { + function { + signature { + name: "__inference_Dataset_map_lambda_9" + input_arg { + name: "args_0" + type: DT_INT64 + } + output_arg { + name: "identity" + type: DT_INT64 + } + } + node_def { + name: "mul" + op: "Mul" + input: "args_0" + input: "args_0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + experimental_debug_info { + original_node_names: "mul" + } + } + node_def { + name: "Identity" + op: "Identity" + input: "mul:z:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + experimental_debug_info { + original_node_names: "Identity" + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "args_0" + } + } + } + } + } +} +versions { + producer: 341 + min_consumer: 12 +} From 16b5232a803abcb55a30fa8326edbe279aca2a1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 13:14:36 -0700 Subject: [PATCH 342/492] Qualify uses of std::string PiperOrigin-RevId: 302086438 Change-Id: Idf09a24e44bda8f986f437f3dab4181042969796 --- tensorflow/core/platform/env.h | 34 +++++++++--------- tensorflow/core/platform/errors.h | 27 +++++++------- tensorflow/core/platform/file_system.h | 17 ++++----- tensorflow/core/platform/logging.h | 2 +- tensorflow/core/platform/numbers.h | 10 +++--- tensorflow/core/platform/protobuf.h | 14 +++++--- tensorflow/core/platform/status.h | 8 ++--- tensorflow/core/platform/str_util.h | 22 ++++++------ tensorflow/core/platform/strcat.h | 36 +++++++++---------- tensorflow/core/platform/tensor_coding.h | 16 ++++----- tensorflow/stream_executor/cuda/cuda_blas.cc | 4 +-- .../stream_executor/cuda/cuda_diagnostics.cc | 28 +++++++-------- .../stream_executor/cuda/cuda_diagnostics.h | 6 ++-- tensorflow/stream_executor/cuda/cuda_dnn.cc | 4 +-- .../stream_executor/cuda/cuda_driver.cc | 19 +++++----- .../stream_executor/cuda/cuda_gpu_executor.cc | 35 +++++++++--------- .../stream_executor/cuda/cuda_platform.cc | 2 +- .../stream_executor/cuda/cuda_platform.h | 4 +-- .../stream_executor/gpu/asm_compiler.cc | 24 ++++++------- .../stream_executor/gpu/gpu_diagnostics.h | 4 +-- tensorflow/stream_executor/gpu/gpu_driver.h | 6 ++-- tensorflow/stream_executor/gpu/gpu_executor.h | 8 ++--- tensorflow/stream_executor/gpu/gpu_rng.h | 10 +++--- .../stream_executor/gpu/redzone_allocator.h | 2 +- 24 files changed, 177 insertions(+), 165 deletions(-) diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index ab1598b22f1..7b617c0231f 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -173,7 +173,8 @@ class Env { /// \brief Returns true if the path matches the given pattern. The wildcards /// allowed in pattern are described in FileSystem::GetMatchingPaths. - virtual bool MatchPath(const string& path, const string& pattern) = 0; + virtual bool MatchPath(const std::string& path, + const std::string& pattern) = 0; /// \brief Given a pattern, stores in *results the set of paths that matches /// that pattern. *results is cleared. @@ -264,18 +265,18 @@ class Env { /// \brief Returns the absolute path of the current executable. It resolves /// symlinks if there is any. - string GetExecutablePath(); + std::string GetExecutablePath(); /// Creates a local unique temporary file name. Returns true if success. - bool LocalTempFilename(string* filename); + bool LocalTempFilename(std::string* filename); /// Creates a local unique file name that starts with |prefix| and ends with /// |suffix|. Returns true if success. - bool CreateUniqueFileName(string* prefix, const string& suffix); + bool CreateUniqueFileName(std::string* prefix, const std::string& suffix); /// \brief Return the runfiles directory if running under bazel. Returns /// the directory the executable is located in if not running under bazel. - virtual string GetRunfilesDir() = 0; + virtual std::string GetRunfilesDir() = 0; // TODO(jeff,sanjay): Add back thread/thread-pool support if needed. // TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or @@ -299,7 +300,7 @@ class Env { /// Caller takes ownership of the result and must delete it eventually /// (the deletion will block until fn() stops running). virtual Thread* StartThread(const ThreadOptions& thread_options, - const string& name, + const std::string& name, std::function fn) TF_MUST_USE_RESULT = 0; // Returns the thread id of calling thread. @@ -309,7 +310,7 @@ class Env { virtual int32 GetCurrentThreadId() = 0; // Copies current thread name to "name". Returns true if success. - virtual bool GetCurrentThreadName(string* name) = 0; + virtual bool GetCurrentThreadName(std::string* name) = 0; // \brief Schedules the given closure on a thread-pool. // @@ -349,8 +350,8 @@ class Env { // "name" should be name of the library. // "version" should be the version of the library or NULL // returns the name that LoadLibrary() can use - virtual string FormatLibraryFileName(const string& name, - const string& version) = 0; + virtual std::string FormatLibraryFileName(const std::string& name, + const std::string& version) = 0; // Returns a possible list of local temporary directories. virtual void GetLocalTempDirectories(std::vector* list) = 0; @@ -387,7 +388,7 @@ class EnvWrapper : public Env { return target_->RegisterFileSystem(scheme, factory); } - bool MatchPath(const string& path, const string& pattern) override { + bool MatchPath(const std::string& path, const std::string& pattern) override { return target_->MatchPath(path, pattern); } @@ -395,12 +396,13 @@ class EnvWrapper : public Env { void SleepForMicroseconds(int64 micros) override { target_->SleepForMicroseconds(micros); } - Thread* StartThread(const ThreadOptions& thread_options, const string& name, + Thread* StartThread(const ThreadOptions& thread_options, + const std::string& name, std::function fn) override { return target_->StartThread(thread_options, name, fn); } int32 GetCurrentThreadId() override { return target_->GetCurrentThreadId(); } - bool GetCurrentThreadName(string* name) override { + bool GetCurrentThreadName(std::string* name) override { return target_->GetCurrentThreadName(name); } void SchedClosure(std::function closure) override { @@ -416,12 +418,12 @@ class EnvWrapper : public Env { void** symbol) override { return target_->GetSymbolFromLibrary(handle, symbol_name, symbol); } - string FormatLibraryFileName(const string& name, - const string& version) override { + std::string FormatLibraryFileName(const std::string& name, + const std::string& version) override { return target_->FormatLibraryFileName(name, version); } - string GetRunfilesDir() override { return target_->GetRunfilesDir(); } + std::string GetRunfilesDir() override { return target_->GetRunfilesDir(); } private: void GetLocalTempDirectories(std::vector* list) override { @@ -520,7 +522,7 @@ namespace register_file_system { template struct Register { - Register(Env* env, const string& scheme) { + Register(Env* env, const std::string& scheme) { // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! env->RegisterFileSystem(scheme, []() -> FileSystem* { return new Factory; }) .IgnoreError(); diff --git a/tensorflow/core/platform/errors.h b/tensorflow/core/platform/errors.h index 3250a2f762f..3f1ff477655 100644 --- a/tensorflow/core/platform/errors.h +++ b/tensorflow/core/platform/errors.h @@ -45,7 +45,7 @@ namespace internal { // able to completely remove PrepareForStrCat(). template typename std::enable_if::value, - string>::type + std::string>::type PrepareForStrCat(const T& t) { std::stringstream ss; ss << t; @@ -126,29 +126,32 @@ DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED) // Note: The pattern below determines the regex _NODEDEF_NAME_RE in the file // tensorflow/python/client/session.py // LINT.IfChange -inline string FormatNodeNameForError(const string& name) { +inline std::string FormatNodeNameForError(const std::string& name) { return strings::StrCat("{{node ", name, "}}"); } // LINT.ThenChange(//tensorflow/python/client/session.py) template -string FormatNodeNamesForError(const T& names) { - return absl::StrJoin(names, ", ", [](string* output, const string& s) { - ::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s)); - }); +std::string FormatNodeNamesForError(const T& names) { + return absl::StrJoin( + names, ", ", [](std::string* output, const std::string& s) { + ::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s)); + }); } // LINT.IfChange -inline string FormatColocationNodeForError(const string& name) { +inline std::string FormatColocationNodeForError(const std::string& name) { return strings::StrCat("{{colocation_node ", name, "}}"); } // LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) template -string FormatColocationNodeForError(const T& names) { - return absl::StrJoin(names, ", ", [](string* output, const string& s) { - ::tensorflow::strings::StrAppend(output, FormatColocationNodeForError(s)); - }); +std::string FormatColocationNodeForError(const T& names) { + return absl::StrJoin(names, ", ", + [](std::string* output, const std::string& s) { + ::tensorflow::strings::StrAppend( + output, FormatColocationNodeForError(s)); + }); } -inline string FormatFunctionForError(const string& name) { +inline std::string FormatFunctionForError(const std::string& name) { return strings::StrCat("{{function_node ", name, "}}"); } diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index 3ab9618e371..640d3b3c027 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -151,7 +151,7 @@ class FileSystem { /// This function provides the equivalent of posix fnmatch, however it is /// implemented without fnmatch to ensure that this can be used for cloud /// filesystems on windows. For windows filesystems, it uses PathMatchSpec. - virtual bool Match(const string& filename, const string& pattern); + virtual bool Match(const std::string& filename, const std::string& pattern); /// \brief Obtains statistics for the given path. virtual tensorflow::Status Stat(const string& fname, @@ -225,7 +225,7 @@ class FileSystem { /// invoke any system calls (getcwd(2)) in order to resolve relative /// paths with respect to the actual working directory. That is, this is /// purely string manipulation, completely independent of process state. - virtual string TranslateName(const string& name) const; + virtual std::string TranslateName(const std::string& name) const; /// \brief Returns whether the given path is a directory or not. /// @@ -288,16 +288,16 @@ class FileSystem { /// invoke any system calls (getcwd(2)) in order to resolve relative /// paths with respect to the actual working directory. That is, this is /// purely string manipulation, completely independent of process state. - string CleanPath(StringPiece path) const; + std::string CleanPath(StringPiece path) const; /// \brief Creates a URI from a scheme, host, and path. /// /// If the scheme is empty, we just return the path. - string CreateURI(StringPiece scheme, StringPiece host, - StringPiece path) const; + std::string CreateURI(StringPiece scheme, StringPiece host, + StringPiece path) const; /// \brief Creates a temporary file name with an extension. - string GetTempFilename(const string& extension) const; + std::string GetTempFilename(const std::string& extension) const; /// \brief Return true if path is absolute. bool IsAbsolutePath(tensorflow::StringPiece path) const; @@ -319,12 +319,13 @@ class FileSystem { /// string path = io::JoinPath(FLAGS_test_srcdir, filename); /// string path = io::JoinPath("/full", "path", "to", "filename"); template - string JoinPath(const T&... args) { + std::string JoinPath(const T&... args) { return JoinPathImpl({args...}); } #endif /* SWIG */ - string JoinPathImpl(std::initializer_list paths); + std::string JoinPathImpl( + std::initializer_list paths); /// \brief Populates the scheme, host, and path from a URI. /// diff --git a/tensorflow/core/platform/logging.h b/tensorflow/core/platform/logging.h index 6fa50d3caa2..20f9cd9bdbf 100644 --- a/tensorflow/core/platform/logging.h +++ b/tensorflow/core/platform/logging.h @@ -33,7 +33,7 @@ namespace internal { // Emit "message" as a log message to the log for the specified // "severity" as if it came from a LOG call at "fname:line" void LogString(const char* fname, int line, int severity, - const string& message); + const std::string& message); } // namespace internal } // namespace tensorflow diff --git a/tensorflow/core/platform/numbers.h b/tensorflow/core/platform/numbers.h index 9d16dc554fa..880e786e8b8 100644 --- a/tensorflow/core/platform/numbers.h +++ b/tensorflow/core/platform/numbers.h @@ -74,12 +74,12 @@ size_t DoubleToBuffer(double value, char* buffer); size_t FloatToBuffer(float value, char* buffer); // Convert a 64-bit fingerprint value to an ASCII representation. -string FpToString(Fprint fp); +std::string FpToString(Fprint fp); // Attempt to parse a fingerprint in the form encoded by FpToString. If // successful, stores the fingerprint in *fp and returns true. Otherwise, // returns false. -bool StringToFp(const string& s, Fprint* fp); +bool StringToFp(const std::string& s, Fprint* fp); // Convert a 64-bit fingerprint value to an ASCII representation that // is terminated by a '\0'. @@ -157,12 +157,12 @@ bool SafeStringToNumeric(StringPiece s, T* value) { // Converts from an int64 to a human readable string representing the // same number, using decimal powers. e.g. 1200000 -> "1.20M". -string HumanReadableNum(int64 value); +std::string HumanReadableNum(int64 value); // Converts from an int64 representing a number of bytes to a // human readable string representing the same number. // e.g. 12345678 -> "11.77MiB". -string HumanReadableNumBytes(int64 num_bytes); +std::string HumanReadableNumBytes(int64 num_bytes); // Converts a time interval as double to a human readable // string. For example: @@ -171,7 +171,7 @@ string HumanReadableNumBytes(int64 num_bytes); // 933120.0 -> "10.8 days" // 39420000.0 -> "1.25 years" // -10 -> "-10 s" -string HumanReadableElapsedTime(double seconds); +std::string HumanReadableElapsedTime(double seconds); } // namespace strings } // namespace tensorflow diff --git a/tensorflow/core/platform/protobuf.h b/tensorflow/core/platform/protobuf.h index 8c972d3aee0..2422aacd5f6 100644 --- a/tensorflow/core/platform/protobuf.h +++ b/tensorflow/core/platform/protobuf.h @@ -55,7 +55,7 @@ extern const char* kProtobufUint64Typename; // this function has no size restrictions on the total size of the encoded // protocol buffer. bool ParseProtoUnlimited(protobuf::MessageLite* proto, - const string& serialized); + const std::string& serialized); bool ParseProtoUnlimited(protobuf::MessageLite* proto, const void* serialized, size_t size); inline bool ParseProtoUnlimited(protobuf::MessageLite* proto, @@ -64,11 +64,13 @@ inline bool ParseProtoUnlimited(protobuf::MessageLite* proto, } // Returns the string value for the value of a string or bytes protobuf field. -inline const string& ProtobufStringToString(const string& s) { return s; } +inline const std::string& ProtobufStringToString(const std::string& s) { + return s; +} // Set to . Swapping is allowed, as does not need to be // preserved. -inline void SetProtobufStringSwapAllowed(string* src, string* dest) { +inline void SetProtobufStringSwapAllowed(std::string* src, std::string* dest) { *dest = std::move(*src); } @@ -77,8 +79,10 @@ inline void SetProtobufStringSwapAllowed(string* src, string* dest) { // tools/proto_text's generated code. They have the same name as the versions // in core/platform/protobuf.h, so the generation code doesn't need to determine // if the type is Cord or string at generation time. -inline string ProtobufStringToString(const Cord& s) { return s.ToString(); } -inline void SetProtobufStringSwapAllowed(string* src, Cord* dest) { +inline std::string ProtobufStringToString(const Cord& s) { + return s.ToString(); +} +inline void SetProtobufStringSwapAllowed(std::string* src, Cord* dest) { dest->CopyFrom(*src); } #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) diff --git a/tensorflow/core/platform/status.h b/tensorflow/core/platform/status.h index b9763a1dc71..c3ce61d37bb 100644 --- a/tensorflow/core/platform/status.h +++ b/tensorflow/core/platform/status.h @@ -62,7 +62,7 @@ class Status { return ok() ? tensorflow::error::OK : state_->code; } - const string& error_message() const { + const std::string& error_message() const { return ok() ? empty_string() : state_->msg; } @@ -82,7 +82,7 @@ class Status { /// \brief Return a string representation of this status suitable for /// printing. Returns the string `"OK"` for success. - string ToString() const; + std::string ToString() const; // Ignores any errors. This method does nothing except potentially suppress // complaints from any tools that are checking that errors are not dropped on @@ -90,10 +90,10 @@ class Status { void IgnoreError() const; private: - static const string& empty_string(); + static const std::string& empty_string(); struct State { tensorflow::error::Code code; - string msg; + std::string msg; }; // OK status has a `NULL` state_. Otherwise, `state_` points to // a `State` structure containing the error code and message(s) diff --git a/tensorflow/core/platform/str_util.h b/tensorflow/core/platform/str_util.h index 84ced441199..56d020f52e2 100644 --- a/tensorflow/core/platform/str_util.h +++ b/tensorflow/core/platform/str_util.h @@ -31,7 +31,7 @@ namespace str_util { // Returns a version of 'src' where unprintable characters have been // escaped using C-style escape sequences. -string CEscape(StringPiece src); +std::string CEscape(StringPiece src); // Copies "source" to "dest", rewriting C-style escape sequences -- // '\n', '\r', '\\', '\ooo', etc -- to their ASCII equivalents. @@ -40,10 +40,10 @@ string CEscape(StringPiece src); // 'error'. To disable error reporting, set 'error' to NULL. // // NOTE: Does not support \u or \U! -bool CUnescape(StringPiece source, string* dest, string* error); +bool CUnescape(StringPiece source, std::string* dest, std::string* error); // Removes any trailing whitespace from "*s". -void StripTrailingWhitespace(string* s); +void StripTrailingWhitespace(std::string* s); // Removes leading ascii_isspace() characters. // Returns number of characters removed. @@ -87,23 +87,23 @@ TF_MUST_USE_RESULT StringPiece StripPrefix(StringPiece s, StringPiece expected); TF_MUST_USE_RESULT StringPiece StripSuffix(StringPiece s, StringPiece expected); // Return lower-cased version of s. -string Lowercase(StringPiece s); +std::string Lowercase(StringPiece s); // Return upper-cased version of s. -string Uppercase(StringPiece s); +std::string Uppercase(StringPiece s); // Capitalize first character of each word in "*s". "delimiters" is a // set of characters that can be used as word boundaries. -void TitlecaseString(string* s, StringPiece delimiters); +void TitlecaseString(std::string* s, StringPiece delimiters); // Replaces the first occurrence (if replace_all is false) or all occurrences // (if replace_all is true) of oldsub in s with newsub. -string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub, - bool replace_all); +std::string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub, + bool replace_all); // Join functionality template -string Join(const T& s, const char* sep) { +std::string Join(const T& s, const char* sep) { return absl::StrJoin(s, sep); } @@ -111,7 +111,7 @@ string Join(const T& s, const char* sep) { // is invoked (f is often constructed with a lambda of the form: // [](string* result, ElemType elem) template -string Join(const T& s, const char* sep, Formatter f) { +std::string Join(const T& s, const char* sep, Formatter f) { return absl::StrJoin(s, sep, f); } @@ -179,7 +179,7 @@ size_t Strnlen(const char* str, const size_t string_max_len); // This method is useful for producing strings matching "[a-z][a-z0-9_]*" // as required by OpDef.ArgDef.name. The resulting string is either empty or // matches this regex. -string ArgDefCase(StringPiece s); +std::string ArgDefCase(StringPiece s); } // namespace str_util } // namespace tensorflow diff --git a/tensorflow/core/platform/strcat.h b/tensorflow/core/platform/strcat.h index 640355c9ea5..3569a86ab33 100644 --- a/tensorflow/core/platform/strcat.h +++ b/tensorflow/core/platform/strcat.h @@ -168,30 +168,30 @@ class AlphaNum { // ---------------------------------------------------------------------- // For performance reasons, we have specializations for <= 4 args. -string StrCat(const AlphaNum &a) TF_MUST_USE_RESULT; -string StrCat(const AlphaNum &a, const AlphaNum &b) TF_MUST_USE_RESULT; -string StrCat(const AlphaNum &a, const AlphaNum &b, - const AlphaNum &c) TF_MUST_USE_RESULT; -string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, - const AlphaNum &d) TF_MUST_USE_RESULT; +std::string StrCat(const AlphaNum &a) TF_MUST_USE_RESULT; +std::string StrCat(const AlphaNum &a, const AlphaNum &b) TF_MUST_USE_RESULT; +std::string StrCat(const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c) TF_MUST_USE_RESULT; +std::string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d) TF_MUST_USE_RESULT; namespace internal { // Do not call directly - this is not part of the public API. -string CatPieces(std::initializer_list pieces); -void AppendPieces(string *dest, std::initializer_list pieces); +std::string CatPieces(std::initializer_list pieces); +void AppendPieces(std::string *dest, std::initializer_list pieces); } // namespace internal // Support 5 or more arguments template -string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, - const AlphaNum &d, const AlphaNum &e, - const AV &... args) TF_MUST_USE_RESULT; +std::string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d, const AlphaNum &e, + const AV &... args) TF_MUST_USE_RESULT; template -string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, - const AlphaNum &d, const AlphaNum &e, const AV &... args) { +std::string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d, const AlphaNum &e, const AV &... args) { return internal::CatPieces({a.Piece(), b.Piece(), c.Piece(), d.Piece(), e.Piece(), static_cast(args).Piece()...}); @@ -218,16 +218,16 @@ string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, // worked around as consecutive calls to StrAppend are quite efficient. // ---------------------------------------------------------------------- -void StrAppend(string *dest, const AlphaNum &a); -void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b); -void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, +void StrAppend(std::string *dest, const AlphaNum &a); +void StrAppend(std::string *dest, const AlphaNum &a, const AlphaNum &b); +void StrAppend(std::string *dest, const AlphaNum &a, const AlphaNum &b, const AlphaNum &c); -void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, +void StrAppend(std::string *dest, const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, const AlphaNum &d); // Support 5 or more arguments template -inline void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, +inline void StrAppend(std::string *dest, const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, const AlphaNum &d, const AlphaNum &e, const AV &... args) { internal::AppendPieces(dest, diff --git a/tensorflow/core/platform/tensor_coding.h b/tensorflow/core/platform/tensor_coding.h index fcfa5469e18..010f9f11de7 100644 --- a/tensorflow/core/platform/tensor_coding.h +++ b/tensorflow/core/platform/tensor_coding.h @@ -31,31 +31,31 @@ namespace port { // Store src contents in *out. If backing memory for src is shared with *out, // will ref obj during the call and will arrange to unref obj when no // longer needed. -void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out); +void AssignRefCounted(StringPiece src, core::RefCounted* obj, std::string* out); // Copy contents of src to dst[0,src.size()-1]. -inline void CopyToArray(const string& src, char* dst) { +inline void CopyToArray(const std::string& src, char* dst) { memcpy(dst, src.data(), src.size()); } // Copy subrange [pos:(pos + n)) from src to dst. If pos >= src.size() the // result is empty. If pos + n > src.size() the subrange [pos, size()) is // copied. -inline void CopySubrangeToArray(const string& src, size_t pos, size_t n, +inline void CopySubrangeToArray(const std::string& src, size_t pos, size_t n, char* dst) { if (pos >= src.size()) return; memcpy(dst, src.data() + pos, std::min(n, src.size() - pos)); } // Store encoding of strings[0..n-1] in *out. -void EncodeStringList(const tstring* strings, int64 n, string* out); +void EncodeStringList(const tstring* strings, int64 n, std::string* out); // Decode n strings from src and store in strings[0..n-1]. // Returns true if successful, false on parse error. -bool DecodeStringList(const string& src, tstring* strings, int64 n); +bool DecodeStringList(const std::string& src, tstring* strings, int64 n); // Assigns base[0..bytes-1] to *s -void CopyFromArray(string* s, const char* base, size_t bytes); +void CopyFromArray(std::string* s, const char* base, size_t bytes); // Encodes sequences of strings and serialized protocol buffers into a string. // Normal usage consists of zero or more calls to Append() and a single call to @@ -68,7 +68,7 @@ class StringListEncoder { virtual void Append(const protobuf::MessageLite& m) = 0; // Encodes the given string. This may not be called after Finalize(). - virtual void Append(const string& s) = 0; + virtual void Append(const std::string& s) = 0; // Signals end of the encoding process. No other calls are allowed after this. virtual void Finalize() = 0; @@ -117,7 +117,7 @@ void EncodeStringList(const tstring* strings, int64 n, Cord* out); // Decode n strings from src and store in strings[0..n-1]. // Returns true if successful, false on parse error. -bool DecodeStringList(const Cord& src, string* strings, int64 n); +bool DecodeStringList(const Cord& src, std::string* strings, int64 n); bool DecodeStringList(const Cord& src, tstring* strings, int64 n); // Assigns base[0..bytes-1] to *c diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 94ddaec03ac..8fb8d032198 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -74,7 +74,7 @@ namespace gpu { PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin); -static string ToString(cublasStatus_t status) { +static std::string ToString(cublasStatus_t status) { switch (status) { case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; @@ -2803,7 +2803,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, GpuComplex(GpuMemoryMutable(b)), ldb); } -port::Status CUDABlas::GetVersion(string *version) { +port::Status CUDABlas::GetVersion(std::string *version) { absl::MutexLock lock(&mu_); int v; diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index e8ff7caa679..e2923bad9fb 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -53,12 +53,12 @@ limitations under the License. namespace stream_executor { namespace cuda { -string DriverVersionToString(DriverVersion version) { +std::string DriverVersionToString(DriverVersion version) { return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version), std::get<2>(version)); } -string DriverVersionStatusToString(port::StatusOr version) { +std::string DriverVersionStatusToString(port::StatusOr version) { if (!version.ok()) { return version.status().ToString(); } @@ -66,8 +66,8 @@ string DriverVersionStatusToString(port::StatusOr version) { return DriverVersionToString(version.ValueOrDie()); } -port::StatusOr StringToDriverVersion(const string &value) { - std::vector pieces = absl::StrSplit(value, '.'); +port::StatusOr StringToDriverVersion(const std::string &value) { + std::vector pieces = absl::StrSplit(value, '.'); if (pieces.size() < 2 || pieces.size() > 4) { return port::Status( port::error::INVALID_ARGUMENT, @@ -122,7 +122,7 @@ static const char *kDriverVersionPath = "/proc/driver/nvidia/version"; // -- class Diagnostician -string Diagnostician::GetDevNodePath(int dev_node_ordinal) { +std::string Diagnostician::GetDevNodePath(int dev_node_ordinal) { return absl::StrCat("/dev/nvidia", dev_node_ordinal); } @@ -177,10 +177,10 @@ void Diagnostician::LogDiagnosticInformation() { #ifndef PLATFORM_WINDOWS if (VLOG_IS_ON(1)) { const char *value = getenv("LD_LIBRARY_PATH"); - string library_path = value == nullptr ? "" : value; + std::string library_path = value == nullptr ? "" : value; VLOG(1) << "LD_LIBRARY_PATH is: \"" << library_path << "\""; - std::vector pieces = absl::StrSplit(library_path, ':'); + std::vector pieces = absl::StrSplit(library_path, ':'); for (const auto &piece : pieces) { if (piece.empty()) { continue; @@ -264,11 +264,11 @@ port::StatusOr Diagnostician::FindDsoVersion() { if (dot == nullptr) { return 0; } - string dso_version = dot + strlen(so_suffix); + std::string dso_version = dot + strlen(so_suffix); // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_dso_version = absl::StripSuffix(dso_version, ".ld64"); auto result = static_cast *>(data); - *result = cuda::StringToDriverVersion(string(stripped_dso_version)); + *result = cuda::StringToDriverVersion(std::string(stripped_dso_version)); return 1; } return 0; @@ -282,10 +282,10 @@ port::StatusOr Diagnostician::FindDsoVersion() { } port::StatusOr Diagnostician::FindKernelModuleVersion( - const string &driver_version_file_contents) { + const std::string &driver_version_file_contents) { static const char *kDriverFilePrelude = "Kernel Module "; size_t offset = driver_version_file_contents.find(kDriverFilePrelude); - if (offset == string::npos) { + if (offset == std::string::npos) { return port::Status( port::error::NOT_FOUND, absl::StrCat("could not find kernel module information in " @@ -293,13 +293,13 @@ port::StatusOr Diagnostician::FindKernelModuleVersion( driver_version_file_contents, "\"")); } - string version_and_rest = driver_version_file_contents.substr( - offset + strlen(kDriverFilePrelude), string::npos); + std::string version_and_rest = driver_version_file_contents.substr( + offset + strlen(kDriverFilePrelude), std::string::npos); size_t space_index = version_and_rest.find(" "); auto kernel_version = version_and_rest.substr(0, space_index); // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_kernel_version = absl::StripSuffix(kernel_version, ".ld64"); - return cuda::StringToDriverVersion(string(stripped_kernel_version)); + return cuda::StringToDriverVersion(std::string(stripped_kernel_version)); } void Diagnostician::WarnOnDsoKernelMismatch( diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.h b/tensorflow/stream_executor/cuda/cuda_diagnostics.h index 0837e136fd4..d5c3194fe05 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.h +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.h @@ -25,13 +25,13 @@ namespace cuda { using DriverVersion = gpu::DriverVersion; // Converts a parsed driver version to string form. -string DriverVersionToString(DriverVersion version); +std::string DriverVersionToString(DriverVersion version); // Converts a parsed driver version or status value to natural string form. -string DriverVersionStatusToString(port::StatusOr version); +std::string DriverVersionStatusToString(port::StatusOr version); // Converts a string of a form like "331.79" to a DriverVersion{331, 79}. -port::StatusOr StringToDriverVersion(const string& value); +port::StatusOr StringToDriverVersion(const std::string& value); using Diagnostician = gpu::Diagnostician; diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index dd511f7a976..5fec1b5990e 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -88,7 +88,7 @@ NarrowT CheckedNarrowing(const WideT& wide) { return narrow; } -string ToString(cudnnStatus_t status) { +std::string ToString(cudnnStatus_t status) { switch (status) { case CUDNN_STATUS_SUCCESS: return "CUDNN_STATUS_SUCCESS"; @@ -307,7 +307,7 @@ port::Status CudnnSupport::Init() { CudnnVersion loaded_version; TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version)); if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) { - const string error = absl::StrCat( + const std::string error = absl::StrCat( "Loaded runtime CuDNN library: ", loaded_version.ToString(), " but source was compiled with: ", source_version.ToString(), ". CuDNN library major and minor version needs to match or have " diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index c4c314b6b55..210c5436fad 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -127,7 +127,7 @@ class CreatedContexts { /* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context" // Formats CUresult to output prettified values into a log stream. -string ToString(CUresult result) { +std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { return absl::StrCat("UNKNOWN ERROR (", static_cast(result), ")"); @@ -167,7 +167,7 @@ port::ThreadPool* GetDriverExecutor() { } // namespace -string MemorySpaceString(MemorySpace memory_space) { +std::string MemorySpaceString(MemorySpace memory_space) { switch (memory_space) { case MemorySpace::kHost: return "host"; @@ -252,7 +252,7 @@ namespace { // Returns a stringified device number associated with pointer, primarily for // logging purposes. Returns "?" if the device could not be successfully // queried. -string CUDAPointerToDeviceString(CUdeviceptr pointer) { +std::string CUDAPointerToDeviceString(CUdeviceptr pointer) { auto value = GpuDriver::GetPointerDevice(pointer); if (value.ok()) { return absl::StrCat(value.ValueOrDie()); @@ -264,7 +264,7 @@ string CUDAPointerToDeviceString(CUdeviceptr pointer) { // Returns a stringified memory space associated with pointer, primarily for // logging purposes. Returns "?" if the memory space could not be successfully // queried. -string CUDAPointerToMemorySpaceString(CUdeviceptr pointer) { +std::string CUDAPointerToMemorySpaceString(CUdeviceptr pointer) { auto value = GpuDriver::GetPointerMemorySpace(pointer); if (value.ok()) { return MemorySpaceString(value.ValueOrDie()); @@ -277,7 +277,7 @@ string CUDAPointerToMemorySpaceString(CUdeviceptr pointer) { // permitted between the "from" and "to" pointers' associated contexts, // primarily for logging purposes. Returns "error" if an error is encountered // in the process of querying. -string CUDAPointersToCanAccessString(CUdeviceptr from, CUdeviceptr to) { +std::string CUDAPointersToCanAccessString(CUdeviceptr from, CUdeviceptr to) { auto from_context = GpuDriver::GetPointerContext(from); if (!from_context.ok()) { LOG(ERROR) << "could not retrieve source pointer's context: " @@ -335,7 +335,7 @@ static port::Status InternalInit() { } /* static */ port::Status GpuDriver::GetDeviceName(CUdevice device, - string* device_name) { + std::string* device_name) { static const size_t kCharLimit = 64; absl::InlinedVector chars(kCharLimit); RETURN_IF_CUDA_RES_ERROR( @@ -434,7 +434,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, return port::Status::OK(); } - string message = "failed call to cuDevicePrimaryCtxRetain: " + ToString(res); + std::string message = + "failed call to cuDevicePrimaryCtxRetain: " + ToString(res); if (res == CUDA_ERROR_OUT_OF_MEMORY) { uint64 total_memory; if (GetDeviceTotalMemory(device, &total_memory)) { @@ -1391,8 +1392,8 @@ static port::StatusOr GetSimpleAttribute(CUdevice device, return true; } -/* static */ string GpuDriver::GetPCIBusID(CUdevice device) { - string pci_bus_id; +/* static */ std::string GpuDriver::GetPCIBusID(CUdevice device) { + std::string pci_bus_id; static const int kBufferSize = 64; absl::InlinedVector chars(kBufferSize); chars[kBufferSize - 1] = '\0'; diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 44bb359d6d0..79a027f1255 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -81,7 +81,7 @@ namespace gpu { // // As this is an implementation-detail workaround, the usage is to declare this // variable with extern linkage and populate it from another translation unit. -std::function g_cubinate; +std::function g_cubinate; static GpuEvent* AsGpuEvent(Event* event) { DCHECK(event != nullptr); @@ -152,12 +152,12 @@ port::Status GpuExecutor::Init(int device_ordinal, bool GpuExecutor::FindOnDiskForComputeCapability( absl::string_view filename, absl::string_view canonical_suffix, - string* found_filename) const { + std::string* found_filename) const { if (cc_major_ == 0 && cc_minor_ == 0) { return false; } - string cc_specific = + std::string cc_specific = absl::StrCat(filename, ".cc", cc_major_, cc_minor_, canonical_suffix); if (port::FileExists(cc_specific).ok()) { VLOG(2) << "found compute-capability-specific file, using that: " @@ -168,8 +168,8 @@ bool GpuExecutor::FindOnDiskForComputeCapability( VLOG(2) << "could not find compute-capability specific file at: " << cc_specific; - if (port::FileExists(string(filename)).ok()) { - *found_filename = string(filename); + if (port::FileExists(std::string(filename)).ok()) { + *found_filename = std::string(filename); return true; } @@ -178,7 +178,7 @@ bool GpuExecutor::FindOnDiskForComputeCapability( bool GpuExecutor::FindOnDiskForISAVersion(absl::string_view filename, absl::string_view canonical_suffix, - string* found_filename) const { + std::string* found_filename) const { LOG(ERROR) << "Feature not supported on CUDA platform (FindOnDiskForISAVersion)"; return false; @@ -188,7 +188,7 @@ bool GpuExecutor::FindOnDiskForISAVersion(absl::string_view filename, // Arg: strip_exe: if true, remove the name of the executable itself from the // returned string. Example: calling this from /usr/bin/foo // would return /usr/bin. -static string GetBinaryDir(bool strip_exe) { +static std::string GetBinaryDir(bool strip_exe) { char exe_path[PATH_MAX] = {0}; #if defined(__APPLE__) uint32_t buffer_size = 0U; @@ -209,8 +209,8 @@ static string GetBinaryDir(bool strip_exe) { if (strip_exe) { // The exe is the last component of the path, so remove one component. - string ret = exe_path; - std::vector components = absl::StrSplit(exe_path, '/'); + std::string ret = exe_path; + std::vector components = absl::StrSplit(exe_path, '/'); components.pop_back(); return absl::StrJoin(components, "/"); } @@ -264,7 +264,7 @@ port::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, KernelBase* kernel) { GpuKernel* cuda_kernel = AsGpuKernel(kernel); CUmodule module; - const string *kernelname; + const std::string* kernelname; VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); @@ -857,7 +857,7 @@ bool GpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const { return GpuDriver::GetDeviceMemoryInfo(context_, free, total); } -bool GpuExecutor::GetSymbol(const string& symbol_name, +bool GpuExecutor::GetSymbol(const std::string& symbol_name, ModuleHandle module_handle, void** mem, size_t* bytes) { auto lookup_in_module = [&](CUmodule module) { @@ -937,7 +937,8 @@ GpuContext* GpuExecutor::gpu_context() { return context_; } // // For anything more complicated/prod-focused than this, you'll likely want to // turn to gsys' topology modeling. -static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) { +static int TryToReadNumaNode(const std::string& pci_bus_id, + int device_ordinal) { #if defined(__APPLE__) LOG(INFO) << "OS X does not support NUMA - returning NUMA node zero"; return 0; @@ -956,7 +957,7 @@ static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) { return kUnknownNumaNode; } - string filename = + std::string filename = absl::StrFormat("/sys/bus/pci/devices/%s/numa_node", pci_bus_id); // We have to use fopen/fread here so that the device properties can be @@ -969,7 +970,7 @@ static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) { return kUnknownNumaNode; } - string content; + std::string content; char buf[32]; size_t did_read = fread(buf, sizeof(buf[0]), sizeof(buf) - 1, file); buf[did_read] = '\0'; @@ -1017,14 +1018,14 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { { int driver_version = 0; (void)GpuDriver::GetDriverVersion(&driver_version); - string augmented_driver_version = absl::StrFormat( + std::string augmented_driver_version = absl::StrFormat( "%d (%s)", driver_version, cuda::DriverVersionStatusToString(Diagnostician::FindDsoVersion())); builder.set_driver_version(augmented_driver_version); } { - string pci_bus_id = GpuDriver::GetPCIBusID(device); + std::string pci_bus_id = GpuDriver::GetPCIBusID(device); // Lower the hex characters to match sysfs. pci_bus_id = absl::AsciiStrToLower(pci_bus_id); @@ -1090,7 +1091,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { } { - string device_name; + std::string device_name; TF_RETURN_IF_ERROR(GpuDriver::GetDeviceName(device, &device_name)); builder.set_name(device_name); } diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc index 92170b30129..c9474b7caae 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.cc +++ b/tensorflow/stream_executor/cuda/cuda_platform.cc @@ -139,7 +139,7 @@ int CudaPlatform::VisibleDeviceCount() const { return GpuDriver::GetDeviceCount(); } -const string& CudaPlatform::Name() const { return name_; } +const std::string& CudaPlatform::Name() const { return name_; } port::StatusOr> CudaPlatform::DescriptionForDevice(int ordinal) const { diff --git a/tensorflow/stream_executor/cuda/cuda_platform.h b/tensorflow/stream_executor/cuda/cuda_platform.h index 2d06ece9076..7d1c67c20be 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.h +++ b/tensorflow/stream_executor/cuda/cuda_platform.h @@ -62,7 +62,7 @@ class CudaPlatform : public Platform { // Returns -1 as a sentinel on internal failure (and logs the error). int VisibleDeviceCount() const override; - const string& Name() const override; + const std::string& Name() const override; port::StatusOr> DescriptionForDevice( int ordinal) const override; @@ -87,7 +87,7 @@ class CudaPlatform : public Platform { void InspectNumaNodes(); // This platform's name. - string name_; + std::string name_; // Cache of created executors. ExecutorCache executor_cache_; diff --git a/tensorflow/stream_executor/gpu/asm_compiler.cc b/tensorflow/stream_executor/gpu/asm_compiler.cc index 20ed4732039..6122ec85c63 100644 --- a/tensorflow/stream_executor/gpu/asm_compiler.cc +++ b/tensorflow/stream_executor/gpu/asm_compiler.cc @@ -55,10 +55,10 @@ port::StatusOr> CompileGpuAsmOrGetCached( // ptxas_path. // // Locks on entry. -static void WarnIfBadPtxasVersion(const string& ptxas_path) { +static void WarnIfBadPtxasVersion(const std::string& ptxas_path) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static std::unordered_set* seen_ptxas_paths TF_GUARDED_BY(mu) = - new std::unordered_set(); + static std::unordered_set* seen_ptxas_paths TF_GUARDED_BY(mu) = + new std::unordered_set(); tensorflow::mutex_lock lock(mu); if (!seen_ptxas_paths->insert(ptxas_path).second) { @@ -74,7 +74,7 @@ static void WarnIfBadPtxasVersion(const string& ptxas_path) { return; } - string out; + std::string out; int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out, /*stderr_output=*/nullptr); if (exit_code != 0) { @@ -84,7 +84,7 @@ static void WarnIfBadPtxasVersion(const string& ptxas_path) { } int64 vmaj, vmin, vdot; - string vmaj_str, vmin_str, vdot_str; + std::string vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, &vmin_str, &vdot_str) || !absl::SimpleAtoi(vmaj_str, &vmaj) || @@ -161,9 +161,9 @@ port::StatusOr> CompileGpuAsm(int device_ordinal, port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options) { - string ptxas_path; + std::string ptxas_path; auto env = tensorflow::Env::Default(); - for (const string& cuda_root : + for (const std::string& cuda_root : tensorflow::CandidateCudaRoots(options.preferred_cuda_dir)) { ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); VLOG(2) << "Looking for ptxas at " << ptxas_path; @@ -180,7 +180,7 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, WarnIfBadPtxasVersion(ptxas_path); // Write ptx into a temporary file. - string ptx_path; + std::string ptx_path; if (!env->LocalTempFilename(&ptx_path)) { return port::InternalError("couldn't get temp PTX file name"); } @@ -193,7 +193,7 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, }); // Invoke ptxas and collect its output. - string cubin_path; + std::string cubin_path; if (!env->LocalTempFilename(&cubin_path)) { return port::InternalError("couldn't get temp CUBIN file name"); } @@ -203,7 +203,7 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); }); tensorflow::SubProcess ptxas_info_dumper; - std::vector ptxas_args = { + std::vector ptxas_args = { ptxas_path, ptx_path, "-o", cubin_path, absl::StrCat("-arch=sm_", cc_major, cc_minor)}; if (VLOG_IS_ON(2)) { @@ -220,7 +220,7 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, if (!ptxas_info_dumper.Start()) { return port::InternalError("Failed to launch ptxas"); } - string stderr_output; + std::string stderr_output; int exit_status = ptxas_info_dumper.Communicate( /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); if (exit_status != 0) { @@ -230,7 +230,7 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, } // Read in the result of compilation and return it as a byte vector. - string cubin; + std::string cubin; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), cubin_path, &cubin)); std::vector cubin_vector(cubin.begin(), cubin.end()); diff --git a/tensorflow/stream_executor/gpu/gpu_diagnostics.h b/tensorflow/stream_executor/gpu/gpu_diagnostics.h index 71642109b57..a893cb48da1 100644 --- a/tensorflow/stream_executor/gpu/gpu_diagnostics.h +++ b/tensorflow/stream_executor/gpu/gpu_diagnostics.h @@ -60,7 +60,7 @@ class Diagnostician { // This is solely used for more informative log messages when the user is // running on a machine that happens to have a libcuda/kernel driver mismatch. static port::StatusOr FindKernelModuleVersion( - const string& driver_version_file_contents); + const std::string& driver_version_file_contents); // Extracts the kernel driver version from the current host. static port::StatusOr FindKernelDriverVersion(); @@ -88,7 +88,7 @@ class Diagnostician { // existence, permissions, accessibility from this uid/gid. static void LogDevNodeDiagnosticInformation(); - static string GetDevNodePath(int dev_node_ordinal); + static std::string GetDevNodePath(int dev_node_ordinal); SE_DISALLOW_COPY_AND_ASSIGN(Diagnostician); }; diff --git a/tensorflow/stream_executor/gpu/gpu_driver.h b/tensorflow/stream_executor/gpu/gpu_driver.h index 948bee31cde..f72c9a129cf 100644 --- a/tensorflow/stream_executor/gpu/gpu_driver.h +++ b/tensorflow/stream_executor/gpu/gpu_driver.h @@ -34,7 +34,7 @@ namespace gpu { enum class MemorySpace { kHost, kDevice }; // Returns a casual string, such as "host" for the provided memory space. -string MemorySpaceString(MemorySpace memory_space); +std::string MemorySpaceString(MemorySpace memory_space); class GpuContext; @@ -149,7 +149,7 @@ class GpuDriver { // Given a device handle, returns the name reported by the driver for the // device. static port::Status GetDeviceName(GpuDeviceHandle device, - string* device_name); + std::string* device_name); // Given a device to create a context for, returns a context handle into the // context outparam, which must not be null. @@ -469,7 +469,7 @@ class GpuDriver { // Returns a PCI bus id string for the device. // [domain]:[bus]:[device].[function] // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g85295e7d9745ab8f0aa80dd1e172acfc - static string GetPCIBusID(GpuDeviceHandle device); + static std::string GetPCIBusID(GpuDeviceHandle device); // -- Context- and device-independent calls. diff --git a/tensorflow/stream_executor/gpu/gpu_executor.h b/tensorflow/stream_executor/gpu/gpu_executor.h index 47c6b85b9c6..fc4ea0e0ab2 100644 --- a/tensorflow/stream_executor/gpu/gpu_executor.h +++ b/tensorflow/stream_executor/gpu/gpu_executor.h @@ -196,7 +196,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { // Search for the symbol and returns a device pointer and size. // Returns false if symbol does not exist. - bool GetSymbol(const string& symbol_name, ModuleHandle module_handle, + bool GetSymbol(const std::string& symbol_name, ModuleHandle module_handle, void** mem, size_t* bytes) override; port::StatusOr> CreateDeviceDescription() @@ -245,7 +245,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { // (supported on CUDA only) bool FindOnDiskForComputeCapability(absl::string_view filename, absl::string_view canonical_suffix, - string* found_filename) const; + std::string* found_filename) const; // Attempts to find a more specific version of the file indicated by // filename by looking for AMDGPU ISA-specific suffixed versions. @@ -253,7 +253,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { bool FindOnDiskForISAVersion(absl::string_view filename, absl::string_view canonical_suffix, - string* found_filename) const; + std::string* found_filename) const; // Host callback landing routine invoked by CUDA. // data: User-provided callback provided to HostCallback() above, captured @@ -294,7 +294,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { // Multiple GPUFunctionHandle are usually obtained from a single // GPUModuleHandle so we attempt to hit in this mapping first, before // retrieving it. - std::map disk_modules_ + std::map disk_modules_ TF_GUARDED_BY(disk_modules_mu_); // Guards the in-memory-module mapping. diff --git a/tensorflow/stream_executor/gpu/gpu_rng.h b/tensorflow/stream_executor/gpu/gpu_rng.h index 8dbe2961fff..a3464f48aeb 100644 --- a/tensorflow/stream_executor/gpu/gpu_rng.h +++ b/tensorflow/stream_executor/gpu/gpu_rng.h @@ -96,25 +96,25 @@ class GpuRng : public rng::RngSupport { }; template -string TypeString(); +std::string TypeString(); template <> -string TypeString() { +std::string TypeString() { return "float"; } template <> -string TypeString() { +std::string TypeString() { return "double"; } template <> -string TypeString>() { +std::string TypeString>() { return "std::complex"; } template <> -string TypeString>() { +std::string TypeString>() { return "std::complex"; } diff --git a/tensorflow/stream_executor/gpu/redzone_allocator.h b/tensorflow/stream_executor/gpu/redzone_allocator.h index 2fbaad32baf..77755ccd3c6 100644 --- a/tensorflow/stream_executor/gpu/redzone_allocator.h +++ b/tensorflow/stream_executor/gpu/redzone_allocator.h @@ -77,7 +77,7 @@ class RedzoneAllocator : public ScratchAllocator { std::string RedzoneFailureMsg() const; - string buffer_name = {}; + std::string buffer_name = {}; void* user_buffer_address = nullptr; int64 offset = 0; uint64 expected_value = 0; From fe4ef23c9e8730477b209cdc3d6d335edeab2dec Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 13:24:27 -0700 Subject: [PATCH 343/492] Internal visibility change. PiperOrigin-RevId: 302088095 Change-Id: I1aeb7f0105a3c9ae143e38d2ab00851c26333abf --- tensorflow/compiler/mlir/xla/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 72126a7ef8f..2a76a75da50 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -15,6 +15,7 @@ package_group( "//learning/brain/experimental/swift_mlir/...", "//learning/brain/google/xla/kernels/...", "//learning/brain/swift/swift_mlir/...", + "//learning/pathways/data_parallel/tf2xla/...", "//platforms/xla/...", "//tensorflow/compiler/mlir/...", "//tensorflow/compiler/tf2xla/...", From 20d58a1916a348696dd452eacaa54075f48ab491 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Fri, 20 Mar 2020 13:38:09 -0700 Subject: [PATCH 344/492] [tf.data service] Add util for handling grpc errors. PiperOrigin-RevId: 302090444 Change-Id: I8aa9c1427c55e8e956225ab0a7f97c97bb96fc32 --- tensorflow/core/data/service/BUILD | 24 ++++++++++++ tensorflow/core/data/service/grpc_util.cc | 37 ++++++++++++++++++ tensorflow/core/data/service/grpc_util.h | 33 ++++++++++++++++ .../core/data/service/grpc_util_test.cc | 39 +++++++++++++++++++ 4 files changed, 133 insertions(+) create mode 100644 tensorflow/core/data/service/grpc_util.cc create mode 100644 tensorflow/core/data/service/grpc_util.h create mode 100644 tensorflow/core/data/service/grpc_util_test.cc diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index b597fd70add..6c8116a6de8 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -45,6 +45,30 @@ tf_proto_library( ], ) +cc_library( + name = "grpc_util", + srcs = ["grpc_util.cc"], + hdrs = [ + "grpc_util.h", + ], + deps = [ + "//tensorflow:grpc++", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "grpc_util_test", + srcs = ["grpc_util_test.cc"], + deps = [ + ":grpc_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "compression_utils", srcs = ["compression_utils.cc"], diff --git a/tensorflow/core/data/service/grpc_util.cc b/tensorflow/core/data/service/grpc_util.cc new file mode 100644 index 00000000000..40950c51efe --- /dev/null +++ b/tensorflow/core/data/service/grpc_util.cc @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/grpc_util.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { +namespace grpc_util { + +Status WrapError(const std::string& message, const grpc::Status& status) { + if (status.ok()) { + return errors::Internal("Expected a non-ok grpc status. Wrapping message: ", + message); + } else { + return Status(static_cast(status.error_code()), + absl::StrCat(message, ": ", status.error_message())); + } +} + +} // namespace grpc_util +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/grpc_util.h b/tensorflow/core/data/service/grpc_util.h new file mode 100644 index 00000000000..60ea10669a5 --- /dev/null +++ b/tensorflow/core/data/service/grpc_util.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_UTIL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_UTIL_H_ + +#include "grpcpp/grpcpp.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { +namespace grpc_util { + +// Wraps a grpc::Status in a tensorflow::Status with the given message. +Status WrapError(const std::string& message, const grpc::Status& status); + +} // namespace grpc_util +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_UTIL_H_ diff --git a/tensorflow/core/data/service/grpc_util_test.cc b/tensorflow/core/data/service/grpc_util_test.cc new file mode 100644 index 00000000000..47a0b4c4d89 --- /dev/null +++ b/tensorflow/core/data/service/grpc_util_test.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/grpc_util.h" + +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace grpc_util { + +TEST(GrpcUtil, WrapInvalidArgument) { + grpc::Status s(grpc::StatusCode::INVALID_ARGUMENT, "test message"); + Status wrapped = WrapError("wrapping message", s); + ASSERT_EQ(wrapped, errors::InvalidArgument("wrapping message: test message")); +} + +TEST(GrpcUtil, WrapOk) { + grpc::Status s; + Status wrapped = WrapError("wrapping message", s); + ASSERT_EQ(wrapped, errors::Internal("Expected a non-ok grpc status. Wrapping " + "message: wrapping message")); +} + +} // namespace grpc_util +} // namespace data +} // namespace tensorflow From 74354e8fddf8aee65fc8bd208b3478f7b9c4c37b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 13:40:50 -0700 Subject: [PATCH 345/492] move python hooks to OSS for code reuse. PiperOrigin-RevId: 302090903 Change-Id: Ic7b5a56583672cb7b089574e6df0e64f6ef8a144 --- tensorflow/python/profiler/internal/BUILD | 16 ++ .../python/profiler/internal/python_hooks.cc | 185 ++++++++++++++++++ .../python/profiler/internal/python_hooks.h | 52 +++++ 3 files changed, 253 insertions(+) create mode 100644 tensorflow/python/profiler/internal/python_hooks.cc create mode 100644 tensorflow/python/profiler/internal/python_hooks.h diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 3dccf7144e4..29d2701faba 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -128,3 +128,19 @@ tf_python_pybind_extension( "@pybind11", ], ) + +cc_library( + name = "python_hooks", + srcs = ["python_hooks.cc"], + hdrs = ["python_hooks.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform:path", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/strings", + "@pybind11", + ], + alwayslink = True, +) diff --git a/tensorflow/python/profiler/internal/python_hooks.cc b/tensorflow/python/profiler/internal/python_hooks.cc new file mode 100644 index 00000000000..be4afe71349 --- /dev/null +++ b/tensorflow/python/profiler/internal/python_hooks.cc @@ -0,0 +1,185 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/python/profiler/internal/python_hooks.h" + +#include "absl/strings/strip.h" +#include "tensorflow/core/platform/path.h" + +namespace tensorflow { +namespace profiler { + +namespace py = ::pybind11; +using tensorflow::profiler::TraceMe; + +template +int ProfileFunction(PyObject* obj, PyFrameObject* frame, int what, + PyObject* arg) { + T::GetSingleton()->ProfileFast(frame, what, arg); + return 0; +} + +void SysSetProfileNone() { + py::object setprofile = py::module::import("sys").attr("setprofile"); + setprofile(py::none()); +} + +void ThreadingSetProfile(const py::object& callback) { + py::object setprofile = py::module::import("threading").attr("setprofile"); + setprofile(callback); +} + +PythonHooks* PythonHooks::GetSingleton() { + static PythonHooks* singleton = new PythonHooks; + return singleton; +} + +void PythonHooks::Start() { + PyGILState_STATE gil_state = PyGILState_Ensure(); + SetProfilerInAllThreads(); + PyGILState_Release(gil_state); +} + +void PythonHooks::Stop() { + PyGILState_STATE gil_state = PyGILState_Ensure(); + ClearProfilerInAllThreads(); + PyGILState_Release(gil_state); +} + +void PythonHooks::Finalize() { tracemes_.clear(); } + +void PythonHooks::ProfileSlow(const py::object& frame, const string& event, + const py::object& arg) { + int what; + absl::string_view event_name(event); + + if (absl::ConsumePrefix(&event_name, "c_")) { + if (event_name == "call") { + what = PyTrace_C_CALL; + } else if (event_name == "return") { + what = PyTrace_C_RETURN; + } else if (event_name == "exception") { + what = PyTrace_C_EXCEPTION; + } else { + return; + } + } else { + if (event_name == "call") { + what = PyTrace_CALL; + } else if (event_name == "return") { + what = PyTrace_RETURN; + } else if (event_name == "exception") { + what = PyTrace_EXCEPTION; + } else { + return; + } + } + + ProfileFast(reinterpret_cast(frame.ptr()), what, arg.ptr()); +} + +void PythonHooks::ProfileFast(PyFrameObject* frame, int what, PyObject* arg) { + const int64 thread_id = PyThread_get_thread_ident(); + + if (what == PyTrace_CALL) { + PyCodeObject* f_code = frame->f_code; + string filename(py::reinterpret_borrow(f_code->co_filename)); + int line_no = frame->f_lineno; + + string function; + if (f_code->co_name == nullptr) { + function = ""; + } else { + function = py::reinterpret_borrow(f_code->co_name); + } + + tracemes_[thread_id].push_back(absl::make_unique(absl::StrCat( + "$", io::Basename(filename), ":", line_no, " ", function))); + } else if (what == PyTrace_C_CALL && PyCFunction_Check(arg)) { + // Python stack does not have a filename/line_no for native calls. + auto* func = reinterpret_cast(arg); + PyObject* module = func->m_module; + string filename; + bool filename_ok; +#if PY_MAJOR_VERSION < 3 + filename_ok = (module != nullptr && PyString_Check(module)); +#else + filename_ok = (module != nullptr && PyUnicode_Check(module)); +#endif + if (filename_ok) { + filename = py::reinterpret_borrow(module); + } else { + filename = ""; + } + + string function(func->m_ml->ml_name); + tracemes_[thread_id].push_back(absl::make_unique( + absl::StrCat(filename, " ", func->m_ml->ml_name))); + } else if (what == PyTrace_RETURN || what == PyTrace_C_RETURN || + what == PyTrace_EXCEPTION || what == PyTrace_C_EXCEPTION) { + auto& thread_tracemes = tracemes_[thread_id]; + if (!thread_tracemes.empty()) { + thread_tracemes.pop_back(); + } + } +} + +void PythonHooks::SetProfilerInAllThreads() { + // We also want any new threads started to use our profiler. + // NOTE: threading does not provide a C API equivalent to + // `threading.setprofile` so we are forced to go via Python to setup the + // profile when a new thread is created. After the first callback in that + // thread we unregister the Python profile function and use + // `PyEval_SetProfile` to register a C profiler which has significantly less + // overhead (>2x faster). + py::cpp_function callback = + py::cpp_function([this](const py::object& frame, const string& event, + const py::object& arg) { + ProfileSlow(frame, event, arg); + SysSetProfileNone(); + PyEval_SetProfile(ProfileFunction, nullptr); + }); + + ThreadingSetProfile(callback); + + // NOTE: This must be after `threading.setprofile` otherwise we + // end up recording that in our trace. + PyThreadState* curr_thread = PyThreadState_Get(); + PyThreadState* next_thread = curr_thread; + while (next_thread != nullptr) { + VLOG(1) << "Setting profiler in " << next_thread->thread_id; + PyThreadState_Swap(next_thread); + PyEval_SetProfile(ProfileFunction, nullptr); + next_thread = next_thread->next; + } + PyThreadState_Swap(curr_thread); +} + +void PythonHooks::ClearProfilerInAllThreads() { + PyThreadState* curr_thread = PyThreadState_Get(); + PyThreadState* next_thread = curr_thread; + while (next_thread != nullptr) { + VLOG(1) << "Clearing profiler in " << next_thread->thread_id; + PyThreadState_Swap(next_thread); + PyEval_SetProfile(nullptr, nullptr); + next_thread = next_thread->next; + } + PyThreadState_Swap(curr_thread); + + // And notify the threading library that we're done. + ThreadingSetProfile(py::none()); +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/python/profiler/internal/python_hooks.h b/tensorflow/python/profiler/internal/python_hooks.h new file mode 100644 index 00000000000..e2b0544ebf4 --- /dev/null +++ b/tensorflow/python/profiler/internal/python_hooks.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_PROFILER_INTERNAL_PYTHON_HOOKS_H_ +#define TENSORFLOW_PYTHON_PROFILER_INTERNAL_PYTHON_HOOKS_H_ + +#include "absl/container/flat_hash_map.h" +#include "include/pybind11/cast.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace tensorflow { +namespace profiler { + +namespace py = ::pybind11; + +// Singleton for tracing python function calls. +class PythonHooks { + public: + static PythonHooks* GetSingleton(); + + void Start(); + void Stop(); + void Finalize(); + void ProfileSlow(const py::object& frame, const string& event, + const py::object& arg); + void ProfileFast(PyFrameObject* frame, int what, PyObject* arg); + + private: + void SetProfilerInAllThreads(); + void ClearProfilerInAllThreads(); + + absl::flat_hash_map>> tracemes_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_PROFILER_INTERNAL_PYTHON_HOOKS_H_ From 50480faea75f56def464b84f251b4aee388dfce9 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Fri, 20 Mar 2020 13:48:55 -0700 Subject: [PATCH 346/492] Make TensorBoard Callback non-blocking. PiperOrigin-RevId: 302092248 Change-Id: I4f03ad95291fec85fc315d59720b2468d9d3e588 --- tensorflow/python/keras/callbacks.py | 104 +++++++++++++++------- tensorflow/python/keras/callbacks_test.py | 37 ++++++++ 2 files changed, 110 insertions(+), 31 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 3f9a3fd684b..c68f58c0747 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -257,12 +257,6 @@ class CallbackList(object): self._delta_ts = collections.defaultdict( lambda: collections.deque([], maxlen=self._queue_length)) - def _process_logs(self, logs): - """Turns tensors into numpy arrays or Python scalars.""" - if logs: - return tf_utils.to_numpy_or_python_type(logs) - return {} - def append(self, callback): self.callbacks.append(callback) @@ -291,9 +285,15 @@ class CallbackList(object): logs = logs or {} t_before_callbacks = time.time() + numpy_logs = None for callback in self.callbacks: batch_hook = getattr(callback, hook_name) - batch_hook(batch, logs) + if getattr(callback, '_supports_tf_logs', False): + batch_hook(batch, logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + batch_hook(batch, numpy_logs) self._delta_ts[hook_name].append(time.time() - t_before_callbacks) delta_t_median = np.median(self._delta_ts[hook_name]) @@ -324,12 +324,10 @@ class CallbackList(object): def on_batch_begin(self, batch, logs=None): if self._should_call_train_batch_hooks: - logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) def on_batch_end(self, batch, logs=None): if self._should_call_train_batch_hooks: - logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) def on_epoch_begin(self, epoch, logs=None): @@ -342,9 +340,15 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ - logs = self._process_logs(logs) + logs = logs or {} + numpy_logs = None for callback in self.callbacks: - callback.on_epoch_begin(epoch, logs) + if getattr(callback, '_supports_tf_logs', False): + callback.on_epoch_begin(epoch, logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + callback.on_epoch_begin(epoch, numpy_logs) self._reset_batch_timing() def on_epoch_end(self, epoch, logs=None): @@ -358,9 +362,15 @@ class CallbackList(object): validation epoch if validation is performed. Validation result keys are prefixed with `val_`. """ - logs = self._process_logs(logs) + logs = logs or {} + numpy_logs = None for callback in self.callbacks: - callback.on_epoch_end(epoch, logs) + if getattr(callback, '_supports_tf_logs', False): + callback.on_epoch_end(epoch, logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + callback.on_epoch_end(epoch, numpy_logs) def on_train_batch_begin(self, batch, logs=None): """Calls the `on_train_batch_begin` methods of its callbacks. @@ -373,7 +383,6 @@ class CallbackList(object): # TODO(b/150629188): Make ProgBarLogger callback not use batch hooks # when verbose != 1 if self._should_call_train_batch_hooks: - logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) def on_train_batch_end(self, batch, logs=None): @@ -384,7 +393,6 @@ class CallbackList(object): logs: dict. Metric results for this batch. """ if self._should_call_train_batch_hooks: - logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) def on_test_batch_begin(self, batch, logs=None): @@ -396,7 +404,6 @@ class CallbackList(object): number and the size of the batch. """ if self._should_call_test_batch_hooks: - logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs) def on_test_batch_end(self, batch, logs=None): @@ -407,7 +414,6 @@ class CallbackList(object): logs: dict. Metric results for this batch. """ if self._should_call_test_batch_hooks: - logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs) def on_predict_batch_begin(self, batch, logs=None): @@ -419,7 +425,6 @@ class CallbackList(object): number and the size of the batch. """ if self._should_call_predict_batch_hooks: - logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs) def on_predict_batch_end(self, batch, logs=None): @@ -430,7 +435,6 @@ class CallbackList(object): logs: dict. Metric results for this batch. """ if self._should_call_predict_batch_hooks: - logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) def on_train_begin(self, logs=None): @@ -440,9 +444,15 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ - logs = self._process_logs(logs) + logs = logs or {} + numpy_logs = None for callback in self.callbacks: - callback.on_train_begin(logs) + if getattr(callback, '_supports_tf_logs', False): + callback.on_train_begin(logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + callback.on_train_begin(numpy_logs) def on_train_end(self, logs=None): """Calls the `on_train_end` methods of its callbacks. @@ -451,9 +461,15 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ - logs = self._process_logs(logs) + logs = logs or {} + numpy_logs = None for callback in self.callbacks: - callback.on_train_end(logs) + if getattr(callback, '_supports_tf_logs', False): + callback.on_train_end(logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + callback.on_train_end(numpy_logs) def on_test_begin(self, logs=None): """Calls the `on_test_begin` methods of its callbacks. @@ -462,9 +478,15 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ - logs = self._process_logs(logs) + logs = logs or {} + numpy_logs = None for callback in self.callbacks: - callback.on_test_begin(logs) + if getattr(callback, '_supports_tf_logs', False): + callback.on_test_begin(logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + callback.on_test_begin(numpy_logs) def on_test_end(self, logs=None): """Calls the `on_test_end` methods of its callbacks. @@ -473,9 +495,15 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ - logs = self._process_logs(logs) + logs = logs or {} + numpy_logs = None for callback in self.callbacks: - callback.on_test_end(logs) + if getattr(callback, '_supports_tf_logs', False): + callback.on_test_end(logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + callback.on_test_end(numpy_logs) def on_predict_begin(self, logs=None): """Calls the 'on_predict_begin` methods of its callbacks. @@ -484,9 +512,15 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ - logs = self._process_logs(logs) + logs = logs or {} + numpy_logs = None for callback in self.callbacks: - callback.on_predict_begin(logs) + if getattr(callback, '_supports_tf_logs', False): + callback.on_predict_begin(logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + callback.on_predict_begin(numpy_logs) def on_predict_end(self, logs=None): """Calls the `on_predict_end` methods of its callbacks. @@ -495,9 +529,15 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ - logs = self._process_logs(logs) + logs = logs or {} + numpy_logs = None for callback in self.callbacks: - callback.on_predict_end(logs) + if getattr(callback, '_supports_tf_logs', False): + callback.on_predict_end(logs) + else: + if numpy_logs is None: # Only convert once. + numpy_logs = tf_utils.to_numpy_or_python_type(logs) + callback.on_predict_end(numpy_logs) def __iter__(self): return iter(self.callbacks) @@ -539,6 +579,7 @@ class Callback(object): # Multi-Worker setting. # TODO(omalleyt): Make this attr public once solution is stable. self._chief_worker_only = None + self._supports_tf_logs = False def set_params(self, params): self.params = params @@ -1718,6 +1759,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): embeddings_metadata=None, **kwargs): super(TensorBoard, self).__init__() + self._supports_tf_logs = True self._validate_kwargs(kwargs) self.log_dir = log_dir diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index e488835a6c5..2b7f7c038c6 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -1916,6 +1916,43 @@ class TestTensorBoardV2(keras_parameterized.TestCase): with self.assertRaisesRegexp(ValueError, 'Unrecognized arguments'): keras.callbacks.TensorBoard(wwrite_images=True) + def test_TensorBoard_non_blocking(self): + model = keras.Sequential([keras.layers.Dense(1)]) + tb = keras.callbacks.TensorBoard(self.logdir) + self.assertTrue(tb._supports_tf_logs) + cb_list = keras.callbacks.CallbackList([tb], + model=model, + epochs=1, + steps=100, + verbose=0) + + tensor = ops.convert_to_tensor(1.) + + def mock_numpy(): + raise RuntimeError( + 'If this error is seen, TensorBoard is causing a blocking ' + 'NumPy conversion.') + + with test.mock.patch.object(tensor, 'numpy', mock_numpy): + logs = {'metric': tensor} + + cb_list.on_train_begin(logs) + cb_list.on_epoch_begin(0, logs) + cb_list.on_train_batch_begin(0, logs) + cb_list.on_train_batch_end(0, logs) + cb_list.on_epoch_end(0, logs) + cb_list.on_train_end(logs) + + cb_list.on_test_begin(logs) + cb_list.on_test_batch_begin(0, logs) + cb_list.on_test_batch_end(0, logs) + cb_list.on_test_end(logs) + + cb_list.on_predict_begin(logs) + cb_list.on_predict_batch_begin(logs) + cb_list.on_predict_batch_end(logs) + cb_list.on_predict_end(logs) + # Note that this test specifies model_type explicitly. @keras_parameterized.run_all_keras_modes(always_skip_v1=True) From 2234cafad6c101182b793b0922aa9d45dad02cc5 Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Fri, 20 Mar 2020 14:42:16 -0700 Subject: [PATCH 347/492] Create _Subprocess class in multi_process_runner to better modularize the methods belonging to parent and subprocesses. PiperOrigin-RevId: 302102237 Change-Id: I473fd6fe65b49b851ee3549cf400d6b7d0c742b3 --- .../python/distribute/multi_process_runner.py | 286 +++++++++--------- .../distribute/multi_process_runner_test.py | 13 +- 2 files changed, 149 insertions(+), 150 deletions(-) diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 0091f0e4109..4b28f1aa05a 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -81,13 +81,6 @@ BARRIER = 'barrier' _DEFAULT_MAX_SUBPROCESS_COUNT = 20 -# Threads to be joined at the time subprocesses successfully exit. -# TODO(rchao): Refactor multi_process_runner so that _threads lives in -# parent process' class which is separated from subprocess' class. Currently -# this needs to be global so it doesn't get pickled into subprocess' function -# runs, which would fail. -_threads = [] - # Next pipe index to be global so that pipes are not reused across multiple # MultiProcessRunner usages. # TODO(rchao): Investigate possibility to remove this variable. @@ -115,7 +108,7 @@ class MultiProcessRunner(object): cluster_spec, rpc_layer=None, max_run_time=None, - grpc_fail_fast=False, + grpc_fail_fast=None, stream_stdout=True, list_stdout=False, args=None, @@ -139,7 +132,8 @@ class MultiProcessRunner(object): since Python signal handler does not get executed when it runs lower level C/C++ code. So it can be delayed for arbitrarily long time. grpc_fail_fast: Whether GRPC connection between processes should fail - without retrying. Defaults to False. + without retrying. Defaults to None, in which case the environment + variable is not explicitly set. stream_stdout: True if the output/error from the subprocesses should be streamed to be printed in parent process' log. Defaults to True. list_stdout: True if the output/error from the subprocesses should be @@ -190,112 +184,6 @@ class MultiProcessRunner(object): # This flag will be set to True once terminate_all() is called. self._all_forced_terminated = False - @contextlib.contextmanager - def _runtime_mode(self): - if self._executing_eagerly: - with context.eager_mode(): - yield - else: - with context.graph_mode(): - yield - - def _finish_process(self, process_status_info, return_value): - """Adds data to queues before program exits.""" - # Clear the alarm. - signal.alarm(0) - - if return_value is not None: - self._add_return_data(return_value) - self._get_process_status_queue().put(process_status_info) - - def _message_checking_func(self, task_type, task_id): - """A function that regularly checks messages from parent process.""" - # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. - while True: - try: - message = self._get_parent_to_sub_queue().get(block=False) - # Currently the only possible message is termination. - assert message.startswith('terminate') - if message == 'terminate {} {}'.format(task_type, task_id): - break - else: - # If the message is not targeting this process, put it back to the - # queue. - self._get_parent_to_sub_queue().put(message) - time.sleep(1) - except Queue.Empty: - time.sleep(0.1) - self._finish_process( - _ProcessStatusInfo( - task_type=task_type, is_successful=True, exc_info=None), None) - # `os._exit(0)` is used to more reliably terminate a subprocess. - os._exit(0) # pylint: disable=protected-access - - def _proc_func_wrapper(self, proc_func, task_type, task_id, - per_process_cluster_spec, rpc_layer, pipe_w, *arg, - **kwargs): - """The wrapper function that actually gets run in child process(es).""" - - pid = os.getpid() - logging.info('Subprocess with PID %d is now being started.', pid) - self._get_subprocess_info_queue().put(_SubprocessInfo(pid=pid)) - - # Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and - # logging.*() write directly to `pipe_w`. Unfortunately since we cannot - # prepend task_type and task_id information to the streamed logs we will - # need a thread per subprocess to distinguish where the piece of message is - # from. - os.dup2(pipe_w.fileno(), sys.stdout.fileno()) - os.dup2(pipe_w.fileno(), sys.stderr.fileno()) - - # The thread will be dedicated to checking messages from the parent process. - threading.Thread( # pylint: disable=unexpected-keyword-arg - target=self._message_checking_func, - args=(task_type, task_id), - daemon=True).start() - - os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast) - tf_config_dict = { - 'cluster': per_process_cluster_spec, - 'task': { - 'type': task_type, - 'index': task_id, - }, - } - if rpc_layer is not None: - tf_config_dict['rpc_layer'] = rpc_layer - os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) - - if self._v2_enabled: - v2_compat.enable_v2_behavior() - - return_value = None - - try: - with self._runtime_mode(): - return_value = proc_func(*arg, **kwargs) - - except Exception: # pylint: disable=broad-except - # Capture all exceptions to be reported to parent process. - self._finish_process( - _ProcessStatusInfo( - task_type=task_type, is_successful=False, - exc_info=sys.exc_info()), return_value) - - # Re-raise the exception in addition to reporting it to the parent - # process, so that even if `--test_timeout` flag is set and the - # error doesn't make it to be shown in parent process before bazel's - # timeout, the log would still show what happens in this subprocess, - # instead of silently suppressing the error due to early bazel - # timeout. Raising an error in the subprocess produces stack trace in - # the log, but the program continues running. - raise - - self._finish_process( - _ProcessStatusInfo( - task_type=task_type, is_successful=True, exc_info=None), - return_value) - def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): """Function to continuously read lines from subprocesses.""" reader = os.fdopen(pipe_r.fileno(), 'r') @@ -303,6 +191,10 @@ class MultiProcessRunner(object): read_line = reader.readline() if read_line == 'EOF': reader.close() + # The thread that runs `_continuously_readline_from_sub` stops here. + # However the threads don't exit until the test exits, so we do not + # attempt to join the threads (which leads to timeout). + # TODO(rchao): Understand why and do thread joining. break task_string = '[{}-{}]:'.format(task_type, task_id) formatted_line = '{} {}'.format(task_string.ljust(14), read_line) @@ -320,20 +212,20 @@ class MultiProcessRunner(object): def _add_stdout_in_queue(self, formatted_line, task_type, task_id): del task_type, task_id # A queue instead of a simple list is used here due to b/150652733. - multi_process_lib.get_user_data()[STREAMING_QUEUE].put(formatted_line) + _resource(STREAMING_QUEUE).put(formatted_line) def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id, args, kwargs): """Start a subprocess and a thread the reads lines from the subprocess.""" global _next_pipe_index - pipe_r, pipe_w = multi_process_lib.get_user_data( - )[STREAMING_PIPE][_next_pipe_index] + pipe_r, pipe_w = _resource(STREAMING_PIPE)[_next_pipe_index] _next_pipe_index += 1 p = multi_process_lib.Process( - target=self._proc_func_wrapper, + target=_Subprocess(), args=(proc_func, task_type, task_id, self._cluster_spec, - self._rpc_layer, pipe_w) + args, + self._rpc_layer, self._grpc_fail_fast, self._v2_enabled, + self._executing_eagerly, pipe_w) + args, kwargs=kwargs) p.start() self._outstanding_subprocess_count += 1 @@ -342,10 +234,8 @@ class MultiProcessRunner(object): # from them. thread = threading.Thread( # pylint: disable=unexpected-keyword-arg target=self._continuously_readline_from_sub, - args=(pipe_r, task_type, task_id), - daemon=True) + args=(pipe_r, task_type, task_id)) thread.start() - _threads.append(thread) def start(self): """Starts processes, one for each task in `cluster_spec`.""" @@ -439,7 +329,7 @@ class MultiProcessRunner(object): while self._outstanding_subprocess_count > 0: while True: try: - process_status = self._get_process_status_queue().get(timeout=10) + process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10) break except Queue.Empty: if self._all_forced_terminated: @@ -464,28 +354,23 @@ class MultiProcessRunner(object): # Giving threads some time to finish the message reading from subprocesses. time.sleep(5) - stdout = self._queue_to_list( - multi_process_lib.get_user_data()[STREAMING_QUEUE]) - return_value = self._queue_to_list( - multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]) + stdout = self._queue_to_list(_resource(STREAMING_QUEUE)) + return_value = self._queue_to_list(_resource(RETURN_VALUE_QUEUE)) # Notifying the threads that are reading lines that we should stop. for pipe_index in range(self._starting_pipe_index, _next_pipe_index): # pylint: disable=protected-access - _, pipe_w = multi_process_lib.get_user_data()[STREAMING_PIPE][pipe_index] + _, pipe_w = _resource(STREAMING_PIPE)[pipe_index] writer = os.fdopen(pipe_w.fileno(), 'w') # Writing end of file message so the threads that's actively reading lines # know to stop. writer.writelines(['EOF']) writer.close() - for thread in _threads: - thread.join(5) - return MultiProcessRunnerResult(stdout=stdout, return_value=return_value) def terminate(self, task_type, task_id): """Terminates the process with `task_type` and `task_id`.""" - self._get_parent_to_sub_queue().put('terminate {} {}'.format( + _resource(PARENT_TO_SUB_QUEUE).put('terminate {} {}'.format( task_type, task_id)) def terminate_all(self): @@ -494,7 +379,7 @@ class MultiProcessRunner(object): while True: try: - subprocess_info = self._get_subprocess_info_queue().get(block=False) + subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False) subprocess_infos.append(subprocess_info) except Queue.Empty: break @@ -505,6 +390,122 @@ class MultiProcessRunner(object): self._all_forced_terminated = True + +class _Subprocess(object): + """Represents an internal subprocess used in MultiProcessRunner's context.""" + + @contextlib.contextmanager + def _runtime_mode(self, executing_eagerly): + if executing_eagerly: + with context.eager_mode(): + yield + else: + with context.graph_mode(): + yield + + def _finish_process(self, process_status_info, return_value): + """Adds data to queues before program exits.""" + # Clear the alarm. + signal.alarm(0) + + if return_value is not None: + self._add_return_data(return_value) + _resource(PROCESS_STATUS_QUEUE).put(process_status_info) + + def _message_checking_func(self, task_type, task_id): + """A function that regularly checks messages from parent process.""" + # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. + while True: + try: + message = _resource(PARENT_TO_SUB_QUEUE).get(block=False) + + # Currently the only possible message is termination. + if not message.startswith('terminate'): + raise ValueError('Unrecognized message: {}'.format(message)) + + if message == 'terminate {} {}'.format(task_type, task_id): + break + else: + # If the message is not targeting this process, put it back to the + # queue. + _resource(PARENT_TO_SUB_QUEUE).put(message) + time.sleep(1) + except Queue.Empty: + time.sleep(0.1) + self._finish_process( + _ProcessStatusInfo( + task_type=task_type, is_successful=True, exc_info=None), None) + # `os._exit(0)` is used to more reliably terminate a subprocess. + os._exit(0) # pylint: disable=protected-access + + def __call__(self, proc_func, task_type, task_id, per_process_cluster_spec, + rpc_layer, grpc_fail_fast, v2_enabled, executing_eagerly, pipe_w, + *arg, **kwargs): + """The wrapper function that actually gets run in child process(es).""" + + pid = os.getpid() + logging.info('Subprocess with PID %d is now being started.', pid) + _resource(SUBPROCESS_INFO_QUEUE).put(_SubprocessInfo(pid=pid)) + + # Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and + # logging.*() write directly to `pipe_w`. Unfortunately since we cannot + # prepend task_type and task_id information to the streamed logs we will + # need a thread per subprocess to distinguish where the piece of message is + # from. + os.dup2(pipe_w.fileno(), sys.stdout.fileno()) + os.dup2(pipe_w.fileno(), sys.stderr.fileno()) + + # The thread will be dedicated to checking messages from the parent process. + threading.Thread( # pylint: disable=unexpected-keyword-arg + target=self._message_checking_func, + args=(task_type, task_id), + daemon=True).start() + + if grpc_fail_fast is not None: + os.environ['GRPC_FAIL_FAST'] = str(grpc_fail_fast) + tf_config_dict = { + 'cluster': per_process_cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id, + }, + } + if rpc_layer is not None: + tf_config_dict['rpc_layer'] = rpc_layer + os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) + + if v2_enabled: + v2_compat.enable_v2_behavior() + + try: + with self._runtime_mode(executing_eagerly): + return_value = proc_func(*arg, **kwargs) + is_successful = True + exc_info = None + + except Exception: # pylint: disable=broad-except + # Capture all exceptions to be reported to parent process. + return_value = None + is_successful = False + exc_info = sys.exc_info() + + # Re-raise the exception in addition to reporting it to the parent + # process, so that even if `--test_timeout` flag is set and the + # error doesn't make it to be shown in parent process before bazel's + # timeout, the log would still show what happens in this subprocess, + # instead of silently suppressing the error due to early bazel + # timeout. Raising an error in the subprocess produces stack trace in + # the log, but the program continues running. + raise + + finally: + self._finish_process( + _ProcessStatusInfo( + task_type=task_type, + is_successful=is_successful, + exc_info=exc_info), + return_value) + def _add_return_data(self, data): """Adds return data that will be returned by `join`. @@ -518,27 +519,22 @@ class MultiProcessRunner(object): # TODO(rchao): Incorporate the task type and id information in a data # wrapper that becomes what is stored in the queue so we can tell where # the data is from. - multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE].put(data) - - def _get_process_status_queue(self): - return multi_process_lib.get_user_data()[PROCESS_STATUS_QUEUE] - - def _get_parent_to_sub_queue(self): - return multi_process_lib.get_user_data()[PARENT_TO_SUB_QUEUE] - - def _get_subprocess_info_queue(self): - return multi_process_lib.get_user_data()[SUBPROCESS_INFO_QUEUE] + _resource(RETURN_VALUE_QUEUE).put(data) def barrier(): return multi_process_lib.get_user_data()[BARRIER] +def _resource(resource_name): + return multi_process_lib.get_user_data()[resource_name] + + def run(proc_func, cluster_spec, rpc_layer=None, max_run_time=None, - grpc_fail_fast=False, + grpc_fail_fast=None, stream_stdout=True, list_stdout=False, timeout=None, diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index a21ae45fd56..b374429257b 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -125,18 +125,20 @@ class MultiProcessRunnerTest(test.TestCase): def func_to_exit_in_15_sec(): time.sleep(5) - mpr._add_return_data('foo') + print('foo', flush=True) time.sleep(20) - mpr._add_return_data('bar') + print('bar', flush=True) mpr = multi_process_runner.MultiProcessRunner( func_to_exit_in_15_sec, multi_worker_test_base.create_cluster_spec(num_workers=1), + list_stdout=True, max_run_time=15) mpr.start() - return_value = mpr.join().return_value - self.assertLen(return_value, 1) + stdout = mpr.join().stdout + self.assertLen([msg for msg in stdout if 'foo' in msg], 1) + self.assertLen([msg for msg in stdout if 'bar' in msg], 0) def test_signal_doesnt_fire_after_process_exits(self): mpr = multi_process_runner.MultiProcessRunner( @@ -148,7 +150,8 @@ class MultiProcessRunnerTest(test.TestCase): with self.assertRaisesRegexp(Queue.Empty, ''): # If the signal was fired, another message would be added to internal # queue, so verifying it's empty. - mpr._get_process_status_queue().get(block=False) + multi_process_runner._resource( + multi_process_runner.PROCESS_STATUS_QUEUE).get(block=False) def test_termination(self): From fde8d3886b0bd7c0e1d3eaef17775ac1e60b79e3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 14:54:59 -0700 Subject: [PATCH 348/492] Fix replicated input sharding handling. PiperOrigin-RevId: 302104485 Change-Id: I08b91a5684cdc352c6d9f8d8b8b0d5de47e3b44f --- .../mlir/tensorflow/tests/tpu_rewrite.mlir | 36 +++++++++++++++++++ .../tensorflow/utils/xla_sharding_util.cc | 2 +- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index f6eb08bb58c..34dbee5cba9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1319,6 +1319,42 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc } } + +// ----- + +// Tests that inputs are inputs with maximal and replicate sharding are set properly +// for replicated model parallelism. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<8xi32>, %[[ARG_1:[a-z0-9]*]]: tensor<8xi32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi1>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi1>, %[[ARG_4:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_5:[a-z0-9]*]]: tensor<*xi32>) + func @parallel_execute_with_input_with_sharding_configurations(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>, %arg2: tensor<*xi1>, %arg3: tensor<*xi1>, %arg4: tensor<*xi32>, %arg5: tensor<*xi32>) -> (tensor<8xi32>, tensor<8xi32>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<8xi32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi1> + // CHECK-SAME: [%[[ARG_4]], %[[ARG_5]]] as %[[RI_2:[a-z0-9]*]]: tensor<*xi32> + %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>, [%arg2, %arg3] as %ri2: tensor<*xi1>, [%arg4, %arg5] as %ri3: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" + // CHECK: "tf._TPUCompileMlir" + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[RI_0]], %[[RI_1]], %[[RI_2]], %[[COMPILE]]#1) + // CHECK-NEXT: tf_device.return %[[EXECUTE_OUTPUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute"(%[[RI_1]], %[[RI_2]], %[[COMPILE]]#2) + // CHECK: device = "TPU_REPLICATED_CORE_1" + %1 = "tf_device.launch_func"(%ri, %ri2, %ri3) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32> + tf_device.return %1 : tensor<8xi32> + } + return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> + } + func @tpu0_func(%arg0: tensor<8xi32>, %arg1: tensor<*xi1>, %arg2: tensor<*xi32>) -> tensor<8xi32> { + %1 = "tf.A"(%arg0, %arg1, %arg2) : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> (tensor<8xi32>) + return %1 : tensor<8xi32> + } +} + // ----- // Tests devices are set properly for replicated model parallelism with diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index bcf6e1b3496..ede8130c953 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -219,7 +219,7 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]); } } else if (input_sharing_type == xla::OpSharding::REPLICATED) { - for (auto inputs : *input_list) inputs.emplace_back(input_value); + for (auto& inputs : *input_list) inputs.emplace_back(input_value); } else { assert(input_sharing_type == xla::OpSharding::MAXIMAL); const int logical_device_id = sharding.tile_assignment_devices(0); From 63c256e4fa35a91f18f2001845c30901f818f78e Mon Sep 17 00:00:00 2001 From: Karim Nosir Date: Fri, 20 Mar 2020 14:56:17 -0700 Subject: [PATCH 349/492] Call verify method for tfl Runtime verifier. PiperOrigin-RevId: 302104741 Change-Id: I32c37777e4ba9fb88543bd58976f4214d34d5c30 --- tensorflow/compiler/mlir/lite/converter_gen.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index db2b924278f..b1fa1675845 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -461,7 +461,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { "operand"); GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(), "result"); - os << " return mlir::success();\n}\n"; + os << " return top.verify();\n}\n"; } return false; From 660b1908aab1b958461b391df74e6dbf4903b845 Mon Sep 17 00:00:00 2001 From: Yi Situ Date: Fri, 20 Mar 2020 15:00:28 -0700 Subject: [PATCH 350/492] Turned down verbosity of TC eligibility detection. PiperOrigin-RevId: 302105550 Change-Id: Ib213516491a7e11d2217fd5d7a2aa9900266a069 --- .../core/profiler/convert/xplane_to_kernel_stats_db.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc index e0fd2bb6339..0deb4309ff8 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc @@ -66,12 +66,12 @@ KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb( kernel.set_op_name(tf_op.name.data(), tf_op.name.size()); bool tensor_core_eligible = IsEinsumTensorCoreEligible(equation) || IsOpTensorCoreEligible(kernel.op_name()); -#if defined(LOG_IF) - LOG_IF(INFO, - !tensor_core_eligible && kernel.is_kernel_using_tensor_core()) +#if defined(VLOG_IF) + VLOG_IF(1, + !tensor_core_eligible && kernel.is_kernel_using_tensor_core()) << "Detected new Op using TensorCores: " << kernel.op_name() << std::endl; -#endif // defined(LOG_IF) +#endif // defined(VLOG_IF) tensor_core_eligible |= kernel.is_kernel_using_tensor_core(); kernel.set_is_op_tensor_core_eligible(tensor_core_eligible); } From ce3ba1058f055d3c924dd0a8e71ef343ef649630 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 20 Mar 2020 15:10:05 -0700 Subject: [PATCH 351/492] Fix crash in Model.fit() if a gradient is None, attempt 2. I first submitted this in 3931d39379b9feb44d4f8edba0906e96629d6884 but was rolled back since Nones were filtered out from the gradients, but not the variables. I now add Nones back to the gradients so they properly match up. PiperOrigin-RevId: 302107549 Change-Id: I81b7fb71c9cdaa458475d83f784366ce8405fb74 --- .../distribute/distribute_strategy_test.py | 27 +++++++++++++++++++ .../python/keras/engine/training_test.py | 24 +++++++++++++++++ .../python/keras/optimizer_v2/optimizer_v2.py | 27 ++++++++++++++----- 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 4ca3cf2b142..f5c3dc9bcfe 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -796,6 +796,33 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, atol=1e-4, rtol=1e-4) + @combinations.generate(all_strategy_combinations_plus_run_distributed()) + def test_gradients_are_none(self, distribution): + + if not context.executing_eagerly(): + self.skipTest('None gradients are not supported in graph mode') + + class DenseWithExtraWeight(keras.layers.Dense): + + def build(self, input_shape): + # Gradients w.r.t. extra_weights are None + self.extra_weight_1 = self.add_weight('extra_weight_1', shape=(), + initializer='ones') + super(DenseWithExtraWeight, self).build(input_shape) + self.extra_weight_2 = self.add_weight('extra_weight_2', shape=(), + initializer='ones') + + with distribution.scope(): + model = keras.Sequential([DenseWithExtraWeight(4, input_shape=(4,))]) + model.compile('adam', 'mse') + + inputs = np.random.normal(size=(64, 4)) + targets = np.random.normal(size=(64, 4)) + old_kernel = model.get_weights()[1] + model.fit(inputs, targets) + new_kernel = model.get_weights()[1] + self.assertNotAllEqual(old_kernel, new_kernel) + class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 35497721f6d..22125df6512 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1383,6 +1383,30 @@ class TrainingTest(keras_parameterized.TestCase): model.fit(x, y) self.assertEqual(model.optimizer.aggregate_gradients_called, True) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_gradients_are_none(self): + + class DenseWithExtraWeight(layers_module.Dense): + + def build(self, input_shape): + # Gradients w.r.t. extra_weights are None + self.extra_weight_1 = self.add_weight('extra_weight_1', shape=(), + initializer='ones') + super(DenseWithExtraWeight, self).build(input_shape) + self.extra_weight_2 = self.add_weight('extra_weight_2', shape=(), + initializer='ones') + + model = sequential.Sequential([DenseWithExtraWeight(4, input_shape=(4,))]) + # Test clipping can handle None gradients + opt = optimizer_v2.adam.Adam(clipnorm=1.0, clipvalue=1.0) + model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly()) + inputs = np.random.normal(size=(64, 4)) + targets = np.random.normal(size=(64, 4)) + old_kernel = model.get_weights()[1] + model.fit(inputs, targets) + new_kernel = model.get_weights()[1] + self.assertNotAllEqual(old_kernel, new_kernel) + class TestExceptionsAndWarnings(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index d9f090c0a60..3026816de8f 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -343,15 +343,16 @@ class OptimizerV2(trackable.Trackable): raise ValueError("Gradient clipping in the optimizer " "(by setting clipnorm or clipvalue) is currently " "unsupported when using a distribution strategy.") - grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] + grads = [None if g is None else clip_ops.clip_by_norm(g, self.clipnorm) + for g in grads] if self.clipvalue is not None: if distribute_ctx.has_strategy(): raise ValueError("Gradient clipping in the optimizer " "(by setting clipnorm or clipvalue) is currently " "unsupported when using a distribution strategy.") + v = self.clipvalue grads = [ - clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) - for g in grads + None if g is None else clip_ops.clip_by_value(g, -v, v) for g in grads ] return grads @@ -511,6 +512,7 @@ class OptimizerV2(trackable.Trackable): A list of all-reduced gradients. """ grads_and_vars = list(grads_and_vars) + filtered_grads_and_vars = _filter_grads(grads_and_vars) def all_reduce_fn(distribution, grads_and_vars): return distribution.extended.batch_reduce_to( ds_reduce_util.ReduceOp.SUM, grads_and_vars) @@ -519,9 +521,22 @@ class OptimizerV2(trackable.Trackable): # replica context. # TODO(b/150507409): Do not switch to a cross-replica context once the bug # is fixed. - if grads_and_vars: - return distribute_ctx.get_replica_context().merge_call( - all_reduce_fn, args=(grads_and_vars,)) + if filtered_grads_and_vars: + reduced = distribute_ctx.get_replica_context().merge_call( + all_reduce_fn, args=(filtered_grads_and_vars,)) + else: + reduced = [] + # Copy 'reduced' but add None gradients back in + reduced_with_nones = [] + reduced_pos = 0 + for g, _ in grads_and_vars: + if g is None: + reduced_with_nones.append(None) + else: + reduced_with_nones.append(reduced[reduced_pos]) + reduced_pos += 1 + assert reduced_pos == len(reduced), "Failed to add all gradients" + return reduced_with_nones def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): """`apply_gradients` using a `DistributionStrategy`.""" From aecf5e0104c3519359cf4cbc932bf9dc2fec1e0a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 15:24:59 -0700 Subject: [PATCH 352/492] Fix build on platforms that don't support alwayslink=1 for header-only targets PiperOrigin-RevId: 302110502 Change-Id: I7ddaefad33f301d3c1e6ab64a95d39d7d521cfd7 --- tensorflow/lite/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 70b1566600d..f0e110cfaff 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -291,7 +291,6 @@ cc_library( "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/schema:schema_fbs", ] + tflite_experimental_runtime_linkopts(), - alwayslink = 1, ) cc_library( From 40882b8f893da73fc4c77b217c79dfad740edc73 Mon Sep 17 00:00:00 2001 From: "Ahmed S. Taei" Date: Fri, 20 Mar 2020 15:25:40 -0700 Subject: [PATCH 353/492] Lower xla_hlo.cos -> linalg.generic PiperOrigin-RevId: 302110626 Change-Id: I999a264fe9e6c50b1f26dacd6b3dd36f6506d319 --- .../mlir/xla/tests/hlo-legalize-to-linalg.mlir | 10 ++++++++++ .../mlir/xla/transforms/xla_legalize_to_linalg.cc | 1 + 2 files changed, 11 insertions(+) diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index 0f7b7369035..e1e11d4d37d 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -212,6 +212,16 @@ func @int_cmp(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @float_cos +func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: cos + %0 = "xla_hlo.cos"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index b09a0159bcb..5ef3b445db4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -582,6 +582,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, From 3947c77855f8d657cdcdcae985a2fd7670dc1e3e Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Fri, 20 Mar 2020 15:27:39 -0700 Subject: [PATCH 354/492] Expose RewriteLayoutWithShardedShape from XlaCompiler. This call can be reused when determining argument layouts with sharding. PiperOrigin-RevId: 302111008 Change-Id: I3607e41dc987e348e8405b96f09ebc549a8427bc --- tensorflow/compiler/tf2xla/xla_compiler.cc | 80 +++++++++++----------- tensorflow/compiler/tf2xla/xla_compiler.h | 6 ++ 2 files changed, 46 insertions(+), 40 deletions(-) diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index c30b1c0e17d..9b17ebe0260 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -138,46 +138,6 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, return Status::OK(); } -// Rewrites the layout of xla_shape if there is tiled sharding. -Status RewriteLayoutWithShardedShape( - const absl::optional& sharding, bool use_fast_memory, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_shape) { - if (sharding && !sharding->IsTileMaximal()) { - // After sharding, per core shape might have different layout. For example, - // before sharding, a shape [128, 128] will be assigned default - // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, - // the sharded shapes will have minor-to-major {0, 1}. - // - // As a result, for sharded shapes, we set their layout to per core shape's - // layout. - // - // TODO(endlessroad): for variable input & update, we might have - // different layouts which will prevent input output aliasing and - // increase memory usage. Investigate such cases. - int64 device = *sharding->tile_assignment().begin(); - std::vector offset = - sharding->TileOffsetForDevice(*xla_shape, device); - std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); - std::vector dimensions(xla_shape->rank()); - for (int64 i = 0; i < xla_shape->rank(); ++i) { - dimensions[i] = limit[i] - offset[i]; - } - xla::Shape per_device_xla_shape = - xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); - TensorShape per_device_tensor_shape; - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - xla_shape->element_type())); - TF_ASSIGN_OR_RETURN(per_device_xla_shape, - shape_representation_fn(per_device_tensor_shape, dtype, - use_fast_memory)); - *xla_shape->mutable_layout() = per_device_xla_shape.layout(); - } - return Status::OK(); -} - // There is a shape_representation_fn or sharding for an output, this function // uses a reshape to fix the layout. xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( @@ -1542,4 +1502,44 @@ xla::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { return iter->second; } +// Rewrites the layout of xla_shape if there is tiled sharding. +Status RewriteLayoutWithShardedShape( + const absl::optional& sharding, bool use_fast_memory, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape) { + if (sharding && !sharding->IsTileMaximal()) { + // After sharding, per core shape might have different layout. For example, + // before sharding, a shape [128, 128] will be assigned default + // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, + // the sharded shapes will have minor-to-major {0, 1}. + // + // As a result, for sharded shapes, we set their layout to per core shape's + // layout. + // + // TODO(endlessroad): for variable input & update, we might have + // different layouts which will prevent input output aliasing and + // increase memory usage. Investigate such cases. + int64 device = *sharding->tile_assignment().begin(); + std::vector offset = + sharding->TileOffsetForDevice(*xla_shape, device); + std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); + std::vector dimensions(xla_shape->rank()); + for (int64 i = 0; i < xla_shape->rank(); ++i) { + dimensions[i] = limit[i] - offset[i]; + } + xla::Shape per_device_xla_shape = + xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); + TensorShape per_device_tensor_shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape->element_type())); + TF_ASSIGN_OR_RETURN(per_device_xla_shape, + shape_representation_fn(per_device_tensor_shape, dtype, + use_fast_memory)); + *xla_shape->mutable_layout() = per_device_xla_shape.layout(); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 6a56136a9f6..d67b1f26696 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -518,6 +518,12 @@ class XlaCompiler { TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; +// Rewrites the layout of xla_shape if there is tiled sharding. +Status RewriteLayoutWithShardedShape( + const absl::optional& sharding, bool use_fast_memory, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ From c5927e4a69e7a9ad47585ef07e2295f648fcb89b Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Fri, 20 Mar 2020 15:45:40 -0700 Subject: [PATCH 355/492] Update replicate invariant op hoisting to handle ops with regions. tf._TPUCompileMlir in TPU rewrite is now being wrapped by a tf_device.launch for device assignment, but it is replicate invariant, so it should be hoisted out. PiperOrigin-RevId: 302114475 Change-Id: I8b08d7e95ea8dafe11a4e2f0b44ea39ddf936aaa --- .../replicate_invariant_op_hoisting.mlir | 39 +++++++++++++++++++ .../replicate_invariant_op_hoisting.cc | 20 ++++++---- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir index e582ed49cd3..4e3564fb6a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir @@ -85,3 +85,42 @@ func @dependent_invariants(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[SHAPE]], %[[OP_A]]) // CHECK: tf_device.replicate // CHECK: tf_device.return %[[SHAPE]], %[[OP_A]], %[[OP_B]] + + +// CHECK-LABEL: func @nested_ops +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xf32>, %{{[a-z0-9]*}}: tensor<*xf32>) +func @nested_ops(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { + %0:8 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<*xf32>) {n = 2: i32} { + %1 = "tf.Shape"(%ri) {device = "", T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor + %2 = "tf_device.launch"() ( { + %a = "tf.opA"(%1) : (tensor) -> tensor<*xi32> + tf_device.return %a : tensor<*xi32> + }) {device = "a"} : () -> tensor<*xi32> + %3 = "tf_device.launch"() ( { + %b = "tf.opB"(%1, %2) : (tensor, tensor<*xi32>) -> tensor<*xf32> + tf_device.return %b : tensor<*xf32> + }) {device = "b"} : () -> tensor<*xf32> + %4 = "tf_device.launch"() ( { + %c = "tf.opC"(%ri, %3) : (tensor<*xf32>, tensor<*xf32>) -> tensor + tf_device.return %c : tensor + }) {device = "c"} : () -> tensor + tf_device.return %1, %2, %3, %4 : tensor, tensor<*xi32>, tensor<*xf32>, tensor + } + return +} + +// CHECK: %[[SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_0]]) +// CHECK-NEXT: %[[LAUNCH_A:[0-9]*]] = "tf_device.launch" +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[SHAPE]]) +// CHECK-NEXT: tf_device.return %[[OP_A]] +// CHECK-NEXT: device = "a" +// CHECK-NEXT: %[[LAUNCH_B:[0-9]*]] = "tf_device.launch" +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[SHAPE]], %[[LAUNCH_A]]) +// CHECK-NEXT: tf_device.return %[[OP_B]] +// CHECK-NEXT: device = "b" +// CHECK-NEXT: tf_device.replicate([{{.*}}] as %[[RI:[a-z0-9]+]]: tensor<*xf32>) +// CHECK-NEXT: %[[LAUNCH_C:[0-9]*]] = "tf_device.launch" +// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[RI]], %[[LAUNCH_B]]) +// CHECK-NEXT: tf_device.return %[[OP_C]] +// CHECK-NEXT: device = "c" +// CHECK-NEXT: tf_device.return %[[SHAPE]], %[[LAUNCH_A]], %[[LAUNCH_B]], %[[LAUNCH_C]] diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 7b4ae38726d..03e0b99a6ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -110,17 +111,20 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, // Checks if op and inner op operands are all replicate invariant. bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) { - auto result = op->walk([&](Operation* inner_op) { - for (Value operand : inner_op->getOperands()) { - Region* parent_region = operand.getParentRegion(); - if (!parent_region || !parent_region->isProperAncestor(replicate_region)) - return WalkResult::interrupt(); - } + auto ancestor_of_replicate = [&](Region* region) { + return region && region->isProperAncestor(replicate_region); + }; - return WalkResult::advance(); + for (Value operand : op->getOperands()) + if (!ancestor_of_replicate(operand.getParentRegion())) return false; + + bool has_replicate_operands = false; + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) { + if (!ancestor_of_replicate(operand->get().getParentRegion())) + has_replicate_operands = true; }); - return !result.wasInterrupted(); + return !has_replicate_operands; } // Hoists replicate invariant ops out of associated `tf_device.replicate` op. From 158719872aebb1c6607ebaa244f5d777b4b2e866 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 15:46:09 -0700 Subject: [PATCH 356/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302114563 Change-Id: I6ffc2792c08c236aa57cc447007d4d7dfa9731ce --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 75d86f71b78..68bb1dc49f5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From a5e787c1718678a6cff4af25b9683269b84deb91 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Fri, 20 Mar 2020 16:19:34 -0700 Subject: [PATCH 357/492] Fix typos: should be ChannelDimIndexInterface PiperOrigin-RevId: 302120701 Change-Id: I6c9290b8a379d572440a03729e0ba2223cc036fd --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 53bec976186..c90fdfbfe1c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -693,7 +693,7 @@ def TFL_ExternalConstOp : Op { def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> { let extraClassDeclaration = [{ - // StatefulOpInterface: + // ChannelDimIndexInterface: int GetChannelDimIndex() { return 0; } }]; } @@ -718,7 +718,7 @@ def TFL_DepthwiseConv2DOp : let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier)); let extraClassDeclaration = [{ - // StatefulOpInterface: + // ChannelDimIndexInterface: int GetChannelDimIndex() { return 3; } }]; } From c7fb55cf532669b3ed40a867cb795bda2470fea3 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 20 Mar 2020 16:20:23 -0700 Subject: [PATCH 358/492] [tf.data] Remove unused forwarding library. To access the "dataset.h" header, `#include "tensorflow/core/framework/dataset.h"` and depend on `"//tensorflow/core:framework"`. PiperOrigin-RevId: 302120830 Change-Id: I430cf0af9996dacc420bb983a0d82355386f910d --- tensorflow/core/kernels/data/BUILD | 8 -------- tensorflow/core/kernels/data/dataset.h | 20 ------------------- .../core/kernels/data/experimental/BUILD | 2 -- 3 files changed, 30 deletions(-) delete mode 100644 tensorflow/core/kernels/data/dataset.h diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 5f0e2343203..823a800e7bb 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -12,14 +12,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -# TODO(mrry): Remove this empty forwarding library. -cc_library( - name = "dataset", - srcs = [], - hdrs = ["dataset.h"], - deps = ["//tensorflow/core:framework"], -) - cc_library( name = "dataset_test_base", testonly = 1, diff --git a/tensorflow/core/kernels/data/dataset.h b/tensorflow/core/kernels/data/dataset.h deleted file mode 100644 index 2c6fc8d5b4f..00000000000 --- a/tensorflow/core/kernels/data/dataset.h +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ -#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ - -#include "tensorflow/core/framework/dataset.h" - -#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 298982eb356..baf804fa1f8 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -106,7 +106,6 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/kernels/data:dataset", ], ) @@ -239,7 +238,6 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/kernels/data:dataset", ], ) From 3caa238d42bfb4adb5f132d47f28b653f442b8f4 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Fri, 20 Mar 2020 16:25:06 -0700 Subject: [PATCH 359/492] Fix mirrored_strategy_test failure after contant tensors are always on CPU We use constant tensors a lot in our tests, and NCCL will complain if inputs are on the same device. PiperOrigin-RevId: 302121666 Change-Id: I6872ac2d63fbdfdac253f6e7c8b8602a8cd2fe7e --- tensorflow/python/distribute/BUILD | 2 +- .../python/distribute/distribute_lib.py | 15 ++++++++------- .../distribute/mirrored_strategy_test.py | 9 +++------ .../python/distribute/strategy_test_lib.py | 19 +++++++++---------- 4 files changed, 21 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 8b5308c4d52..72475136e8b 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1206,7 +1206,7 @@ cuda_py_test( srcs = ["mirrored_strategy_test.py"], shard_count = 5, tags = [ - # "multi_and_single_gpu", # b/151862653 + "multi_and_single_gpu", "no_windows_gpu", # TODO(b/130551176) ], deps = [ diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 9b819987899..c7ae3f22add 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -1035,15 +1035,16 @@ class StrategyBase(object): if dim is not None: # By returning a python value in the static shape case, we can # maybe get a fast path for reducing the denominator. - return numer, array_ops.constant(dim, dtype=dtypes.int64) + # TODO(b/151871486): Remove array_ops.identity after we fallback to + # simple reduction if inputs are all on CPU. + return numer, array_ops.identity( + constant_op.constant(dim, dtype=dtypes.int64)) elif axis < 0: axis = axis + array_ops.rank(v) - if v.shape.rank == 1: - # TODO(b/139422050): Currently tf.shape is not supported in TPU dynamic - # padder, use tf.size instead to workaround if the rank is 1. - denom = array_ops.size(v, out_type=dtypes.int64) - else: - denom = array_ops.shape_v2(v, out_type=dtypes.int64)[axis] + # TODO(b/151871486): Remove array_ops.identity after we fallback to simple + # reduction if inputs are all on CPU. + denom = array_ops.identity( + array_ops.shape_v2(v, out_type=dtypes.int64)[axis]) # TODO(josh11b): Should we cast denom to v.dtype here instead of after the # reduce is complete? return numer, denom diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 73a0f34c6bd..f06360ca021 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -127,7 +127,8 @@ class MirroredTwoDeviceDistributionTest( def replica_squared_fn(dtype=dtype): # Lists with different lengths on different replicas. replica_id = _replica_id_as_int() - return math_ops.cast([replica_id] * (replica_id + 1), dtype) + return array_ops.identity( + math_ops.cast([replica_id] * (replica_id + 1), dtype)) self.reduce_axis_helper(distribution, replica_squared_fn) @@ -1406,11 +1407,7 @@ def _replica_id(): replica_id = ds_context.get_replica_context().replica_id_in_sync_group if not isinstance(replica_id, ops.Tensor): replica_id = constant_op.constant(replica_id) - # TODO(b/149852830): Workaround for small Tensor caching (which is only on - # CPU) to ensure the value is on the correct device. - replica_id = math_ops.cast(replica_id, dtypes.float32) - replica_id = math_ops.cast(replica_id, dtypes.int32) - return replica_id + return array_ops.identity(replica_id) def _replica_id_as_int(): diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py index 148fda8008c..b3ececcdcba 100644 --- a/tensorflow/python/distribute/strategy_test_lib.py +++ b/tensorflow/python/distribute/strategy_test_lib.py @@ -34,7 +34,6 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -120,7 +119,7 @@ class DistributionTestBase(test.TestCase): l = core.Dense(1, use_bias=False) def loss(x): - y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + y = array_ops.reshape(l(x), []) - array_ops.identity(1.) return y * y # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a # common `implicit_grad` function and put it in DistributionStrategy. @@ -130,7 +129,7 @@ class DistributionTestBase(test.TestCase): def update(v, g): return v.assign_sub(0.2 * g) - one = constant_op.constant([[1.]]) + one = array_ops.identity([[1.]]) def step(): """Perform one optimization step.""" @@ -177,7 +176,7 @@ class DistributionTestBase(test.TestCase): l = core.Dense(1, use_bias=False) def loss(x): - y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + y = array_ops.reshape(l(x), []) - array_ops.identity(1.) return y * y grad_fn = backprop.implicit_grad(loss) @@ -185,7 +184,7 @@ class DistributionTestBase(test.TestCase): def update(v, g): return v.assign_sub(learning_rate * g) - one = constant_op.constant([[1.]]) + one = array_ops.identity([[1.]]) def step(): """Perform one optimization step.""" @@ -453,7 +452,7 @@ class OneDeviceDistributionTestBase(test.TestCase): """Some tests that should work with any one-device DistributionStrategy.""" def _test_run(self, strategy): - out1 = strategy.run(lambda: constant_op.constant(4.)) + out1 = strategy.run(lambda: array_ops.identity(4.)) self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1))) out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,)) @@ -506,7 +505,7 @@ class OneDeviceDistributionTestBase(test.TestCase): self.skipTest("`tf.gradients` is not supported with eager execution.") def step(c): - x = constant_op.constant(42.) + x = array_ops.identity(42.) y = comm_fn(x) * c return gradients_impl.gradients(y, [x])[0] @@ -524,7 +523,7 @@ class OneDeviceDistributionTestBase(test.TestCase): expected_grads): def step(c): - x = constant_op.constant(42.) + x = array_ops.identity(42.) with backprop.GradientTape() as tape: tape.watch(x) y = comm_fn(x) * c @@ -634,7 +633,7 @@ class TwoDeviceDistributionTestBase(test.TestCase): self.skipTest("`tf.gradients` is not supported with eager execution.") def step(c): - x = constant_op.constant(42.) + x = array_ops.identity(42.) y = comm_fn(x) * c return gradients_impl.gradients(y, [x])[0] @@ -652,7 +651,7 @@ class TwoDeviceDistributionTestBase(test.TestCase): expected_grads): def step(c): - x = constant_op.constant(42.) + x = array_ops.identity(42.) with backprop.GradientTape() as tape: tape.watch(x) y = comm_fn(x) * c From c984ec0b3605c4914546efd81ca0ac0729faa992 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Fri, 20 Mar 2020 16:46:25 -0700 Subject: [PATCH 360/492] Re-organize values_test This change separated common parts into DistributedVariableTest, and AggregateVaraible tests into its own ones as well. PiperOrigin-RevId: 302125354 Change-Id: I1cfba4d5956a70b7b743913eea4d0301c4c8d1ce --- tensorflow/python/distribute/BUILD | 2 + tensorflow/python/distribute/values_test.py | 280 +++++++++++--------- 2 files changed, 159 insertions(+), 123 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 72475136e8b..8f6231b7655 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -935,8 +935,10 @@ py_library( distribute_py_test( name = "values_test", + size = "medium", srcs = ["values_test.py"], main = "values_test.py", + shard_count = 5, tags = [ # "multi_and_single_gpu", # b/151865826 ], diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 290ea7d011a..0c7b3dffd2b 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -537,6 +537,126 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) +@combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + synchronization=[ + variables_lib.VariableSynchronization.ON_READ, + variables_lib.VariableSynchronization.ON_WRITE, + ], + aggregation=[ + variables_lib.VariableAggregation.MEAN, + variables_lib.VariableAggregation.SUM, + variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, + ], + mode=["graph", "eager"])) +class DistributedVariableTest(test.TestCase, parameterized.TestCase): + + def testExtendsVariable(self, distribution, synchronization, aggregation): + with distribution.scope(): + v = variables_lib.Variable( + 1., synchronization=synchronization, aggregation=aggregation) + self.assertIsInstance(v, variables_lib.Variable) + + def testCheckpointing(self, distribution, synchronization, aggregation): + with distribution.scope(): + v = variables_lib.Variable( + constant_op.constant([1., 2., 3., 4]), + synchronization=synchronization, + aggregation=aggregation) + + self.evaluate(v.initializer) + before_save = self.evaluate(v.read_value()) + + # Save random weights into checkpoint. + checkpoint = trackable_utils.Checkpoint(v=v) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + with self.test_session(): + save_path = checkpoint.save(prefix) + + # Assign inverted value. + self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.]))) + after_assign = self.evaluate(v.read_value()) + self.assertNotAllClose(before_save, after_assign) + + # Restore from the checkpoint. + with self.test_session(): + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + after_restore = self.evaluate(v) + self.assertAllClose(before_save, after_restore) + + def testTraceback(self, distribution, synchronization, aggregation): + if context.executing_eagerly(): + self.skipTest("does not apply to eager") + with distribution.scope(): + variable_scope.get_variable( + name="testVar", + initializer=1., + use_resource=True, + synchronization=synchronization, + aggregation=aggregation) + with self.assertRaisesRegex(ValueError, + "Variable testVar already exists"): + variable_scope.get_variable( + name="testVar", + initializer=1., + use_resource=True, + synchronization=synchronization, + aggregation=aggregation) + + def testSelectReplica(self, distribution, synchronization, aggregation): + with distribution.scope(): + v = variables_lib.Variable( + 1., synchronization=synchronization, aggregation=aggregation) + self.assertIs(v, values.select_replica(0, v)) + + def testIsTensorLike(self, distribution, synchronization, aggregation): + if isinstance(distribution.extended, + tpu_strategy.TPUExtended) and context.executing_eagerly(): + self.skipTest("TPU doesn't support pure eager") + + with distribution.scope(): + v = variables_lib.Variable( + 0., synchronization=synchronization, aggregation=aggregation) + # In cross replica context. + self.assertTrue(ops.is_dense_tensor_like(v)) + # In replica context. + distribution.run( + lambda v: self.assertTrue(ops.is_dense_tensor_like(v)), args=(v,)) + + def testAssignReturnValueIsTensorLike(self, distribution, synchronization, + aggregation): + if isinstance(distribution.extended, tpu_strategy.TPUExtended): + if context.executing_eagerly(): + self.skipTest("TPU doesn't support pure eager") + else: + self.skipTest("b/152076846") + + with distribution.scope(): + v = variables_lib.Variable( + 0., synchronization=synchronization, aggregation=aggregation) + + def assert_is_tensor_like(v): + # We can't use Python literals because they are treated as non-distributed + # values is not allowed when aggregation is SUM. See + # `cross_device_ops.reduce_non_distributed_value`. + delta = array_ops.identity(1.) + self.assertTrue(ops.is_dense_tensor_like(v.assign(delta))) + self.assertTrue(ops.is_dense_tensor_like(v.assign_sub(delta))) + self.assertTrue(ops.is_dense_tensor_like(v.assign_add(delta))) + + # In cross replica context we return a PerReplica which is not Tensor like + # yet. + + # In replica context. + distribution.run(assert_is_tensor_like, args=(v,)) + + class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() @@ -752,7 +872,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): strategy_combinations.tpu_strategy, ], mode=["graph", "eager"])) - def testAssignOutOfScope_mirrored(self, distribution): + def testAssignOutOfScope(self, distribution): with distribution.scope(): mirrored = variables_lib.Variable(1.) self.evaluate(mirrored.assign(3.)) @@ -760,20 +880,6 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): for component in mirrored.values: self.assertEqual(self.evaluate(component.read_value()), 3.) - @combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.central_storage_strategy_with_two_gpus - ], - mode=["graph", "eager"])) - def testAssignOutOfScope_aggregating(self, distribution): - with distribution.scope(): - aggregating = variables_lib.Variable(1.) - self.assertIsInstance(aggregating, values.AggregatingVariable) - self.evaluate(aggregating.assign(3.)) - self.assertEqual(self.evaluate(aggregating.read_value()), 3.) - self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.) - @combinations.generate( combinations.combine( distribution=[ @@ -835,80 +941,16 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine( distribution=[ - strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, - ], - mode=["graph", "eager"])) - def testExtendsVariable(self, distribution): - with distribution.scope(): - v = variables_lib.Variable(1.) - self.assertIsInstance(v, variables_lib.Variable) - - @combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_one_cpu, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, - ], - mode=["graph", "eager"])) - def testCheckpointing(self, distribution): - with distribution.scope(): - v = variables_lib.Variable(constant_op.constant([1., 2., 3., 4])) - - self.evaluate(v.initializer) - before_save = self.evaluate(v.read_value()) - - # Save random weights into checkpoint. - checkpoint = trackable_utils.Checkpoint(v=v) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - with self.test_session(): - save_path = checkpoint.save(prefix) - - # Assign inverted value. - self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.]))) - after_assign = self.evaluate(v.read_value()) - self.assertNotAllClose(before_save, after_assign) - - # Restore from the checkpoint. - with self.test_session(): - checkpoint.restore(save_path).assert_consumed().run_restore_ops() - after_restore = self.evaluate(v) - self.assertAllClose(before_save, after_restore) - - @combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_one_cpu, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, - ], - mode=["graph"])) - def testTraceback(self, distribution): - with distribution.scope(): - variable_scope.get_variable( - name="testVar", initializer=1., use_resource=True) - with self.assertRaisesRegex( - ValueError, "Variable testVar already exists"): - variable_scope.get_variable( - name="testVar", initializer=1., use_resource=True) - - @combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, ], mode=["eager"])) def testInitializedToSameValueInsideEagerRun(self, distribution): v = [None] + @def_function.function def step(): + def f(): if v[0] is None: v[0] = variables_lib.Variable(random_ops.random_normal([])) @@ -926,45 +968,6 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, - ], - mode=["graph", "eager"])) - def testSelectReplica(self, distribution): - with distribution.scope(): - v = variables_lib.Variable(1.) - self.assertIs(v, values.select_replica(0, v)) - - @combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_one_cpu, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, - ], - mode=["graph", "eager"])) - def testModAfterAssign(self, distribution): - with distribution.scope(): - v = variables_lib.Variable(0) - def replica_fn(): - def merge_fn(_): - return math_ops.mod(v.assign_add(1), 2) - return distribution_strategy_context.get_replica_context().merge_call( - merge_fn) - - @def_function.function - def foo(): - distribution.run(replica_fn) - - foo() - - @combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_one_cpu, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, ], mode=["graph", "eager"])) def testAggregationOnlyFirstReplica(self, distribution): @@ -992,7 +995,6 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, ], mode=["graph", "eager"])) def testAssignAdd(self, distribution): @@ -1006,8 +1008,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): return v.assign_add(2) per_replica_results = self.evaluate( - distribution.experimental_local_results( - distribution.run(assign))) + distribution.experimental_local_results(distribution.run(assign))) # The per-replica values should always match the first replicas value. self.assertAllEqual([3, 3], per_replica_results) @@ -1791,6 +1792,39 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(vals[0], vals[1]) +@combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.central_storage_strategy_with_two_gpus + ], + mode=["graph", "eager"])) +class AggregatingVariableTest(test.TestCase, parameterized.TestCase): + + def testAssignOutOfScope(self, distribution): + with distribution.scope(): + aggregating = variables_lib.Variable(1.) + self.assertIsInstance(aggregating, values.AggregatingVariable) + self.evaluate(aggregating.assign(3.)) + self.assertEqual(self.evaluate(aggregating.read_value()), 3.) + self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.) + + def testAssignAdd(self, distribution): + self.skipTest("b/151250566") + with distribution.scope(): + v = variable_scope.variable( + 1, aggregation=variables_lib.VariableAggregation.MEAN) + self.evaluate(variables_lib.global_variables_initializer()) + + @def_function.function + def assign(): + return v.assign_add(2) + + per_replica_results = self.evaluate( + distribution.experimental_local_results( + distribution.experimental_run_v2(assign))) + self.assertAllEqual([3], per_replica_results) + + class MirroredTest(test.TestCase): def testAddOp(self): From 7de80adca641ff8f64bdc5c8a7dbd1f6c51fea1c Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 20 Mar 2020 17:03:05 -0700 Subject: [PATCH 361/492] [tf.data] Fix "passing result of std::move() as const reference argument" warning. To avoid an implicit string-to-tstring conversion that was causing the warning, this change switches to using tstring as the string type for reading a line from a text file, and adds the necessary tstring overload in `BufferedInputStream`. PiperOrigin-RevId: 302128176 Change-Id: I55592893c61a66922285896f6c1a42a6ca8e4785 --- .../core/kernels/data/text_line_dataset_op.cc | 11 +++++----- .../core/lib/io/buffered_inputstream.cc | 20 ++++++++++++------- tensorflow/core/lib/io/buffered_inputstream.h | 4 +++- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc index dc193f53a8d..550a859093d 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc @@ -100,18 +100,17 @@ class TextLineDatasetOp::Dataset : public DatasetBase { do { // We are currently processing a file, so try to read the next line. if (buffered_input_stream_) { - string line_contents; - Status s = buffered_input_stream_->ReadLine(&line_contents); + Tensor line_contents(tstring{}); + tstring& line_contents_str = line_contents.scalar()(); + Status s = buffered_input_stream_->ReadLine(&line_contents_str); if (s.ok()) { // Produce the line as output. static monitoring::CounterCell* bytes_counter = metrics::GetTFDataBytesReadCounter( name_utils::OpName(TextLineDatasetOp::kDatasetType)); - bytes_counter->IncrementBy(line_contents.size()); - out_tensors->emplace_back(ctx->allocator({}), DT_STRING, - TensorShape({})); - out_tensors->back().scalar()() = line_contents; + bytes_counter->IncrementBy(line_contents_str.size()); + out_tensors->push_back(std::move(line_contents)); *end_of_sequence = false; return Status::OK(); } else if (!errors::IsOutOfRange(s)) { diff --git a/tensorflow/core/lib/io/buffered_inputstream.cc b/tensorflow/core/lib/io/buffered_inputstream.cc index 5e3e8bfed71..94479a1149f 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.cc +++ b/tensorflow/core/lib/io/buffered_inputstream.cc @@ -21,17 +21,17 @@ namespace tensorflow { namespace io { BufferedInputStream::BufferedInputStream(InputStreamInterface* input_stream, - size_t buffer_size, + size_t buffer_bytes, bool owns_input_stream) : input_stream_(input_stream), - size_(buffer_size), + size_(buffer_bytes), owns_input_stream_(owns_input_stream) { buf_.reserve(size_); } BufferedInputStream::BufferedInputStream(RandomAccessFile* file, - size_t buffer_size) - : BufferedInputStream(new RandomAccessInputStream(file), buffer_size, + size_t buffer_bytes) + : BufferedInputStream(new RandomAccessInputStream(file), buffer_bytes, true) {} BufferedInputStream::~BufferedInputStream() { @@ -56,7 +56,9 @@ Status BufferedInputStream::FillBuffer() { return s; } -Status BufferedInputStream::ReadLineHelper(string* result, bool include_eol) { +template +Status BufferedInputStream::ReadLineHelper(StringType* result, + bool include_eol) { result->clear(); Status s; while (true) { @@ -70,13 +72,13 @@ Status BufferedInputStream::ReadLineHelper(string* result, bool include_eol) { char c = buf_[pos_++]; if (c == '\n') { if (include_eol) { - *result += c; + result->append(1, c); } return Status::OK(); } // We don't append '\r' to *result if (c != '\r') { - *result += c; + result->append(1, c); } } if (errors::IsOutOfRange(s) && !result->empty()) { @@ -202,6 +204,10 @@ Status BufferedInputStream::ReadLine(string* result) { return ReadLineHelper(result, false); } +Status BufferedInputStream::ReadLine(tstring* result) { + return ReadLineHelper(result, false); +} + string BufferedInputStream::ReadLineAsString() { string result; ReadLineHelper(&result, true).IgnoreError(); diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h index a247bb41675..fde3088f824 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.h +++ b/tensorflow/core/lib/io/buffered_inputstream.h @@ -66,6 +66,7 @@ class BufferedInputStream : public InputStreamInterface { // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. tensorflow::Status ReadLine(string* result); + tensorflow::Status ReadLine(tstring* result); // Returns one text line of data until end-of-file or a '\n' is read. The '\n' // is included in the result. @@ -86,7 +87,8 @@ class BufferedInputStream : public InputStreamInterface { private: tensorflow::Status FillBuffer(); - tensorflow::Status ReadLineHelper(string* result, bool include_eol); + template + tensorflow::Status ReadLineHelper(StringType* result, bool include_eol); InputStreamInterface* input_stream_; // not owned. size_t size_; // buffer size. From d6af13cafff28fc238898a08d6795b7b506b4a18 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 17:16:45 -0700 Subject: [PATCH 362/492] Update Eigen to: https://gitlab.com/libeigen/eigen/-/commit/4da2c6b1974827b1999bab652a3d4703e1992d26 PiperOrigin-RevId: 302130422 Change-Id: I96fc471914c445c0a9e065a8b2378bedb3f60e2a --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 1066479823a..6f1feead83e 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -201,11 +201,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "ce221392c106e90fa28a2ffccf6e45869477b40e17a0b0728334e5e1970de294", # SHARED_EIGEN_SHA - strip_prefix = "eigen-7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83", + sha256 = "eb50646c27d32791d6b09b0422f29f52b8ff0385354abd117f68aa66da1e2e92", # SHARED_EIGEN_SHA + strip_prefix = "eigen-4da2c6b1974827b1999bab652a3d4703e1992d26", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83/eigen-7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83/eigen-7158ed4e0e34d40cd0f358a3bf69a5c30d8d0f83.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/4da2c6b1974827b1999bab652a3d4703e1992d26/eigen-4da2c6b1974827b1999bab652a3d4703e1992d26.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/4da2c6b1974827b1999bab652a3d4703e1992d26/eigen-4da2c6b1974827b1999bab652a3d4703e1992d26.tar.gz", ], ) From ca16bf312b95d5768ddee67734af8d5ad99f94b3 Mon Sep 17 00:00:00 2001 From: Robert David Date: Fri, 20 Mar 2020 17:18:49 -0700 Subject: [PATCH 363/492] Small Softmax cleanups: - Remove OpData. Use SoftmaxParams directly. - Only call CalculateSoftmaxOpData for quantized case, rename to CalculateQuantizedSoftmaxParams. - Add stricter type checks to CalculateQuantizedSoftmaxParams. PiperOrigin-RevId: 302130753 Change-Id: Icb5e9cfa28e9179d4c91f67325f9721edbd8eb9b --- tensorflow/lite/micro/kernels/softmax.cc | 95 ++++++++++-------------- 1 file changed, 40 insertions(+), 55 deletions(-) diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc index 85952de9d50..1f30ddc5949 100644 --- a/tensorflow/lite/micro/kernels/softmax.cc +++ b/tensorflow/lite/micro/kernels/softmax.cc @@ -29,41 +29,37 @@ namespace micro { namespace activations { namespace { -struct OpData { - int32_t input_multiplier = 0; - int input_left_shift = 0; - int32_t input_range_radius = 0; - int diff_min = 0; -}; - -TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, - const TfLiteTensor* input, - TfLiteTensor* output, - const TfLiteSoftmaxParams* params, - OpData* data) { - if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { - if (input->type == kTfLiteUInt8) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); +TfLiteStatus CalculateQuantizedSoftmaxParams(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + const TfLiteSoftmaxParams* params, + SoftmaxParams* data) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8); + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); + // NOTE: Current int16 softmax output does not require symmetric scaling + // - so no need to verify scale here. } else { - if (output->type == kTfLiteInt16) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); - // NOTE: Current int16 softmax output does not require symmetric scaling - // - so no need to verify scale here. - } else { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); - TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); - } + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); + TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); } - - static const int kScaledDiffIntegerBits = 5; - - tflite::PreprocessSoftmaxScaling( - static_cast(params->beta), - static_cast(input->params.scale), kScaledDiffIntegerBits, - &data->input_multiplier, &data->input_left_shift); - data->diff_min = -1.0 * tflite::CalculateInputRadius( - kScaledDiffIntegerBits, data->input_left_shift); } + + static const int kScaledDiffIntegerBits = 5; + + int input_left_shift; + tflite::PreprocessSoftmaxScaling(static_cast(params->beta), + static_cast(input->params.scale), + kScaledDiffIntegerBits, + &data->input_multiplier, &input_left_shift); + data->input_left_shift = input_left_shift; + data->diff_min = -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, + data->input_left_shift); return kTfLiteOk; } @@ -97,7 +93,8 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, } void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { + TfLiteSoftmaxParams* params, + const SoftmaxParams& op_params) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation // always traverses the last dimension of a 4D tensor, we will pretend our 1D // tensor is 4D in a special way. We will convert a (Y) shape into a (1, @@ -105,10 +102,6 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, const int input_size = input->dims->data[0]; const int32_t shape_data[4] = {1, 1, 1, input_size}; RuntimeShape shape(4, shape_data); - SoftmaxParams op_params; - op_params.input_multiplier = data->input_multiplier; - op_params.input_left_shift = data->input_left_shift; - op_params.diff_min = data->diff_min; if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax(op_params, shape, GetTensorData(input), shape, @@ -127,7 +120,8 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, } void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { + TfLiteSoftmaxParams* params, + const SoftmaxParams& op_params) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation // always traverses the last dimension of a 4D tensor, we will pretend our 2D // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, @@ -136,10 +130,6 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, const int input_size = input->dims->data[1]; const int32_t shape_data[4] = {batch_size, 1, 1, input_size}; RuntimeShape shape(4, shape_data); - SoftmaxParams op_params; - op_params.input_multiplier = data->input_multiplier; - op_params.input_left_shift = data->input_left_shift; - op_params.diff_min = data->diff_min; if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax(op_params, shape, GetTensorData(input), shape, @@ -168,11 +158,8 @@ void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, } void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { - SoftmaxParams op_params; - op_params.input_multiplier = data->input_multiplier; - op_params.input_left_shift = data->input_left_shift; - op_params.diff_min = data->diff_min; + TfLiteSoftmaxParams* params, + const SoftmaxParams& op_params) { if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax( op_params, GetTensorShape(input), GetTensorData(input), @@ -196,11 +183,6 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - OpData local_data_object; - OpData* data = &local_data_object; - TF_LITE_ENSURE_STATUS( - CalculateSoftmaxOpData(context, input, output, params, data)); - // TODO(ahentz): consider an implementation that works for many (all?) // dimensions. switch (input->type) { @@ -224,16 +206,19 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } case kTfLiteInt8: case kTfLiteUInt8: { + SoftmaxParams op_params; + TF_LITE_ENSURE_STATUS(CalculateQuantizedSoftmaxParams( + context, input, output, params, &op_params)); if (NumDimensions(input) == 1) { - Softmax1DQuantized(input, output, params, data); + Softmax1DQuantized(input, output, params, op_params); return kTfLiteOk; } if (NumDimensions(input) == 2) { - Softmax2DQuantized(input, output, params, data); + Softmax2DQuantized(input, output, params, op_params); return kTfLiteOk; } if (NumDimensions(input) == 4) { - Softmax4DQuantized(input, output, params, data); + Softmax4DQuantized(input, output, params, op_params); return kTfLiteOk; } TF_LITE_KERNEL_LOG( From c5239755706a3dcbaeb3e3198b0cae8733bf50bb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 17:19:47 -0700 Subject: [PATCH 364/492] Add a few more test cases for the shape function of Reshape. Add a missing return statement in shape inference. PiperOrigin-RevId: 302130894 Change-Id: I021d501adca880a4cb49e0adff5fe83e5a685ac1 --- tensorflow/core/grappler/costs/graph_properties.cc | 1 + tensorflow/core/ops/array_ops.cc | 3 +-- tensorflow/core/ops/array_ops_test.cc | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 569c7cadeef..e3688cc0a6d 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -1777,6 +1777,7 @@ class SymbolicShapeRefiner { } if (!has_values_smaller_than_minus_1) { *tensors_as_shapes = ic->MakeShape(dims); + return true; } } else if (IsIntegerScalar(tensor)) { // Scalar constant. diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 9c4c59872f9..98803bfe086 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -194,8 +194,7 @@ Status SetOutputShapeForReshape(InferenceContext* c) { c->set_output(0, out); return Status::OK(); } - - if (c->RankKnown(out) && c->RankKnown(in)) { + if (c->RankKnown(in)) { // We don't know the number of output elements, but we can try to infer // the missing dimension. bool too_many_unknown = false; diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 5d9700a6f67..443e1124df8 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -862,12 +862,14 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) { // No valid shape provided. INFER_OK(op, "?;?", "?"); INFER_OK(op, "[?];?", "?"); + INFER_OK(op, "?;[?]", "?"); INFER_OK(op, "[?];[?]", "?"); INFER_OK(op, "[4];[?]", "?"); // All dimensions provided. Tensor new_shape = test::AsTensor({1, 2, 3}); op.input_tensors[1] = &new_shape; + INFER_OK(op, "?;[3]", "[1,2,3]"); INFER_OK(op, "[?];[3]", "[1,2,3]"); INFER_OK(op, "[6];[3]", "[1,2,3]"); // The number of elements should match for the reshape to succeed. @@ -878,6 +880,7 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) { // Unknown dimensions. // Flatten: new_shape = test::AsTensor({-1}); + INFER_OK(op, "?;[1]", "[?]"); INFER_OK(op, "[?];[1]", "[d0_0]"); INFER_OK(op, "[2,2];[1]", "[4]"); // The first dimension is inferred: @@ -890,6 +893,7 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) { // Multiple missing dimensions cannot be inferred. new_shape = test::AsTensor({-1, -1, 2}); INFER_OK(op, "[8];[3]", "[?,?,2]"); + INFER_OK(op, "?;[3]", "[?,?,2]"); // Symbolic shape propagation new_shape = test::AsTensor({-1, 2, 3}); From f93485dad7d9b21f7c2bc3aee3134d532131a939 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Fri, 20 Mar 2020 17:36:02 -0700 Subject: [PATCH 365/492] [tf.data service] Add initial master and worker implementations. PiperOrigin-RevId: 302133273 Change-Id: I89707e53e7a014380264b7a2594b39053450ce51 --- tensorflow/core/data/service/BUILD | 130 +++++++++ .../core/data/service/data_service_test.cc | 267 ++++++++++++++++++ .../core/data/service/grpc_master_impl.cc | 48 ++++ .../core/data/service/grpc_master_impl.h | 60 ++++ .../core/data/service/grpc_worker_impl.cc | 52 ++++ .../core/data/service/grpc_worker_impl.h | 61 ++++ .../data/service/local_credentials_factory.cc | 48 ++++ tensorflow/core/data/service/master_impl.cc | 208 ++++++++++++++ tensorflow/core/data/service/master_impl.h | 123 ++++++++ tensorflow/core/data/service/server_lib.cc | 88 ++++++ tensorflow/core/data/service/server_lib.h | 73 +++++ tensorflow/core/data/service/worker_impl.cc | 157 ++++++++++ tensorflow/core/data/service/worker_impl.h | 81 ++++++ 13 files changed, 1396 insertions(+) create mode 100644 tensorflow/core/data/service/data_service_test.cc create mode 100644 tensorflow/core/data/service/grpc_master_impl.cc create mode 100644 tensorflow/core/data/service/grpc_master_impl.h create mode 100644 tensorflow/core/data/service/grpc_worker_impl.cc create mode 100644 tensorflow/core/data/service/grpc_worker_impl.h create mode 100644 tensorflow/core/data/service/local_credentials_factory.cc create mode 100644 tensorflow/core/data/service/master_impl.cc create mode 100644 tensorflow/core/data/service/master_impl.h create mode 100644 tensorflow/core/data/service/server_lib.cc create mode 100644 tensorflow/core/data/service/server_lib.h create mode 100644 tensorflow/core/data/service/worker_impl.cc create mode 100644 tensorflow/core/data/service/worker_impl.h diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 6c8116a6de8..b0b6ce3f3e7 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -45,6 +45,63 @@ tf_proto_library( ], ) +cc_library( + name = "master_impl", + srcs = ["master_impl.cc"], + hdrs = [ + "master_impl.h", + ], + deps = [ + ":common_proto_cc", + ":credentials_factory", + ":grpc_util", + ":master_proto_cc", + ":worker_cc_grpc_proto", + ":worker_proto_cc", + "//tensorflow:grpc++", + "//tensorflow/c:c_api_internal", + "//tensorflow/c:tf_status_helper", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels/data:dataset_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "worker_impl", + srcs = ["worker_impl.cc"], + hdrs = [ + "worker_impl.h", + ], + deps = [ + ":common_proto_cc", + ":compression_utils", + ":credentials_factory", + ":grpc_util", + ":master_cc_grpc_proto", + ":master_proto_cc", + ":worker_proto_cc", + "//tensorflow:grpc++", + "//tensorflow/c:c_api_internal", + "//tensorflow/c:tf_status_helper", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/data:standalone", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + ], +) + cc_library( name = "grpc_util", srcs = ["grpc_util.cc"], @@ -125,6 +182,17 @@ tf_cc_test( ], ) +# Link this target to enable LOCAL credentials for the dataset service. +cc_library( + name = "local_credentials_factory", + srcs = ["local_credentials_factory.cc"], + deps = [ + ":credentials_factory", + "//tensorflow:grpc++", + ], + alwayslink = 1, +) + cc_library( name = "test_util", testonly = True, @@ -155,6 +223,68 @@ tf_cc_test( ], ) +cc_library( + name = "grpc_master_impl", + srcs = ["grpc_master_impl.cc"], + hdrs = ["grpc_master_impl.h"], + deps = [ + ":master_cc_grpc_proto", + ":master_impl", + "//tensorflow:grpc++", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + ], +) + +cc_library( + name = "grpc_worker_impl", + srcs = ["grpc_worker_impl.cc"], + hdrs = ["grpc_worker_impl.h"], + deps = [ + ":worker_cc_grpc_proto", + ":worker_impl", + "//tensorflow:grpc++", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + ], +) + +cc_library( + name = "server_lib", + srcs = ["server_lib.cc"], + hdrs = ["server_lib.h"], + deps = [ + ":credentials_factory", + ":grpc_master_impl", + ":grpc_worker_impl", + "//tensorflow:grpc++", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow", + ], +) + +tf_cc_test( + name = "data_service_test", + srcs = ["data_service_test.cc"], + deps = [ + ":compression_utils", + ":grpc_master_impl", + ":grpc_util", + ":grpc_worker_impl", + ":local_credentials_factory", + ":master_cc_grpc_proto", + ":master_proto_cc", + ":server_lib", + ":test_util", + ":worker_cc_grpc_proto", + ":worker_proto_cc", + "//tensorflow:grpc++", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/kernels/data:dataset_test_base", + "@com_google_absl//absl/strings", + ], +) + cc_grpc_library( name = "master_cc_grpc_proto", srcs = [":master_proto"], diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc new file mode 100644 index 00000000000..0eb3ca55c05 --- /dev/null +++ b/tensorflow/core/data/service/data_service_test.cc @@ -0,0 +1,267 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" +#include "absl/strings/str_split.h" +#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/service/grpc_util.h" +#include "tensorflow/core/data/service/master.grpc.pb.h" +#include "tensorflow/core/data/service/master.pb.h" +#include "tensorflow/core/data/service/server_lib.h" +#include "tensorflow/core/data/service/test_util.h" +#include "tensorflow/core/data/service/worker.grpc.pb.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { + +namespace { +const char kProtocol[] = "grpc+local"; + +// Parse the address from a string in the form "://
". +Status AddressFromTarget(const std::string& target, std::string* address) { + std::vector parts = absl::StrSplit(target, "://"); + if (parts.size() != 2) { + return errors::InvalidArgument("target ", target, " split into ", + parts.size(), " parts, not 2"); + } + *address = parts[1]; + return Status::OK(); +} + +class TestCluster { + public: + explicit TestCluster(int num_workers) : num_workers_(num_workers) {} + + Status Initialize() { + TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_)); + TF_RETURN_IF_ERROR(master_->Start()); + TF_RETURN_IF_ERROR(AddressFromTarget(master_->Target(), &master_address_)); + workers_.reserve(num_workers_); + worker_addresses_.reserve(num_workers_); + for (int i = 0; i < num_workers_; ++i) { + TF_RETURN_IF_ERROR(AddWorker()); + } + return Status::OK(); + } + + Status AddWorker() { + workers_.emplace_back(); + TF_RETURN_IF_ERROR(NewWorkerServer(/*port=*/0, kProtocol, master_address_, + &workers_.back())); + TF_RETURN_IF_ERROR(workers_.back()->Start()); + worker_addresses_.emplace_back(); + TF_RETURN_IF_ERROR(AddressFromTarget(workers_.back()->Target(), + &worker_addresses_.back())); + return Status::OK(); + } + + std::string MasterAddress() { return master_address_; } + + std::string WorkerAddress(int index) { return worker_addresses_[index]; } + + private: + int num_workers_; + std::unique_ptr master_; + std::string master_address_; + std::vector> workers_; + std::vector worker_addresses_; +}; + +Status RegisterDataset(MasterService::Stub* master_stub, + const GraphDef& dataset_graph, int64* dataset_id) { + grpc_impl::ClientContext ctx; + GetOrRegisterDatasetRequest req; + *req.mutable_dataset()->mutable_graph() = dataset_graph; + GetOrRegisterDatasetResponse resp; + grpc::Status s = master_stub->GetOrRegisterDataset(&ctx, req, &resp); + if (!s.ok()) { + return grpc_util::WrapError("Failed to register dataset", s); + } + *dataset_id = resp.dataset_id(); + return Status::OK(); +} + +Status BeginEpoch(MasterService::Stub* master_stub, int64 dataset_id, + int64* epoch_id) { + grpc_impl::ClientContext ctx; + BeginEpochRequest req; + req.set_dataset_id(dataset_id); + BeginEpochResponse resp; + grpc::Status s = master_stub->BeginEpoch(&ctx, req, &resp); + if (!s.ok()) { + return grpc_util::WrapError("Failed to begin epoch", s); + } + *epoch_id = resp.epoch_id(); + return Status::OK(); +} + +Status GetTasks(MasterService::Stub* master_stub, int64 epoch_id, + std::vector* tasks) { + grpc_impl::ClientContext ctx; + GetTasksRequest req; + req.set_epoch_id(epoch_id); + GetTasksResponse resp; + grpc::Status s = master_stub->GetTasks(&ctx, req, &resp); + if (!s.ok()) { + return grpc_util::WrapError("Failed to get tasks", s); + } + tasks->clear(); + for (auto& task : resp.task_info()) { + tasks->push_back(task); + } + return Status::OK(); +} + +Status GetElement(WorkerService::Stub* worker_stub, int64 task_id, + std::vector* element, bool* end_of_sequence) { + grpc_impl::ClientContext ctx; + GetElementRequest req; + req.set_task_id(task_id); + GetElementResponse resp; + grpc::Status s = worker_stub->GetElement(&ctx, req, &resp); + if (!s.ok()) { + return grpc_util::WrapError("Failed to get element", s); + } + *end_of_sequence = resp.end_of_sequence(); + if (!*end_of_sequence) { + const CompressedElement& compressed = resp.compressed_element(); + TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, element)); + } + return Status::OK(); +} + +Status CheckWorkerOutput(const std::string& worker_address, int64 task_id, + std::vector> expected_output) { + auto worker_channel = grpc::CreateChannel( + worker_address, grpc::experimental::LocalCredentials(LOCAL_TCP)); + std::unique_ptr worker_stub = + WorkerService::NewStub(worker_channel); + for (std::vector& expected : expected_output) { + bool end_of_sequence; + std::vector element; + TF_RETURN_IF_ERROR( + GetElement(worker_stub.get(), task_id, &element, &end_of_sequence)); + if (end_of_sequence) { + return errors::Internal("Reached end of sequence too early."); + } + TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected, + /*compare_order=*/true)); + } + // Call GetElement a couple more times to verify tha end_of_sequence keeps + // returning true. + bool end_of_sequence; + std::vector element; + TF_RETURN_IF_ERROR( + GetElement(worker_stub.get(), task_id, &element, &end_of_sequence)); + if (!end_of_sequence) { + return errors::Internal("Expected end_of_sequence to be true"); + } + TF_RETURN_IF_ERROR( + GetElement(worker_stub.get(), task_id, &element, &end_of_sequence)); + if (!end_of_sequence) { + return errors::Internal("Expected end_of_sequence to be true"); + } + return Status::OK(); +} + +} // namespace + +TEST(DataService, IterateDatasetOneWorker) { + TestCluster cluster(1); + TF_ASSERT_OK(cluster.Initialize()); + test_util::GraphDefTestCase test_case; + TF_ASSERT_OK(test_util::map_test_case(&test_case)); + auto master_channel = grpc::CreateChannel( + cluster.MasterAddress(), grpc::experimental::LocalCredentials(LOCAL_TCP)); + std::unique_ptr master_stub = + MasterService::NewStub(master_channel); + + int64 dataset_id; + TF_ASSERT_OK( + RegisterDataset(master_stub.get(), test_case.graph_def, &dataset_id)); + int64 epoch_id; + TF_ASSERT_OK(BeginEpoch(master_stub.get(), dataset_id, &epoch_id)); + std::vector tasks; + TF_ASSERT_OK(GetTasks(master_stub.get(), epoch_id, &tasks)); + ASSERT_EQ(tasks.size(), 1); + ASSERT_EQ(tasks[0].worker_address(), cluster.WorkerAddress(0)); + + TF_ASSERT_OK(CheckWorkerOutput(tasks[0].worker_address(), tasks[0].id(), + test_case.output)); +} + +TEST(DataService, IterateDatasetTwoWorkers) { + TestCluster cluster(2); + TF_ASSERT_OK(cluster.Initialize()); + test_util::GraphDefTestCase test_case; + TF_ASSERT_OK(test_util::map_test_case(&test_case)); + auto master_channel = grpc::CreateChannel( + cluster.MasterAddress(), grpc::experimental::LocalCredentials(LOCAL_TCP)); + std::unique_ptr master_stub = + MasterService::NewStub(master_channel); + + int64 dataset_id; + TF_ASSERT_OK( + RegisterDataset(master_stub.get(), test_case.graph_def, &dataset_id)); + int64 epoch_id; + TF_ASSERT_OK(BeginEpoch(master_stub.get(), dataset_id, &epoch_id)); + std::vector tasks; + TF_ASSERT_OK(GetTasks(master_stub.get(), epoch_id, &tasks)); + ASSERT_EQ(tasks.size(), 2); + + // Each worker produces the full dataset. + for (TaskInfo task : tasks) { + TF_ASSERT_OK( + CheckWorkerOutput(task.worker_address(), task.id(), test_case.output)); + } +} + +TEST(DataService, AddWorkerMidEpoch) { + TestCluster cluster(1); + TF_ASSERT_OK(cluster.Initialize()); + test_util::GraphDefTestCase test_case; + TF_ASSERT_OK(test_util::map_test_case(&test_case)); + auto master_channel = grpc::CreateChannel( + cluster.MasterAddress(), grpc::experimental::LocalCredentials(LOCAL_TCP)); + std::unique_ptr master_stub = + MasterService::NewStub(master_channel); + + int64 dataset_id; + TF_ASSERT_OK( + RegisterDataset(master_stub.get(), test_case.graph_def, &dataset_id)); + int64 epoch_id; + TF_ASSERT_OK(BeginEpoch(master_stub.get(), dataset_id, &epoch_id)); + std::vector tasks; + TF_ASSERT_OK(GetTasks(master_stub.get(), epoch_id, &tasks)); + ASSERT_EQ(tasks.size(), 1); + TF_ASSERT_OK(cluster.AddWorker()); + TF_ASSERT_OK(GetTasks(master_stub.get(), epoch_id, &tasks)); + ASSERT_EQ(tasks.size(), 2); + + // Each worker produces the full dataset. + for (TaskInfo task : tasks) { + TF_ASSERT_OK( + CheckWorkerOutput(task.worker_address(), task.id(), test_case.output)); + } +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/grpc_master_impl.cc b/tensorflow/core/data/service/grpc_master_impl.cc new file mode 100644 index 00000000000..d7f21bfc406 --- /dev/null +++ b/tensorflow/core/data/service/grpc_master_impl.cc @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/grpc_master_impl.h" + +#include "grpcpp/server_context.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace tensorflow { +namespace data { + +using ::grpc::ServerBuilder; +using ::grpc::ServerContext; +using ::grpc::Status; + +GrpcMasterImpl::GrpcMasterImpl(ServerBuilder* server_builder, + const std::string& protocol) + : impl_(protocol) { + server_builder->RegisterService(this); + VLOG(1) << "Registered data service master"; +} + +#define HANDLER(method) \ + Status GrpcMasterImpl::method(ServerContext* context, \ + const method##Request* request, \ + method##Response* response) { \ + return ToGrpcStatus(impl_.method(request, response)); \ + } +HANDLER(RegisterWorker); +HANDLER(GetOrRegisterDataset); +HANDLER(BeginEpoch); +HANDLER(GetTasks); +#undef HANDLER + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/grpc_master_impl.h b/tensorflow/core/data/service/grpc_master_impl.h new file mode 100644 index 00000000000..cd4ffe30d79 --- /dev/null +++ b/tensorflow/core/data/service/grpc_master_impl.h @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ + +#include "grpcpp/server_builder.h" +#include "tensorflow/core/data/service/master.grpc.pb.h" +#include "tensorflow/core/data/service/master_impl.h" + +namespace tensorflow { +namespace data { + +// This class is a wrapper that handles communication for gRPC. +// +// Example usage: +// +// ::grpc::ServerBuilder builder; +// // configure builder +// GrpcMasterImpl data_service(&builder); +// builder.BuildAndStart() +// +class GrpcMasterImpl : public MasterService::Service { + public: + explicit GrpcMasterImpl(grpc::ServerBuilder* server_builder, + const std::string& protocol); + ~GrpcMasterImpl() override {} + + private: +#define HANDLER(method) \ + grpc::Status method(grpc::ServerContext* context, \ + const method##Request* request, \ + method##Response* response) override; + HANDLER(RegisterWorker); + HANDLER(GetOrRegisterDataset); + HANDLER(BeginEpoch); + HANDLER(GetTasks); +#undef HANDLER + + DataServiceMasterImpl impl_; + + TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterImpl); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc new file mode 100644 index 00000000000..a5d005d6c6e --- /dev/null +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/grpc_worker_impl.h" + +#include "grpcpp/server_context.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace tensorflow { +namespace data { + +using ::grpc::ServerBuilder; +using ::grpc::ServerContext; +using ::grpc::Status; + +GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder, + const std::string& master_address, + const std::string& protocol) + : impl_(master_address, protocol) { + server_builder->RegisterService(this); + LOG(INFO) << "GrpcWorkerImpl: master address is " << master_address; + VLOG(1) << "Registered data service worker"; +} + +void GrpcWorkerImpl::Start(const std::string& worker_address) { + impl_.Start(worker_address); +} + +#define HANDLER(method) \ + Status GrpcWorkerImpl::method(ServerContext* context, \ + const method##Request* request, \ + method##Response* response) { \ + return ToGrpcStatus(impl_.method(request, response)); \ + } +HANDLER(ProcessTask); +HANDLER(GetElement); +#undef HANDLER + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h new file mode 100644 index 00000000000..b7ece2a7738 --- /dev/null +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_WORKER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_WORKER_IMPL_H_ + +#include "grpcpp/server_builder.h" +#include "tensorflow/core/data/service/worker.grpc.pb.h" +#include "tensorflow/core/data/service/worker_impl.h" + +namespace tensorflow { +namespace data { + +// This class is a wrapper that handles communication for gRPC. +// +// Example usage: +// +// ::grpc::ServerBuilder builder; +// // configure builder +// GrpcWorkerImpl data_service(&builder); +// builder.BuildAndStart() +// +class GrpcWorkerImpl : public WorkerService::Service { + public: + explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder, + const std::string& master_address, + const std::string& protocol); + ~GrpcWorkerImpl() override {} + + void Start(const std::string& worker_address); + +#define HANDLER(method) \ + grpc::Status method(grpc::ServerContext* context, \ + const method##Request* request, \ + method##Response* response) override; + HANDLER(ProcessTask); + HANDLER(GetElement); +#undef HANDLER + + private: + DataServiceWorkerImpl impl_; + + TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerImpl); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_WORKER_IMPL_H_ diff --git a/tensorflow/core/data/service/local_credentials_factory.cc b/tensorflow/core/data/service/local_credentials_factory.cc new file mode 100644 index 00000000000..136bf49df9b --- /dev/null +++ b/tensorflow/core/data/service/local_credentials_factory.cc @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/credentials_factory.h" + +namespace tensorflow { +namespace data { + +class LocalCredentialsFactory : public CredentialsFactory { + public: + std::string Protocol() override { return "grpc+local"; } + + Status CreateServerCredentials( + std::shared_ptr<::grpc::ServerCredentials>* out) override { + *out = grpc::experimental::LocalServerCredentials(LOCAL_TCP); + return Status::OK(); + } + + Status CreateClientCredentials( + std::shared_ptr<::grpc::ChannelCredentials>* out) override { + *out = grpc::experimental::LocalCredentials(LOCAL_TCP); + return Status::OK(); + } +}; + +class LocalCredentialsRegistrar { + public: + LocalCredentialsRegistrar() { + auto factory = new LocalCredentialsFactory(); + CredentialsFactory::Register(factory); + } +}; +static LocalCredentialsRegistrar registrar; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/master_impl.cc b/tensorflow/core/data/service/master_impl.cc new file mode 100644 index 00000000000..033b28c03a8 --- /dev/null +++ b/tensorflow/core/data/service/master_impl.cc @@ -0,0 +1,208 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/master_impl.h" + +#include "grpcpp/create_channel.h" +#include "grpcpp/impl/codegen/server_context.h" +#include "grpcpp/security/credentials.h" +#include "absl/memory/memory.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/credentials_factory.h" +#include "tensorflow/core/data/service/grpc_util.h" +#include "tensorflow/core/data/service/master.pb.h" +#include "tensorflow/core/data/service/worker.grpc.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace data { + +namespace { +Status CreateWorkerStub(const std::string& address, + const std::string& protocol_, + std::unique_ptr* stub) { + ::grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(-1); + std::shared_ptr<::grpc::ChannelCredentials> credentials; + TF_RETURN_IF_ERROR( + CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); + auto channel = ::grpc::CreateCustomChannel(address, credentials, args); + *stub = WorkerService::NewStub(channel); + return Status::OK(); +} +} // namespace + +DataServiceMasterImpl::DataServiceMasterImpl(const std::string protocol) + : protocol_(protocol) {} + +Status DataServiceMasterImpl::RegisterWorker( + const RegisterWorkerRequest* request, RegisterWorkerResponse* response) { + VLOG(3) << "Received register worker request"; + mutex_lock l(mu_); + int64 worker_id = next_worker_id_++; + workers_.emplace_back(); + workers_.back().address = request->worker_address(); + workers_.back().id = worker_id; + response->set_worker_id(worker_id); + + // Allocate tasks to the worker. + for (auto& entry : epochs_) { + Epoch& epoch = entry.second; + int64 task_id = next_task_id_++; + DCHECK(!tasks_.contains(task_id)); + Task& task = tasks_[task_id]; + task.id = task_id; + task.dataset_id = epoch.dataset_id; + task.worker_address = request->worker_address(); + epoch.task_ids.push_back(task_id); + + TaskDef* task_def = response->add_tasks(); + *task_def->mutable_dataset() = + datasets_by_id_[task.dataset_id]->dataset_def; + task_def->set_dataset_id(task.dataset_id); + task_def->set_epoch_id(epoch.id); + task_def->set_task_id(task.id); + } + + VLOG(1) << "Registered worker " << workers_.back().DebugString(); + return Status::OK(); +} + +Status DataServiceMasterImpl::GetOrRegisterDataset( + const GetOrRegisterDatasetRequest* request, + GetOrRegisterDatasetResponse* response) { + uint64 fingerprint; + TF_RETURN_IF_ERROR(HashGraph(request->dataset().graph(), &fingerprint)); + mutex_lock l(mu_); + VLOG(3) << "Registering dataset graph: " + << request->dataset().graph().DebugString(); + if (datasets_by_fingerprint_.contains(fingerprint)) { + int64 id = datasets_by_fingerprint_[fingerprint]->id; + VLOG(3) << "Received duplicate RegisterDataset request with fingerprint " + << fingerprint << ". Returning id " << id; + response->set_dataset_id(id); + return Status::OK(); + } + int64 id = RegisterDataset(fingerprint, request->dataset()); + + response->set_dataset_id(id); + VLOG(3) << "Registered new dataset with id " << id; + return Status::OK(); +} + +int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint, + const DatasetDef& dataset) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto new_dataset = std::make_shared(); + int64 dataset_id = next_dataset_id_++; + new_dataset->id = dataset_id; + new_dataset->fingerprint = fingerprint; + new_dataset->dataset_def = dataset; + + DCHECK(!datasets_by_id_.contains(dataset_id)); + datasets_by_id_[dataset_id] = new_dataset; + DCHECK(!datasets_by_fingerprint_.contains(fingerprint)); + datasets_by_fingerprint_[dataset_id] = new_dataset; + return dataset_id; +} + +Status DataServiceMasterImpl::BeginEpoch(const BeginEpochRequest* request, + BeginEpochResponse* response) { + VLOG(3) << "Received begin epoch request for dataset id " + << request->dataset_id(); + mutex_lock l(mu_); + if (!datasets_by_id_.contains(request->dataset_id())) { + return errors::NotFound("BeginEpoch failed. Dataset id: <", + request->dataset_id(), "> not found."); + } + + int64 epoch_id = next_epoch_id_++; + DCHECK(!epochs_.contains(epoch_id)); + Epoch& epoch = epochs_[epoch_id]; + epoch.id = epoch_id; + epoch.dataset_id = request->dataset_id(); + response->set_epoch_id(epoch_id); + + for (auto& worker : workers_) { + int64 task_id = next_task_id_++; + DCHECK(!tasks_.contains(task_id)); + Task& task = tasks_[task_id]; + task.id = task_id; + task.dataset_id = request->dataset_id(); + task.worker_address = worker.address; + epoch.task_ids.push_back(task_id); + + std::unique_ptr stub; + TF_RETURN_IF_ERROR(CreateWorkerStub(worker.address, protocol_, &stub)); + // TODO(aaudibert): perform these calls asynchronously. + TF_RETURN_IF_ERROR(AllocateTaskToWorker(task, &worker)); + } + + VLOG(3) << "Beginning epoch " << epoch_id << " for dataset " + << request->dataset_id(); + return Status::OK(); +} + +Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task, + WorkerInfo* worker) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!worker->stub) { + TF_RETURN_IF_ERROR( + CreateWorkerStub(worker->address, protocol_, &worker->stub)); + } + grpc::ClientContext client_ctx; + ProcessTaskRequest req; + req.mutable_task()->set_dataset_id(task.dataset_id); + DCHECK(datasets_by_id_.contains(task.dataset_id)); + *req.mutable_task()->mutable_dataset() = + datasets_by_id_[task.dataset_id]->dataset_def; + req.mutable_task()->set_task_id(task.id); + ProcessTaskResponse resp; + grpc::Status s = worker->stub->ProcessTask(&client_ctx, req, &resp); + if (!s.ok()) { + return grpc_util::WrapError( + absl::StrCat("Failed to submit task to worker ", worker->address), s); + } + return Status::OK(); +} + +Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request, + GetTasksResponse* response) { + mutex_lock l(mu_); + VLOG(3) << "Looking up tasks for epoch id " << request->epoch_id(); + auto it = epochs_.find(request->epoch_id()); + if (it == epochs_.end()) { + return errors::NotFound("GetTasks failed. Epoch id <", request->epoch_id(), + "> not found."); + } + Epoch& epoch = it->second; + for (const auto& task_id : epoch.task_ids) { + auto task_iter = tasks_.find(task_id); + DCHECK(task_iter != tasks_.end()); + Task& task = task_iter->second; + TaskInfo* task_info = response->mutable_task_info()->Add(); + task_info->set_worker_address(task.worker_address); + task_info->set_id(task.id); + } + VLOG(3) << "Found " << response->task_info_size() << " tasks for epoch id " + << request->epoch_id(); + return Status::OK(); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/master_impl.h b/tensorflow/core/data/service/master_impl.h new file mode 100644 index 00000000000..c2bbd36d2a0 --- /dev/null +++ b/tensorflow/core/data/service/master_impl.h @@ -0,0 +1,123 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/master.pb.h" +#include "tensorflow/core/data/service/worker.grpc.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace data { + +// A service which coordinates a pool of workers to serve dataset elements over +// RPC. +// +// Glossary: +// * Dataset: A definition of how to generate a potentially large collection of +// elements. +// * Epoch: A single pass over a dataset. There may be multiple epochs +// for the same dataset, and they can be iterated over independently +// * Task: An epoch is broken into multiple tasks, which each represent +// iterating over all of or part of the dataset. Workers process tasks. We +// don't currently implement dataset splitting, so every task represents a +// full iteration over the dataset. In the future, we will partition the data +// across all tasks for the same epoch. +class DataServiceMasterImpl { + public: + explicit DataServiceMasterImpl(const std::string protocol); + + // See master.proto for API documentation. + + /// Worker-facing API. + Status RegisterWorker(const RegisterWorkerRequest* request, + RegisterWorkerResponse* response); + + /// Client-facing API. + Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request, + GetOrRegisterDatasetResponse* response); + Status BeginEpoch(const BeginEpochRequest* request, + BeginEpochResponse* response); + Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response); + + private: + typedef struct WorkerInfo { + std::string address; + int64 id; + std::unique_ptr stub; + + std::string DebugString() { + return absl::StrCat("id: ", id, "address: ", address); + } + } WorkerInfo; + + typedef struct Dataset { + int64 id; + int64 fingerprint; + DatasetDef dataset_def; + } Dataset; + + typedef struct Epoch { + int64 id; + int64 dataset_id; + std::vector task_ids; + } Epoch; + + typedef struct Task { + int64 id; + int64 dataset_id; + std::string worker_address; + } Task; + + // Registers a dataset with the given fingerprint, returning a new dataset id. + int64 RegisterDataset(uint64 fingerprint, const DatasetDef& dataset); + // Instructs a worker to begin processing a task. + Status AllocateTaskToWorker(const Task& task_id, WorkerInfo* worker); + + // Protocol to use for communicating with workers. + const std::string protocol_; + + mutex mu_; + + int64 next_worker_id_ TF_GUARDED_BY(mu_) = 0; + int64 next_dataset_id_ TF_GUARDED_BY(mu_) = 0; + int64 next_epoch_id_ TF_GUARDED_BY(mu_) = 0; + int64 next_task_id_ TF_GUARDED_BY(mu_) = 0; + + // Registered workers. + std::vector workers_ TF_GUARDED_BY(mu_); + // Registered datasets, keyed by dataset ids. + absl::flat_hash_map> datasets_by_id_ + TF_GUARDED_BY(mu_); + // Registered datasets, keyed by dataset fingerprints. + absl::flat_hash_map> datasets_by_fingerprint_ + TF_GUARDED_BY(mu_); + // Information about epochs, keyed by epoch ids. + absl::flat_hash_map epochs_ TF_GUARDED_BY(mu_); + // Information about tasks, keyed by task ids. + absl::flat_hash_map tasks_ TF_GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc new file mode 100644 index 00000000000..d4ec8dd0a9d --- /dev/null +++ b/tensorflow/core/data/service/server_lib.cc @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/server_lib.h" + +#include "tensorflow/core/data/service/credentials_factory.h" +#include "tensorflow/core/data/service/grpc_master_impl.h" +#include "tensorflow/core/data/service/grpc_worker_impl.h" + +namespace tensorflow { +namespace data { + +GrpcDataServer::GrpcDataServer(int port, const std::string& protocol, + bool is_master, + const std::string& master_address) + : requested_port_(port), + protocol_(protocol), + is_master_(is_master), + master_address_(master_address) {} + +Status GrpcDataServer::Start() { + ::grpc::ServerBuilder builder; + std::shared_ptr<::grpc::ServerCredentials> credentials; + TF_RETURN_IF_ERROR( + CredentialsFactory::CreateServerCredentials(protocol_, &credentials)); + builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_), + credentials, &bound_port_); + builder.SetMaxReceiveMessageSize(-1); + + if (is_master_) { + service_ = absl::make_unique(&builder, protocol_); + } else { + service_ = + absl::make_unique(&builder, master_address_, protocol_); + } + + server_ = builder.BuildAndStart(); + if (!server_) { + return errors::Internal("Could not start gRPC server"); + } + + if (!is_master_) { + static_cast(service_.get()) + ->Start(strings::StrCat("localhost:", bound_port_)); + } + + LOG(INFO) << "Started data service " << (is_master_ ? "master" : "worker") + << " running at " << Target(); + return Status::OK(); +} + +void GrpcDataServer::Stop() { server_->Shutdown(); } + +void GrpcDataServer::Join() { server_->Wait(); } + +std::string GrpcDataServer::Target() { + return strings::StrCat(protocol_, "://localhost:", bound_port_); +} + +Status NewMasterServer(int port, const std::string& protocol, + std::unique_ptr* out_server) { + *out_server = absl::make_unique( + port, protocol, /*is_master=*/true, /*master_address=*/""); + return Status::OK(); +} + +Status NewWorkerServer(int port, const std::string& protocol, + const std::string& master_address, + std::unique_ptr* out_server) { + *out_server = absl::make_unique( + port, protocol, /*is_master=*/false, master_address); + return Status::OK(); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h new file mode 100644 index 00000000000..753dd5ddfbf --- /dev/null +++ b/tensorflow/core/data/service/server_lib.h @@ -0,0 +1,73 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ + +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/server.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace data { + +// A grpc server for the dataset service. +class GrpcDataServer { + public: + // Constructs a dataset server with the specified port. If the port is 0, the + // server will find an available port in `Start()`. The chosen port can be + // found in the output of `Target()`. + // + // master_address is only needed for worker data servers. + explicit GrpcDataServer(int requested_port, const std::string& protocol, + bool is_master, const std::string& master_address); + + // Starts the server running asynchronously. + Status Start(); + + // Stops the server. This will block until all outstanding requests complete. + void Stop(); + + // Blocks until the server stops. + void Join(); + + // Returns the target string for the server. Only valid after calling Start(). + std::string Target(); + + private: + const int requested_port_; + const std::string protocol_; + const bool is_master_; + const std::string master_address_; + + int bound_port_; + + std::unique_ptr service_; + std::unique_ptr server_; +}; + +// Creates a master dataset server and stores it in `*out_server`. +Status NewMasterServer(int port, const std::string& protocol, + std::unique_ptr* out_server); + +// Creates a worker dataset server and stores it in `*out_server`. +Status NewWorkerServer(int port, const std::string& protocol, + const std::string& master_address, + std::unique_ptr* out_server); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc new file mode 100644 index 00000000000..dde51ab77a9 --- /dev/null +++ b/tensorflow/core/data/service/worker_impl.cc @@ -0,0 +1,157 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/worker_impl.h" + +#include "grpcpp/create_channel.h" +#include "absl/memory/memory.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/service/credentials_factory.h" +#include "tensorflow/core/data/service/grpc_util.h" +#include "tensorflow/core/data/service/master.grpc.pb.h" +#include "tensorflow/core/data/service/master.pb.h" +#include "tensorflow/core/data/standalone.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/zlib_outputbuffer.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/platform/snappy.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace data { + +const constexpr uint64 kHeartbeatIntervalMicros = 5ull * 1000 * 1000; + +namespace { +auto* tf_data_service_created = + monitoring::Gauge::New("/tensorflow/data/service/created", + "Whether a tf.data service server " + "has been created."); +} // namespace + +DataServiceWorkerImpl::DataServiceWorkerImpl(const std::string& master_address, + const std::string& protocol) + : master_address_(master_address), protocol_(protocol) { + tf_data_service_created->GetCell()->Set(true); +} + +void DataServiceWorkerImpl::Start(const std::string& worker_address) { + VLOG(3) << "Starting tf.data service worker at address " << worker_address; + mutex_lock l(mu_); + worker_address_ = worker_address; + + Status s = Register(); + while (!s.ok()) { + LOG(WARNING) << "Failed to register with master at " << master_address_ + << ": " << s; + Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros); + s = Register(); + } +} + +Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + VLOG(3) << "Registering with master at " << master_address_; + if (!master_stub_) { + ::grpc::ChannelArguments args; + std::shared_ptr<::grpc::ChannelCredentials> credentials; + TF_RETURN_IF_ERROR( + CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); + auto channel = + ::grpc::CreateCustomChannel(master_address_, credentials, args); + master_stub_ = MasterService::NewStub(channel); + } + RegisterWorkerRequest req; + req.set_worker_address(worker_address_); + RegisterWorkerResponse resp; + + grpc::ClientContext ctx; + grpc::Status s = master_stub_->RegisterWorker(&ctx, req, &resp); + if (!s.ok()) { + return grpc_util::WrapError("Failed to register worker", s); + } + for (const TaskDef& task : resp.tasks()) { + TF_RETURN_IF_ERROR(ProcessTaskInternal(task)); + } + VLOG(3) << "Registered worker with id " << resp.worker_id(); + return Status::OK(); +} + +Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request, + ProcessTaskResponse* response) { + mutex_lock l(mu_); + const TaskDef& task = request->task(); + VLOG(3) << "Received request to process task " << task.task_id(); + return ProcessTaskInternal(task); +} + +Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + standalone::Dataset::Params params; + std::unique_ptr dataset; + TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph( + params, task_def.dataset().graph(), &dataset)); + + std::unique_ptr iterator; + TF_RETURN_IF_ERROR(dataset->MakeIterator(&iterator)); + + if (tasks_.contains(task_def.task_id())) { + return errors::AlreadyExists("A task with id ", task_def.task_id(), + " already exists."); + } + Task& task = tasks_[task_def.task_id()]; + task.id = task_def.task_id(); + task.dataset = std::move(dataset); + task.iterator = std::move(iterator); + return Status::OK(); +} + +Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, + GetElementResponse* response) { + VLOG(3) << "Received GetElement request for task " << request->task_id(); + bool end_of_sequence = false; + std::vector outputs; + { + mutex_lock l(mu_); + auto it = tasks_.find(request->task_id()); + if (it == tasks_.end()) { + return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ", + "Task id ", request->task_id(), " not found"); + } + std::unique_ptr& iter = it->second.iterator; + if (iter == nullptr) { + response->set_end_of_sequence(true); + return Status::OK(); + } + TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence)); + if (end_of_sequence) { + // Release iterator memory and leave a null entry as a tombstone. + iter.reset(); + } + } + + if (!end_of_sequence) { + TF_RETURN_IF_ERROR(service_util::Compress( + outputs, response->mutable_compressed_element())); + } + response->set_end_of_sequence(end_of_sequence); + + return Status::OK(); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h new file mode 100644 index 00000000000..9595702f5d7 --- /dev/null +++ b/tensorflow/core/data/service/worker_impl.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/master.grpc.pb.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/data/standalone.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace data { + +// A TensorFlow DataService serves dataset elements over RPC. +class DataServiceWorkerImpl { + public: + explicit DataServiceWorkerImpl(const std::string& master_address, + const std::string& protocol); + virtual ~DataServiceWorkerImpl() {} + + // Starts the worker. The worker needs to know its own address so that it can + // register with the master. + void Start(const std::string& worker_address); + + // See worker.proto for API documentation. + + /// Master-facing API. + Status ProcessTask(const ProcessTaskRequest* request, + ProcessTaskResponse* response); + + /// Client-facing API. + Status GetElement(const GetElementRequest* request, + GetElementResponse* response); + + private: + // Registers the worker with the master. + Status Register(); + // Creates an iterator to process a task. + Status ProcessTaskInternal(const TaskDef& task); + + typedef struct Task { + int64 id; + // TODO(aaudibert): Have standalone::Iterator own a reference to + // standalone::Dataset so that we don't need to store the dataset here. + std::unique_ptr dataset; + std::unique_ptr iterator; + } Task; + + const std::string master_address_; + // Protocol for communicating with the master. + const std::string protocol_; + // The worker's own address. + std::string worker_address_; + + mutex mu_; + std::unique_ptr master_stub_ TF_GUARDED_BY(mu_); + // Information about tasks, keyed by task ids. + absl::flat_hash_map tasks_ TF_GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ From 8d82addeab1af2c58a45a9aa85061fc192283b3c Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Thu, 19 Mar 2020 17:01:35 -0700 Subject: [PATCH 366/492] Upgraded to DNNL-1.2.2 release. --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 1066479823a..c9cd2ddec66 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -175,11 +175,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "mkl_dnn_v1", build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"), - sha256 = "30979a09753e8e35d942446c3778c9f0eba543acf2fb0282af8b9c89355d0ddf", - strip_prefix = "mkl-dnn-1.2", + sha256 = "a71ec1f27c30b8a176605e8a78444f1f12301a3c313b70ff93290926c140509c", + strip_prefix = "mkl-dnn-1.2.2", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.2.tar.gz", - "https://github.com/intel/mkl-dnn/archive/v1.2.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.2.2.tar.gz", + "https://github.com/intel/mkl-dnn/archive/v1.2.2.tar.gz", ], ) From 96fe849d46ef0f061c58c71082ffc07b9ecbb4f2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 19:01:09 -0700 Subject: [PATCH 367/492] Internal TFRT changes. PiperOrigin-RevId: 302143492 Change-Id: I33aa81b260d8adfd2a83400544466019c077494b --- tensorflow/compiler/mlir/tensorflow/BUILD | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 3bed4e753e0..aea06a349f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -17,6 +17,13 @@ package_group( ], ) +exports_files([ + "ir/tf_generated_ops.td", + "ir/tf_op_base.td", + "ir/tf_op_interfaces.td", + "ir/tf_ops.td", +]) + filegroup( name = "tensorflow_ops_td_files", srcs = [ From c228728a2f6b294dc0c28f13e4a91d7d5947d4fe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Mar 2020 19:46:22 -0700 Subject: [PATCH 368/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302147539 Change-Id: Ie8604778d4d100f7ec45970b6f7205bfd03c36d6 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 68bb1dc49f5..75d86f71b78 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From ad188d0ade414e19e67877ddc545e37f92cce31b Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 20 Mar 2020 20:08:30 -0700 Subject: [PATCH 369/492] Minor refactoring of some snapshot_util functions. PiperOrigin-RevId: 302149438 Change-Id: I6b5a7837296f8793bbd20700634d2c32eb007a41 --- .../data/experimental/snapshot_dataset_op.cc | 16 +++++++----- .../data/experimental/snapshot_util.cc | 26 ++++++++++--------- .../kernels/data/experimental/snapshot_util.h | 6 ++--- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index e9873fd226e..a0349033519 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -378,9 +378,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); if (iterator_ == nullptr) { experimental::SnapshotMetadataRecord metadata; - Status s = snapshot_util::ReadMetadataFile(hash_dir_, &metadata); + bool file_exists; + TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile( + hash_dir_, &metadata, &file_exists)); TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState( - dataset()->mode_, s, &metadata, + dataset()->mode_, file_exists, &metadata, dataset()->pending_snapshot_expiry_seconds_, &state_)); VLOG(2) << "Snapshot state: " << state_; TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata)); @@ -417,8 +419,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { state_ = snapshot_util::Mode(temp); } experimental::SnapshotMetadataRecord metadata; - TF_RETURN_IF_ERROR( - snapshot_util::ReadMetadataFile(hash_dir_, &metadata)); + bool file_exists; + TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(hash_dir_, &metadata, + &file_exists)); TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata)); VLOG(2) << "Restoring Snapshot iterator: " << state_; return RestoreInput(ctx, reader, iterator_); @@ -1336,8 +1339,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); if (!written_final_metadata_file_) { experimental::SnapshotMetadataRecord metadata; - TF_RETURN_IF_ERROR( - snapshot_util::ReadMetadataFile(hash_dir_, &metadata)); + bool file_exists; + TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile( + hash_dir_, &metadata, &file_exists)); if (metadata.run_id() == run_id_) { metadata.set_finalized(true); diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc index 9c2b30736e7..72d2c5cddd9 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc @@ -406,10 +406,17 @@ Status WriteMetadataFile(const string& hash_dir, } Status ReadMetadataFile(const string& hash_dir, - experimental::SnapshotMetadataRecord* metadata) { + experimental::SnapshotMetadataRecord* metadata, + bool* file_exists) { string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename); - TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename)); - return ReadBinaryProto(Env::Default(), metadata_filename, metadata); + Status s = Env::Default()->FileExists(metadata_filename); + *file_exists = s.ok(); + + if (*file_exists) { + return ReadBinaryProto(Env::Default(), metadata_filename, metadata); + } else { + return Status::OK(); + } } Status DumpDatasetGraph(const std::string& path, uint64 hash, @@ -424,15 +431,14 @@ Status DumpDatasetGraph(const std::string& path, uint64 hash, return WriteTextProto(Env::Default(), graph_file, *graph); } -Status DetermineOpState(const std::string& mode_string, - const Status& file_status, +Status DetermineOpState(const std::string& mode_string, bool file_exists, const experimental::SnapshotMetadataRecord* metadata, const uint64 pending_snapshot_expiry_seconds, Mode* mode) { if (mode_string == kModeRead) { // In read mode, we should expect a metadata file is written. - if (errors::IsNotFound(file_status)) { - return file_status; + if (!file_exists) { + return errors::NotFound("Metadata file does not exist."); } LOG(INFO) << "Overriding mode to reader."; *mode = READER; @@ -451,15 +457,11 @@ Status DetermineOpState(const std::string& mode_string, return Status::OK(); } - if (errors::IsNotFound(file_status)) { + if (!file_exists) { *mode = WRITER; return Status::OK(); } - if (!file_status.ok()) { - return file_status; - } - if (metadata->finalized()) { // File found, snapshot has been finalized. *mode = READER; diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h index c141cb0bbb0..e962bb56380 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.h +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h @@ -134,13 +134,13 @@ Status WriteMetadataFile(const string& hash_dir, const experimental::SnapshotMetadataRecord* metadata); Status ReadMetadataFile(const string& hash_dir, - experimental::SnapshotMetadataRecord* metadata); + experimental::SnapshotMetadataRecord* metadata, + bool* file_exists); Status DumpDatasetGraph(const std::string& path, uint64 hash, const GraphDef* graph); -Status DetermineOpState(const std::string& mode_string, - const Status& file_status, +Status DetermineOpState(const std::string& mode_string, bool file_exists, const experimental::SnapshotMetadataRecord* metadata, const uint64 pending_snapshot_expiry_seconds, Mode* mode); From d17ad62d8a774ea14159b2d5a7c8e22b8f570146 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 20 Mar 2020 20:26:07 -0700 Subject: [PATCH 370/492] Add _HAS_ALL_REDUCE_SUM_GRAD to SGD PiperOrigin-RevId: 302150669 Change-Id: I9c48c608ab6930da77a2800e00418c2ff559f111 --- tensorflow/python/keras/optimizer_v2/gradient_descent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py index 2f0bccb8355..539443aef60 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py @@ -74,6 +74,8 @@ class SGD(optimizer_v2.OptimizerV2): http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). """ + _HAS_ALL_REDUCE_SUM_GRAD = True + def __init__(self, learning_rate=0.01, momentum=0.0, From c7768ea238928a03b3236a7cba4fce31c990489c Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Fri, 20 Mar 2020 20:34:38 -0700 Subject: [PATCH 371/492] Register the TensorFlow Eager constant folding hook in 'compile_mlir_util'. 'tf_dialect_passes' is added as a dependency to 'compile_mlir_util' to register the TF Eager constant folding hook. Certain ops, like 'tf.BroadcastGradientArgs', fail to constant fold prior via the canonicalization, preventing legalization from TF to HLO. PiperOrigin-RevId: 302151385 Change-Id: I5938743c00ab4d9aba4706b9c7b7a620e288eb01 --- tensorflow/compiler/mlir/tensorflow/BUILD | 1 + .../utils/compile_mlir_util_test.cc | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index aea06a349f5..bb8e3c19e9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1058,6 +1058,7 @@ cc_library( ":error_util", ":tensorflow_dialect_registration", ":tensorflow_passes", + ":tf_dialect_passes", ":translate_utils", "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index b258dd68ae1..0caf1752cfb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -248,5 +248,42 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { ::testing::HasSubstr(expected_signature)); } +constexpr llvm::StringRef kBroadcastGradientArgsModule = R"( +module attributes {tf.versions = {producer = 179 : i32}} { + func @main() -> (tensor<0xi32>, tensor<0xi32>) { + %0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %r0, %r1 = "tf.BroadcastGradientArgs"(%0, %0) {T = i32} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<0xi32>, tensor<0xi32>) + return %r0, %r1 : tensor<0xi32>, tensor<0xi32> + } +} +)"; + +TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { + std::vector arg_shapes(2, TensorShape()); + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + kBroadcastGradientArgsModule, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); + ASSERT_TRUE(s.ok()); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + string expected_hlo_module_string = R"(HloModule main.4 + +ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { + %arg_tuple.1 = () parameter(0) + %constant.2 = s32[0]{0} constant({}) + ROOT %tuple.3 = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} %constant.2, s32[0]{0} %constant.2) +} + +)"; + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + } // namespace } // namespace tensorflow From f138d653010da6cb5620853cc052c5febd818819 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Mar 2020 02:03:21 -0700 Subject: [PATCH 372/492] compat: Update forward compatibility horizon to 2020-03-21 PiperOrigin-RevId: 302175867 Change-Id: Ide8600036d09a1eace724145dc16bd9bc62a9339 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 3195e9ce5b9..288c0670968 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 20) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 21) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From cd3eec5795530f9c76ab06f6c32205a3936feabd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Mar 2020 03:46:11 -0700 Subject: [PATCH 373/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302181996 Change-Id: I8b06f8d5405d7669b2a244b2aea24e4e60d31e99 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 75d86f71b78..68bb1dc49f5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From a99f66ee34ae492122d16799bc493acf7a71dc1f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 21 Mar 2020 10:48:38 -0500 Subject: [PATCH 374/492] Expose eager c_api_experimental to other language binding. --- tensorflow/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 55406a5686a..114787116df 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -702,6 +702,7 @@ tf_cc_shared_object( "//tensorflow/c:exported_symbols.lds", "//tensorflow/c:version_script.lds", "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", "//tensorflow/core:tensorflow", "//tensorflow/core/distributed_runtime/rpc:grpc_session", ], From 33ebd0d4f2dd5fd9a71d06c624d71c47b50541fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Mar 2020 09:46:17 -0700 Subject: [PATCH 375/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302205333 Change-Id: I4946406d196a19678c66cb8ab951e7f27d66042f --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 68bb1dc49f5..75d86f71b78 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 6c6f57ca2ae4c9c423885025dc7aa23a5f2b579d Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Sat, 21 Mar 2020 09:46:32 -0700 Subject: [PATCH 376/492] Roll forward (with a fix) of previously rolled back PR #37400. PiperOrigin-RevId: 302205361 Change-Id: I4d5ce5e9f3179da7de2c0f26c716dc2df9a4d5fd --- .../bucket_by_sequence_length_test.py | 7 +- .../data/kernel_tests/from_generator_test.py | 63 +++++- .../python/data/kernel_tests/iterator_test.py | 4 +- tensorflow/python/data/ops/dataset_ops.py | 183 +++++++++++------- tensorflow/python/data/util/structure.py | 16 +- tensorflow/python/ops/ragged/ragged_tensor.py | 5 + tensorflow/python/ops/script_ops.py | 49 ++++- .../v1/tensorflow.-ragged-tensor-spec.pbtxt | 4 + .../golden/v1/tensorflow.data.-dataset.pbtxt | 2 +- ...ow.data.-fixed-length-record-dataset.pbtxt | 2 +- .../tensorflow.data.-t-f-record-dataset.pbtxt | 2 +- .../tensorflow.data.-text-line-dataset.pbtxt | 2 +- ...rflow.data.experimental.-csv-dataset.pbtxt | 2 +- ...ow.data.experimental.-random-dataset.pbtxt | 2 +- ...rflow.data.experimental.-sql-dataset.pbtxt | 2 +- .../v2/tensorflow.-ragged-tensor-spec.pbtxt | 4 + .../golden/v2/tensorflow.data.-dataset.pbtxt | 2 +- ...ow.data.-fixed-length-record-dataset.pbtxt | 2 +- .../tensorflow.data.-t-f-record-dataset.pbtxt | 2 +- .../tensorflow.data.-text-line-dataset.pbtxt | 2 +- ...rflow.data.experimental.-csv-dataset.pbtxt | 2 +- ...ow.data.experimental.-random-dataset.pbtxt | 2 +- ...rflow.data.experimental.-sql-dataset.pbtxt | 2 +- 23 files changed, 253 insertions(+), 110 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py index 0dd7ae1f083..d23bbbe615a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py @@ -48,7 +48,7 @@ def _format_record(array, sparse): return { "values": array, "indices": [[i] for i in range(len(array))], - "dense_shape": (len(array),) + "dense_shape": [len(array),] } return array @@ -402,13 +402,16 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase, bucket_size = 10 def _build_dataset(): - input_data = [range(i+1) for i in range(min_len, max_len)] + input_data = [list(range(i + 1)) for i in range(min_len, max_len)] + def generator_fn(): for record in input_data: yield _format_record(record, sparse=True) + dataset = dataset_ops.Dataset.from_generator( generator=generator_fn, output_types=_get_record_type(sparse=True)) + dataset = dataset.map(_to_sparse_tensor) return dataset diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py index d320b281136..288d0e694f2 100644 --- a/tensorflow/python/data/kernel_tests/from_generator_test.py +++ b/tensorflow/python/data/kernel_tests/from_generator_test.py @@ -28,7 +28,12 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test @@ -241,7 +246,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaisesOpError("The expected type was int64"): + with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) self.assertAllEqual([7, 8, 9], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -261,7 +266,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): + with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) self.assertAllEqual([11, 12, 13], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -282,11 +287,9 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual((1, 2), self.evaluate(get_next())) self.assertEqual((3, 4), self.evaluate(get_next())) - with self.assertRaisesOpError( - r"The expected structure was \(tf\.int64, tf\.int64\)"): + with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) - with self.assertRaisesOpError( - r"The expected structure was \(tf\.int64, tf\.int64\)"): + with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) self.assertEqual((9, 10), self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -405,8 +408,12 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): stateful=True) dummy = constant_op.constant(37) - dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x, - finalize_fn).take(2) + + dataset = dataset_ops._GeneratorDataset( + dummy, lambda x: x, lambda x: x, finalize_fn, + tensor_spec.TensorSpec((), dtypes.int32)) + + dataset = dataset.take(2) get_next = self.getNext(dataset) self.assertAllEqual(37, self.evaluate(get_next())) @@ -428,6 +435,46 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([20], self.evaluate(get_next())) + @combinations.generate(test_base.default_test_combinations()) + def testFromGeneratorRaggedTensor(self): + + def generator(): + yield ragged_factory_ops.constant([[1, 2], [3]], + dtype=dtypes.int64, + ragged_rank=1) + + dataset = dataset_ops.Dataset.from_generator( + generator, + output_signature=ragged_tensor.RaggedTensorSpec( + shape=(2, None), dtype=dtypes.int64)) + get_next = self.getNext(dataset) + + ret = get_next() + + self.assertIsInstance(ret, ragged_tensor.RaggedTensor) + self.assertAllEqual([1, 2, 3], ret.values) + + @combinations.generate(test_base.default_test_combinations()) + def testFromGeneratorSparseTensor(self): + + def generator(): + yield sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], + values=constant_op.constant([1, 2], dtype=dtypes.int64), + dense_shape=[3, 4]) + + dataset = dataset_ops.Dataset.from_generator( + generator, + output_signature=sparse_tensor.SparseTensorSpec([3, 4], dtypes.int64)) + + get_next = self.getNext(dataset) + + ret = get_next() + + self.assertIsInstance(ret, sparse_tensor.SparseTensor) + self.assertAllEqual([[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]], + sparse_ops.sparse_tensor_to_dense(ret)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 36689ed75fb..94b50a7864d 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -946,7 +946,9 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @def_function.function def fn(): - dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn) + output_spec = tensor_spec.TensorSpec((), dtypes.int64) + dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, + output_spec) iterator = iter(dataset) next(iterator) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 32ab469363e..9eb38bfc0d1 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -408,8 +408,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): def element_spec(self): """The type specification of an element of this dataset. - >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) - >>> dataset.element_spec + >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec TensorSpec(shape=(), dtype=tf.int32, name=None) Returns: @@ -675,27 +674,48 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): del self._iterators[iterator_id] @staticmethod - def from_generator(generator, output_types, output_shapes=None, args=None): + @deprecation.deprecated_args(None, "Use output_signature instead", + "output_types", "output_shapes") + def from_generator(generator, + output_types=None, + output_shapes=None, + args=None, + output_signature=None): """Creates a `Dataset` whose elements are generated by `generator`. The `generator` argument must be a callable object that returns an object that supports the `iter()` protocol (e.g. a generator function). - The elements generated by `generator` must be compatible with the given - `output_types` and (optional) `output_shapes` arguments. - >>> import itertools - >>> + The elements generated by `generator` must be compatible with either the + given `output_signature` argument or with the given `output_types` and + (optionally) `output_shapes` arguments whichiver was specified. + + The recommended way to call `from_generator` is to use the + `output_signature` argument. In this case the output will be assumed to + consist of objects with the classes, shapes and types defined by + `tf.TypeSpec` objects from `output_signature` argument: + >>> def gen(): - ... for i in itertools.count(1): - ... yield (i, [1] * i) + ... ragged_tensor = tf.ragged.constant([[1, 2], [3]], + ... ragged_rank=1, + ... dtype=tf.int64) + ... yield 42, ragged_tensor >>> >>> dataset = tf.data.Dataset.from_generator( ... gen, - ... (tf.int64, tf.int64), - ... (tf.TensorShape([]), tf.TensorShape([None]))) + ... output_signature=( + ... tf.TensorSpec(shape=(), dtype=tf.int64), + ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int64))) >>> - >>> list(dataset.take(3).as_numpy_iterator()) - [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))] + >>> list(dataset.take(1)) + [(, + )] + + There is also a deprecated way to call `from_generator` by either with + `output_types` argument alone or together with `output_shapes` argument. + In this case the output of the function will be assumed to consist of + `tf.Tensor` objects with with the types defined by `output_types` and with + the shapes which are either unknown or defined by `output_shapes`. Note: The current implementation of `Dataset.from_generator()` uses `tf.numpy_function` and inherits the same constraints. In particular, it @@ -719,31 +739,56 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): `iter()` protocol. If `args` is not specified, `generator` must take no arguments; otherwise it must take as many arguments as there are values in `args`. - output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element yielded by `generator`. + output_types: (Optional.) A nested structure of `tf.DType` objects + corresponding to each component of an element yielded by `generator`. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element yielded by `generator`. args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated and passed to `generator` as NumPy-array arguments. + output_signature: (Optional.) A nested structure of `tf.TypeSpec` objects + corresponding to each component of an element yielded by `generator`. Returns: Dataset: A `Dataset`. """ if not callable(generator): raise TypeError("`generator` must be callable.") - if output_shapes is None: - output_shapes = nest.map_structure( - lambda _: tensor_shape.TensorShape(None), output_types) + + if output_signature is not None: + if output_types is not None: + raise TypeError("`output_types` can not be used together with " + "`output_signature`") + if output_shapes is not None: + raise TypeError("`output_shapes` can not be used together with " + "`output_signature`") + if not all( + isinstance(_, type_spec.TypeSpec) + for _ in nest.flatten(output_signature)): + raise TypeError("All the elements of `output_siganture` must be " + "a `tf.TypeSpec` objects.") else: - output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) + if output_types is None and output_shapes is not None: + raise TypeError("`output_shapes` can not be used alone without " + "`output_types`") + + if output_signature is None: + if output_shapes is None: + output_shapes = nest.map_structure( + lambda _: tensor_shape.TensorShape(None), output_types) + else: + output_shapes = nest.map_structure_up_to(output_types, + tensor_shape.as_shape, + output_shapes) + output_signature = nest.map_structure_up_to(output_types, + tensor_spec.TensorSpec, + output_shapes, output_types) + if args is None: args = () else: args = tuple(ops.convert_n_to_tensor(args, name="args")) - flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)] - flattened_shapes = nest.flatten(output_shapes) + flat_output_types = structure.get_flat_tensor_types(output_signature) generator_state = DatasetV2._GeneratorState(generator) @@ -781,56 +826,41 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): """A `py_func` that will be called to invoke the iterator.""" # `next()` raises `StopIteration` when there are no more # elements remaining to be generated. - values = next(generator_state.get_iterator(iterator_id)) + values = next(generator_state.get_iterator(iterator_id.numpy())) + + def serialize_structure(s): + return nest.map_structure(lambda ts: ts._serialize(), s) # pylint: disable=protected-access - # Use the same _convert function from the py_func() implementation to - # convert the returned values to arrays early, so that we can inspect - # their values. try: - flattened_values = nest.flatten_up_to(output_types, values) + output_dtypes = nest.map_structure(lambda t: t.dtype, + output_signature) + values = structure.normalize_element(values, dtypes=output_dtypes) except (TypeError, ValueError): - six.reraise(TypeError, TypeError( - "`generator` yielded an element that did not match the expected " - "structure. The expected structure was %s, but the yielded " - "element was %s." % (output_types, values)), sys.exc_info()[2]) - ret_arrays = [] - for ret, dtype in zip(flattened_values, flattened_types): - try: - ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access - ret, dtype=dtype.as_numpy_dtype)) - except (TypeError, ValueError): - six.reraise(TypeError, TypeError( - "`generator` yielded an element that could not be converted to " - "the expected type. The expected type was %s, but the yielded " - "element was %s." % (dtype.name, ret)), sys.exc_info()[2]) + six.reraise( + TypeError, + TypeError( + "`generator` yielded an element that did not match the " + "expected structure. The expected structure was %s, but the " + "yielded element was %s." % + (serialize_structure(output_signature), values)), + sys.exc_info()[2]) - # Additional type and shape checking to ensure that the components - # of the generated element match the `output_types` and `output_shapes` - # arguments. - for (ret_array, expected_dtype, expected_shape) in zip( - ret_arrays, flattened_types, flattened_shapes): - if ret_array.dtype != expected_dtype.as_numpy_dtype: - raise TypeError( - "`generator` yielded an element of type %s where an element " - "of type %s was expected." % (ret_array.dtype, - expected_dtype.as_numpy_dtype)) - if not expected_shape.is_compatible_with(ret_array.shape): - raise ValueError( - "`generator` yielded an element of shape %s where an element " - "of shape %s was expected." % (ret_array.shape, expected_shape)) + values_spec = structure.type_spec_from_value(values) - return ret_arrays + if not structure.are_compatible(values_spec, output_signature): + raise TypeError( + "`generator` yielded an element of TypeSpec%s where an element " + "of TypeSpec%s was expected." % + (serialize_structure(values_spec), + serialize_structure(output_signature))) - flat_values = script_ops.numpy_function(generator_py_func, - [iterator_id_t], flattened_types) + return structure.to_tensor_list(output_signature, values) - # The `py_func()` op drops the inferred shapes, so we add them back in - # here. - if output_shapes is not None: - for ret_t, shape in zip(flat_values, flattened_shapes): - ret_t.set_shape(shape) - - return nest.pack_sequence_as(output_types, flat_values) + return script_ops._eager_py_func( # pylint: disable=protected-access + generator_py_func, + inp=[iterator_id_t], + Tout=flat_output_types, + use_tape_cache=False) def finalize_fn(iterator_id_t): """Releases host-side state for the iterator with ID `iterator_id_t`.""" @@ -856,7 +886,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): # given ID, and raises StopIteration when that iterator contains no # more elements. return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, - finalize_fn) + finalize_fn, output_signature) # A single-element dataset that, each time it is evaluated, contains a # freshly-generated and unique (for the returned dataset) int64 @@ -2278,9 +2308,14 @@ class DatasetV1(DatasetV2): @staticmethod @functools.wraps(DatasetV2.from_generator) - def from_generator(generator, output_types, output_shapes=None, args=None): - return DatasetV1Adapter(DatasetV2.from_generator( - generator, output_types, output_shapes, args)) + def from_generator(generator, + output_types=None, + output_shapes=None, + args=None, + output_signature=None): + return DatasetV1Adapter( + DatasetV2.from_generator(generator, output_types, output_shapes, args, + output_signature)) @staticmethod @functools.wraps(DatasetV2.range) @@ -3261,7 +3296,8 @@ class StructuredFunctionWrapper(object): class _GeneratorDataset(DatasetSource): """A `Dataset` that generates elements by invoking a function.""" - def __init__(self, init_args, init_func, next_func, finalize_func): + def __init__(self, init_args, init_func, next_func, finalize_func, + output_signature): """Constructs a `_GeneratorDataset`. Args: @@ -3275,6 +3311,8 @@ class _GeneratorDataset(DatasetSource): finalize_func: A TensorFlow function that will be called on the result of `init_func` immediately before a C++ iterator over this dataset is destroyed. The return value is ignored. + output_signature: A nested structure of `tf.TypeSpec` objects describing + the output of `next_func`. """ self._init_args = init_args @@ -3294,6 +3332,9 @@ class _GeneratorDataset(DatasetSource): finalize_func, self._transformation_name(), input_structure=self._init_func.output_structure) + + self._output_signature = output_signature + variant_tensor = gen_dataset_ops.generator_dataset( structure.to_tensor_list(self._init_structure, self._init_args) + self._init_func.function.captured_inputs, @@ -3307,7 +3348,7 @@ class _GeneratorDataset(DatasetSource): @property def element_spec(self): - return self._next_func.output_structure + return self._output_signature def _transformation_name(self): return "Dataset.from_generator()" diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 87825005069..ee6151742f6 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -67,7 +67,7 @@ def _RaggedTensorStructure(dtype, shape, ragged_rank): # TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once # it is a subclass of `CompositeTensor`. -def normalize_element(element): +def normalize_element(element, dtypes=None): """Normalizes a nested structure of element components. * Components matching `SparseTensorSpec` are converted to `SparseTensor`. @@ -78,6 +78,10 @@ def normalize_element(element): Args: element: A nested structure of individual components. + dtypes: (Optional.) A nested structure of `tf.DType` objects corresponding + to each component of `element`. If specified, it will be used to set the + exact type of output tensor when converting input components which + are not tensors themselves (e.g. numpy arrays, native python types, etc.) Returns: A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`, @@ -85,17 +89,21 @@ def normalize_element(element): """ components = nest.flatten(element) normalized_components = [] + if dtypes is None: + flattened_dtypes = [None] * len(components) + else: + flattened_dtypes = nest.flatten(dtypes) with ops.name_scope("normalize_element"): # Imported here to avoid circular dependency. from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top - for i, t in enumerate(components): + for i, (t, dtype) in enumerate(zip(components, flattened_dtypes)): try: spec = type_spec_from_value(t, use_fallback=False) except TypeError: # TypeError indicates it was not possible to compute a `TypeSpec` for # the value. As a fallback try converting the value to a tensor. normalized_components.append( - ops.convert_to_tensor(t, name="component_%d" % i)) + ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) else: if isinstance(spec, sparse_tensor.SparseTensorSpec): normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) @@ -112,7 +120,7 @@ def normalize_element(element): normalized_components.append(t) else: normalized_components.append( - ops.convert_to_tensor(t, name="component_%d" % i)) + ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) return nest.pack_sequence_as(element, normalized_components) diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 78be28b7ec6..6d365210308 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -2085,6 +2085,11 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): else: return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value) + @property + def dtype(self): + """The `tf.dtypes.DType` specified by this type for the RaggedTensor.""" + return self._dtype + def _serialize(self): return (self._shape, self._dtype, self._ragged_rank, self._row_splits_dtype) diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index bee85dc4a5b..dd53b388bd4 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -70,7 +70,7 @@ def _maybe_copy_to_context_device(tensor, device_name): class EagerFunc(object): """A wrapper for a function owned by an EagerPyFunc.""" - def __init__(self, func, Tout, is_grad_func): + def __init__(self, func, Tout, is_grad_func, use_tape_cache=True): """Constructs an EagerFunc. Args: @@ -79,10 +79,12 @@ class EagerFunc(object): None. is_grad_func: Whether this EagerFunc is the gradient of another EagerPyFunc. + use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`. """ self._func = func self._out_dtypes = Tout self._is_grad_func = is_grad_func + self._use_tape_cache = use_tape_cache def _convert(self, value, dtype): """Converts `value` to a tensor of type `dtype`, with error checking. @@ -146,7 +148,8 @@ class EagerFunc(object): else: outputs = _maybe_copy_to_context_device( self._convert(ret, dtype=self._out_dtypes[0]), device_name) - tape_cache[compat.as_bytes(token)] = (tape, args, outputs) + if self._use_tape_cache: + tape_cache[compat.as_bytes(token)] = (tape, args, outputs) return outputs @@ -276,7 +279,8 @@ def _internal_py_func(func, stateful=None, eager=False, is_grad_func=False, - name=None): + name=None, + use_tape_cache=True): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError("Expected func to be callable, got func of type {}".format( @@ -292,7 +296,7 @@ def _internal_py_func(func, Tout = [Tout] if eager: - func = EagerFunc(func, Tout, is_grad_func) + func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache) # Tying the registered function's lifetime with the current default graph is # not reliable. For example, Estimator-based binaries may switch graphs in @@ -369,6 +373,35 @@ def _EagerPyFuncGrad(op, *dy): is_grad_func=True) +# NOTE(lithuak): this function as a layer of indirection was added with one +# specific purpose: as a workaround for github issue #35084. +# It does all the same as `eager_py_func` used to do with one difference: +# it can be used to instruct underlying EagerFunc not to use `tape_cache` +# to avoid memory leak. When the issue #35084 is fixed - this function should +# be removed, its body should be moved back to become the body of +# `eager_py_func` and all the call sites should be reverted to +# using `eager_py_func` without `use_tape_cache` argument of any value. +def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True): + """Wraps a python function into a TensorFlow op that executes it eagerly.""" + if ops.executing_eagerly_outside_functions(): + with ops.device(context.context().host_address_space()): + return _internal_py_func( + func=func, + inp=inp, + Tout=Tout, + eager=True, + name=name, + use_tape_cache=use_tape_cache) + + return _internal_py_func( + func=func, + inp=inp, + Tout=Tout, + eager=True, + name=name, + use_tape_cache=use_tape_cache) + + @tf_export("py_function") def eager_py_func(func, inp, Tout, name=None): """Wraps a python function into a TensorFlow op that executes it eagerly. @@ -449,12 +482,8 @@ def eager_py_func(func, inp, Tout, name=None): A list of `Tensor` or a single `Tensor` which `func` computes; an empty list if `func` returns None. """ - if ops.executing_eagerly_outside_functions(): - with ops.device(context.context().host_address_space()): - return _internal_py_func( - func=func, inp=inp, Tout=Tout, eager=True, name=name) - - return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name) + return _eager_py_func( + func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True) def py_func_common(func, inp, Tout, stateful=True, name=None): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt index 2ec5bb46ed1..029d04fee9b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "dtype" + mtype: "" + } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 872d03770ed..841b142c082 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -63,7 +63,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index a84c5aa3caf..42225d3f566 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index a3862ae2a19..81a1c7fbd9c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index baaaf7ea7be..e9e3962a498 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt index afdeea5d018..20712fb14a7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt index 76113c5e01d..c139c6b9cc8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt index 1a11026fd19..41a67db62dc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt index 2ec5bb46ed1..029d04fee9b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "dtype" + mtype: "" + } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index d9414c31e7d..3cb50feac2d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -46,7 +46,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 28efdb6e855..9e2fa7255fd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index c9553efb58c..1bd43d28bc4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -47,7 +47,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 16a878144ae..2e295c44b5f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index d1d2db041e0..91175909f77 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index 18a6b8cbd1b..09ed74d3460 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index 0cf3d94ba68..c245d563e9e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_tensor_slices" From ca19cb9b6a781ecd1252938db205b794a898c25d Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Sat, 21 Mar 2020 10:05:16 -0700 Subject: [PATCH 377/492] [tf.data] Completing migration to new internal APIs that make it possible to overriding policy for handling external state during iterator checkpointing. PiperOrigin-RevId: 302206854 Change-Id: I9c3f55015e22dcab51af816d727ae243beda314f --- tensorflow/core/framework/dataset.h | 19 +------------------ .../core/kernels/data/cache_dataset_ops.cc | 3 ++- .../experimental/matching_files_dataset_op.cc | 3 ++- 3 files changed, 5 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 25cc8fd759e..9cabcb08490 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -628,14 +628,6 @@ class IteratorBase { return input->SaveInternal(ctx, writer); } - // TODO(jsimsa): Remove this override when all callers are migrated to the - // override that uses SerializationContext. - Status SaveInput(IteratorStateWriter* writer, - const std::unique_ptr& input) { - SerializationContext ctx(/*params=*/{}); - return input->SaveInternal(&ctx, writer); - } - // This is needed so that sub-classes of IteratorBase can call // `RestoreInternal` on their input iterators. Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, @@ -648,16 +640,7 @@ class IteratorBase { // This method is used to store the state of the iterator in a checkpoint. // implementations have an override. virtual Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) { - return SaveInternal(writer); - } - - // TODO(jsimsa): Remove this override when all subclasses are migrated to the - // override that accepts SerializationContext and make that override pure - // virtual. - virtual Status SaveInternal(IteratorStateWriter* writer) { - return errors::Unimplemented("checkpointing is not supported"); - } + IteratorStateWriter* writer) = 0; // Restores the state of this iterator. // diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index f99ac114dc2..707800bc896 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -944,7 +944,8 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index 9ba44aaf909..90a61d72597 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -192,7 +192,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name("current_pattern_index"), current_pattern_index_)); From 874adafd91b0719b89ed1b77cfc472251a82d36b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Mar 2020 10:28:14 -0700 Subject: [PATCH 378/492] [tf.data] Completing migration to new internal APIs that make it possible to overriding policy for handling external state during iterator checkpointing. PiperOrigin-RevId: 302208411 Change-Id: Iaa4f18fe6c9f02bab7b440a34ecbe61f7186201a --- tensorflow/core/framework/dataset.h | 19 ++++++++++++++++++- .../core/kernels/data/cache_dataset_ops.cc | 3 +-- .../experimental/matching_files_dataset_op.cc | 3 +-- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 9cabcb08490..25cc8fd759e 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -628,6 +628,14 @@ class IteratorBase { return input->SaveInternal(ctx, writer); } + // TODO(jsimsa): Remove this override when all callers are migrated to the + // override that uses SerializationContext. + Status SaveInput(IteratorStateWriter* writer, + const std::unique_ptr& input) { + SerializationContext ctx(/*params=*/{}); + return input->SaveInternal(&ctx, writer); + } + // This is needed so that sub-classes of IteratorBase can call // `RestoreInternal` on their input iterators. Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, @@ -640,7 +648,16 @@ class IteratorBase { // This method is used to store the state of the iterator in a checkpoint. // implementations have an override. virtual Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) = 0; + IteratorStateWriter* writer) { + return SaveInternal(writer); + } + + // TODO(jsimsa): Remove this override when all subclasses are migrated to the + // override that accepts SerializationContext and make that override pure + // virtual. + virtual Status SaveInternal(IteratorStateWriter* writer) { + return errors::Unimplemented("checkpointing is not supported"); + } // Restores the state of this iterator. // diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 707800bc896..f99ac114dc2 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -944,8 +944,7 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index 90a61d72597..9ba44aaf909 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -192,8 +192,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name("current_pattern_index"), current_pattern_index_)); From 1210b521aa2226b01ee1bd9528a8f247b7283efb Mon Sep 17 00:00:00 2001 From: Yi Situ Date: Sat, 21 Mar 2020 12:35:00 -0700 Subject: [PATCH 379/492] Remove macro guards. PiperOrigin-RevId: 302216728 Change-Id: Ia9ba9c21351b63c36665a611b0543bd1e0fe2450 --- .../profiler/convert/xplane_to_kernel_stats_db.cc | 15 ++++++++------- .../core/profiler/utils/kernel_stats_utils.cc | 7 +++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc index 0deb4309ff8..8d2f95d8fc3 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" #include "tensorflow/core/profiler/utils/event_span.h" @@ -66,13 +67,13 @@ KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb( kernel.set_op_name(tf_op.name.data(), tf_op.name.size()); bool tensor_core_eligible = IsEinsumTensorCoreEligible(equation) || IsOpTensorCoreEligible(kernel.op_name()); -#if defined(VLOG_IF) - VLOG_IF(1, - !tensor_core_eligible && kernel.is_kernel_using_tensor_core()) - << "Detected new Op using TensorCores: " << kernel.op_name() - << std::endl; -#endif // defined(VLOG_IF) - tensor_core_eligible |= kernel.is_kernel_using_tensor_core(); + + if (!tensor_core_eligible && kernel.is_kernel_using_tensor_core()) { + VLOG(1) << "Detected new Op using TensorCores: " << kernel.op_name() + << std::endl; + tensor_core_eligible = true; + } + kernel.set_is_op_tensor_core_eligible(tensor_core_eligible); } } diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc index 4721047a856..14038d5c177 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc @@ -79,10 +79,9 @@ bool IsKernelUsingTensorCore(absl::string_view kernel_name) { // turing_fp16_s1688cudnn_fp16 bool possible_tensor_kernel = absl::StrContains(kernel_name, "884") || absl::StrContains(kernel_name, "1688"); -#if defined(VLOG_IF) - VLOG_IF(1, possible_tensor_kernel) - << "Possible tensor kernel: " << kernel_name << "\n"; -#endif // defined(VLOG_IF) + if (possible_tensor_kernel) { + VLOG(1) << "Possible tensor kernel: " << kernel_name << "\n"; + } return (absl::StartsWith(kernel_name, "volta_i884") || absl::StartsWith(kernel_name, "volta_h884") || From 827943f3d5b6ab8d3afce05c2d12fa981130eba5 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 21 Mar 2020 17:11:18 -0700 Subject: [PATCH 380/492] Emit unconverted ops with attributes Also dedup based on op an attributes. Don't report large element attributes. PiperOrigin-RevId: 302236543 Change-Id: I0c09293a69a57d18723614a2b2adc334260f18f6 --- .../mlir/lite/flatbuffer_translate.cc | 50 +++++++++++++++---- .../tests/mlir2flatbuffer/disable_custom.mlir | 5 +- .../tests/mlir2flatbuffer/disable_flex.mlir | 5 +- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index a75c1b3bab2..e8337d4a79f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -41,6 +41,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project @@ -527,8 +528,8 @@ class Translator { const Dialect* tfl_dialect_; // The failed ops during legalization. - std::vector failed_flex_ops_; - std::vector failed_custom_ops_; + std::set failed_flex_ops_; + std::set failed_custom_ops_; }; std::string Translator::UniqueName(mlir::Value val) { @@ -1083,11 +1084,39 @@ Optional> Translator::BuildOperator( return llvm::None; } } else { + // Create description of operation that could not be converted. + const int kLargeElementsAttr = 16; + std::string op_str; + llvm::raw_string_ostream os(op_str); + inst->getName().print(os); + // Print out attributes except for large elementsattributes (which should + // rarely be the cause why the legalization didn't happen). + if (!inst->getAttrList().getAttrs().empty()) { + os << " {"; + bool first = true; + for (auto& named_attr : inst->getAttrList().getDictionary()) { + os << (!first ? ", " : ""); + first = false; + named_attr.first.print(os); + os << " = "; + if (auto element_attr = named_attr.second.dyn_cast()) { + if (element_attr.getNumElements() <= kLargeElementsAttr) { + element_attr.print(os); + } else { + os << ""; + } + } else { + named_attr.second.print(os); + } + } + os << "}"; + } + // Insert failed op to `flex_ops` or `custom_ops`. if (IsWhitelistedFlexOp(node_def->op())) { - failed_flex_ops_.push_back(node_def->op()); + failed_flex_ops_.insert(os.str()); } else { - failed_custom_ops_.push_back(node_def->op()); + failed_custom_ops_.insert(os.str()); } return inst->emitOpError("is neither a custom op nor a flex op"), llvm::None; @@ -1385,19 +1414,20 @@ Optional Translator::TranslateInternal() { } if (first_failed_func != -1) { - std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, ","); - std::string failed_custom_ops_list = absl::StrJoin(failed_custom_ops_, ","); + std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); + std::string failed_custom_ops_list = + absl::StrJoin(failed_custom_ops_, "\n\t"); std::string err; if (!failed_flex_ops_list.empty()) err += "Ops that can be supported by the flex runtime (enabled via setting " - "the -emit-select-tf-ops flag): " + - failed_flex_ops_list + "."; + "the -emit-select-tf-ops flag):\n\t" + + failed_flex_ops_list; if (!failed_custom_ops_list.empty()) err += "Ops that need custom implementation (enabled via setting the " - "-emit-custom-ops flag): " + - failed_custom_ops_list + "."; + "-emit-custom-ops flag):\n\t" + + failed_custom_ops_list; auto& failed_region = named_regions[first_failed_func]; return failed_region.second->getParentOp()->emitError() diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir index 046fe6ac9ef..23d04c40ec4 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir @@ -1,8 +1,9 @@ -// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s +// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s --dump-input-on-failure // CHECK: error: 'tf.MyCustomOp' op is neither a custom op nor a flex op // CHECK: error: failed while converting: 'main' -// CHECK: Ops that need custom implementation (enabled via setting the -emit-custom-ops flag): MyCustomOp. +// CHECK: Ops that need custom implementation (enabled via setting the -emit-custom-ops flag): +// CHECK: tf.MyCustomOp {name = "MyCustomOp"} func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir index e77cd69cbc7..1c2e918f61e 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir @@ -1,8 +1,9 @@ -// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s +// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s --dump-input-on-failure // CHECK: error: 'tf.Div' op is neither a custom op nor a flex op // CHECK: error: failed while converting: 'main' -// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): Div. +// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): +// CHECK: tf.Div {name = "div"} func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): From 4e6bf4ea54ae98d3f02e17cf91920f1589665291 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Sat, 21 Mar 2020 17:49:36 -0700 Subject: [PATCH 381/492] Add explicit build dependencies for imports in graph_optimization_pass. PiperOrigin-RevId: 302239301 Change-Id: I9d1c4beaef528ce6e16e33ec1826b8b677c854c0 --- tensorflow/compiler/mlir/tensorflow/BUILD | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index bb8e3c19e9c..1cc26c9bb4d 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -196,8 +196,8 @@ cc_library( "transforms/legalize_hlo.cc", ], deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + ":lower_tf_lib", + ":tensorflow", "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/core:framework", "@llvm-project//llvm:support", @@ -472,8 +472,13 @@ cc_library( srcs = ["transforms/graph_optimization_pass.cc"], hdrs = ["transforms/graph_optimization_pass.h"], deps = [ + ":error_util", ":tensorflow_passes", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) From f1ba29bea2b26becf17c790043815aac5571e2e2 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Sat, 21 Mar 2020 19:07:31 -0700 Subject: [PATCH 382/492] Make flatbuffer_translate_lib dynamic linked To do this, some static registered translated functions are moved to a seperated c++ file and target. Only the binaries requires these translates functions needs to link them statically. This cl also removes the tensorflow/core:lib dependence from the quantize_model target. PiperOrigin-RevId: 302245203 Change-Id: Ic33d4dfd8c5fe4fb5fb1f1f4232cd406f8d4e705 --- tensorflow/compiler/mlir/lite/BUILD | 46 +- .../compiler/mlir/lite/flatbuffer_export.cc | 1455 ++++++++++++++++ ...buffer_translate.h => flatbuffer_export.h} | 6 +- ...late_flags.h => flatbuffer_export_flags.h} | 6 +- .../compiler/mlir/lite/flatbuffer_import.cc | 84 +- .../mlir/lite/flatbuffer_translate.cc | 1495 +---------------- .../compiler/mlir/lite/mlir_tflite_runner.cc | 4 +- .../lite/quantization/lite/quantize_model.cc | 2 +- .../mlir/lite/sparsity/sparsify_model.cc | 2 +- .../compiler/mlir/lite/tf_tfl_translate.cc | 4 +- .../mlir/lite/tf_to_tfl_flatbuffer.cc | 2 +- tensorflow/compiler/mlir/tensorflow/BUILD | 3 +- .../mlir/tensorflow/utils/error_util.cc | 2 +- .../mlir/tensorflow/utils/error_util.h | 2 +- 14 files changed, 1586 insertions(+), 1527 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/flatbuffer_export.cc rename tensorflow/compiler/mlir/lite/{flatbuffer_translate.h => flatbuffer_export.h} (90%) rename tensorflow/compiler/mlir/lite/{flatbuffer_translate_flags.h => flatbuffer_export_flags.h} (84%) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 03cf9265f3b..446ba89a3f1 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -224,7 +224,6 @@ cc_library( deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:support", @@ -554,14 +553,14 @@ cc_library( cc_library( name = "flatbuffer_translate_lib", srcs = [ + "flatbuffer_export.cc", "flatbuffer_import.cc", - "flatbuffer_translate.cc", "utils/convert_type.cc", ], hdrs = [ + "flatbuffer_export.h", + "flatbuffer_export_flags.h", "flatbuffer_import.h", - "flatbuffer_translate.h", - "flatbuffer_translate_flags.h", "utils/convert_type.h", ], deps = [ @@ -579,8 +578,10 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:status", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", @@ -601,15 +602,32 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", ], +) + +cc_library( + name = "flatbuffer_translate_registeration", + srcs = [ + "flatbuffer_translate.cc", + ], + deps = [ + ":flatbuffer_translate_lib", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LoopOpsTransforms", + "@llvm-project//mlir:MlirTranslateMain", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], alwayslink = 1, ) tf_cc_binary( name = "flatbuffer_translate", deps = [ - ":flatbuffer_translate_lib", - "@llvm-project//mlir:LoopOpsTransforms", - "@llvm-project//mlir:MlirTranslateMain", + ":flatbuffer_translate_registeration", ], ) @@ -647,10 +665,13 @@ filegroup( tf_cc_binary( name = "tf_tfl_translate", - srcs = [":tf_tfl_translate_main"], + srcs = [ + ":tf_tfl_translate_main", + ], deps = [ ":common", ":flatbuffer_translate_lib", + ":flatbuffer_translate_registeration", ":tensorflow_lite", ":tf_tfl_passes", ":tf_tfl_translate_cl_options", @@ -672,15 +693,18 @@ tf_cc_binary( tf_cc_binary( name = "mlir-tflite-runner", - srcs = ["mlir_tflite_runner.cc"], + srcs = [ + "mlir_tflite_runner.cc", + ], deps = [ ":flatbuffer_translate_lib", + ":flatbuffer_translate_registeration", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc new file mode 100644 index 00000000000..72e9b8c742a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -0,0 +1,1455 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Translation.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" +#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/tools/versioning/op_version.h" +#include "tensorflow/lite/tools/versioning/runtime_version.h" +#include "tensorflow/lite/version.h" + +using llvm::dyn_cast; +using llvm::formatv; +using llvm::isa; +using llvm::Optional; +using llvm::StringRef; +using llvm::Twine; +using mlir::Dialect; +using mlir::ElementsAttr; +using mlir::FuncOp; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::NoneType; +using mlir::Operation; +using mlir::Region; +using mlir::StringAttr; +using mlir::TensorType; +using mlir::Type; +using mlir::UnknownLoc; +using mlir::Value; +using tensorflow::OpOrArgLocNameMapper; +using tensorflow::OpOrArgNameMapper; +using tensorflow::Status; +using tflite::flex::IsWhitelistedFlexOp; +using xla::StatusOr; + +template +using BufferOffset = flatbuffers::Offset; + +template +using VectorBufferOffset = flatbuffers::Offset>; + +using CustomOptionsOffset = VectorBufferOffset; + +namespace error = tensorflow::error; +namespace tfl = mlir::TFL; + +ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; + +// Use initial buffer size in flatbuffer builder to be same as the initial size +// used by the TOCO export. (It does not explain rationale for this choice.) +constexpr size_t kInitialBufferSize = 10240; + +// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. +// Since tflite doesn't support unsigned for other types, returns error if +// `isSigned` is set to false for other types. +static StatusOr GetTFLiteType(Type type, + bool is_signed = true) { + if (!is_signed && type.isSignlessInteger(8)) { + return tflite::TensorType_UINT8; + } + if (!is_signed) { + return Status(error::INVALID_ARGUMENT, + "'isSigned' can only be set for 8-bits integer type"); + } + switch (type.getKind()) { + case mlir::StandardTypes::F32: + return tflite::TensorType_FLOAT32; + case mlir::StandardTypes::F16: + return tflite::TensorType_FLOAT16; + case mlir::TF::TensorFlowTypes::STRING: + return tflite::TensorType_STRING; + case mlir::TF::TensorFlowTypes::QUINT8: + return tflite::TensorType_UINT8; + case mlir::StandardTypes::Complex: { + auto ftype = type.cast().getElementType(); + if (ftype && ftype.isF32()) { + return tflite::TensorType_COMPLEX64; + } + return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } + case mlir::StandardTypes::Integer: { + const auto& itype = type.cast(); + switch (itype.getWidth()) { + case 1: + return tflite::TensorType_BOOL; + case 8: + return itype.isUnsigned() ? tflite::TensorType_UINT8 + : tflite::TensorType_INT8; + case 16: + return tflite::TensorType_INT16; + case 32: + return tflite::TensorType_INT32; + case 64: + return tflite::TensorType_INT64; + } + } + case mlir::quant::QuantizationTypes::UniformQuantized: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } + case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } + case mlir::TF::TensorFlowTypes::RESOURCE: { + // Treat tf.resource values as integer values in flatbuffer. + // TODO(b/146131919): Maybe need to have a detailed design for supporting + // other resource types beyonds hash table resources and resource + // variables. + return tflite::TensorType_INT32; + } + default: + // TFLite export fills FLOAT32 for unknown data types. Returning an error + // for now for safety and this could be revisited when required. + return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } +} + +static bool IsConst(Operation* op) { + return isa(op) || isa(op) || + isa(op) || isa(op); +} + +template +static bool HasValidTFLiteType(Value value, T& error_handler) { + // None type is allowed to represent unspecified operands. + if (value.getType().isa()) return true; + + auto type = value.getType().dyn_cast(); + if (!type) { + if (auto op = value.getDefiningOp()) { + error_handler.emitError() + << '\'' << op << "' should produce value of tensor type instead of " + << value.getType(); + return false; + } + error_handler.emitError("expected tensor type, got ") << value.getType(); + return false; + } + + Type element_type = type.getElementType(); + auto status = GetTFLiteType(element_type); + if (!status.ok()) { + return error_handler.emitError( + formatv("Failed to convert element type '{0}': {1}", + element_type, status.status().error_message())), + false; + } + return true; +} + +// Returns true if the module holds all the invariants expected by the +// Translator class. +// TODO(hinsu): Now that translation is done by making a single pass over the +// MLIR module, consider inlining these validation checks at the place where +// these invariants are assumed instead of checking upfront. +static bool IsValidTFLiteMlirModule(ModuleOp module) { + MLIRContext* context = module.getContext(); + + // Verify that module has a function named main. + FuncOp main_fn = module.lookupSymbol("main"); + if (!main_fn) { + return emitError(UnknownLoc::get(context), + "should have a function named 'main'"), + false; + } + + for (auto fn : module.getOps()) { + if (fn.getBlocks().size() != 1) { + return fn.emitError("should have exactly one basic block"), false; + } + auto& bb = fn.getBlocks().front(); + + for (auto arg : bb.getArguments()) { + if (!HasValidTFLiteType(arg, fn)) + return fn.emitError("invalid TFLite type: ") << arg.getType(), false; + } + + // Verify that all operations except the terminator have exactly one + // result of type supported by TFLite. + for (auto& inst : bb) { + if (inst.isKnownTerminator()) break; + + for (auto result : inst.getResults()) { + if (!HasValidTFLiteType(result, inst)) + return fn.emitError("invalid TFLite type: ") << result.getType(), + false; + } + } + } + + return true; +} + +static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( + ::mlir::Operation* inst) { + // We pass empty string for the original node_def name since Flex runtime + // does not care about this being set correctly on node_def. There is no + // "easy" (see b/120948529) way yet to get this from MLIR inst. + auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( + inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); + if (!status_or_node_def.ok()) { + inst->emitOpError( + Twine("failed to obtain TensorFlow nodedef with status: " + + status_or_node_def.status().ToString())); + return {}; + } + return std::move(status_or_node_def.ValueOrDie()); +} + +// Converts a mlir padding StringRef to TfLitePadding. +// Returns llvm::None if conversion fails. +static Optional GetTflitePadding(Operation* inst, + llvm::StringRef padding) { + const tflite::Padding padding_attr = + std::move(llvm::StringSwitch(padding) + .Case("SAME", tflite::Padding_SAME) + .Case("VALID", tflite::Padding_VALID)); + if (padding_attr == tflite::Padding_SAME) { + return kTfLitePaddingSame; + } + if (padding_attr == tflite::Padding_VALID) { + return kTfLitePaddingValid; + } + + return inst->emitOpError() << "Invalid padding attribute: " << padding, + llvm::None; +} + +// Extracts TfLitePoolParams from a TFL custom op. +// Template parameter, TFLOp, should be a TFL custom op containing attributes +// generated from TfLitePoolParams. +// Returns llvm::None if conversion fails. +template +static Optional GetTflitePoolParams(Operation* inst, + TFLOp op) { + TfLitePoolParams pool_params; + pool_params.stride_height = op.stride_h().getSExtValue(); + pool_params.stride_width = op.stride_w().getSExtValue(); + pool_params.filter_height = op.filter_h().getSExtValue(); + pool_params.filter_width = op.filter_w().getSExtValue(); + const auto padding = GetTflitePadding(inst, op.padding()); + if (padding) { + pool_params.padding = *padding; + pool_params.activation = kTfLiteActNone; + pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; + return pool_params; + } + + return llvm::None; +} + +namespace { + +// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. +class Translator { + public: + // Translates the given MLIR module into TFLite FlatBuffer format and returns + // the serialized output. Returns llvm::None on unsupported, invalid inputs or + // internal error. + static Optional Translate( + ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); + + private: + enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; + explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, + bool emit_select_tf_ops, bool emit_custom_ops, + OpOrArgNameMapper* op_or_arg_name_mapper) + : module_(module), + name_mapper_(*op_or_arg_name_mapper), + builder_(kInitialBufferSize) { + // The first buffer must be empty according to the schema definition. + empty_buffer_ = tflite::CreateBuffer(builder_); + buffers_.push_back(empty_buffer_); + if (emit_builtin_tflite_ops) { + enabled_op_types_.emplace(OpType::kTfliteBuiltin); + } + if (emit_select_tf_ops) { + enabled_op_types_.emplace(OpType::kSelectTf); + } + if (emit_custom_ops) { + enabled_op_types_.emplace(OpType::kCustomOp); + } + tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); + tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); + } + + Optional TranslateInternal(); + + // Returns TFLite buffer populated with constant value if the operation is + // TFLite constant operation. Otherwise, returns an empty buffer. Emits error + // and returns llvm::None on failure. + Optional> BuildBuffer(Operation* inst); + + // Build TFLite tensor from the given type. This function is for tfl.lstm + // intermediates, which should have UniformQuantizedType. + Optional> BuildTensorFromType( + mlir::Type type, const std::string& name); + + // Builds TFLite tensor from the given value. `buffer_idx` is index of the + // corresponding buffer. Emits error and returns llvm::None on failure. + Optional> BuildTensor(Value value, + const std::string& name, + unsigned buffer_idx); + + // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove + // these 2 functions here. + BufferOffset BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results); + BufferOffset BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results); + + // Build while operator where cond & body are regions. + Optional> BuildWhileOperator( + mlir::TFL::WhileOp op, const std::vector& operands, + const std::vector& results); + + // Builds custom operators. + // Templated on a) data type of custom_option to be stored into flatbuffer, + // and b) TFL custom op type. + template + BufferOffset BuildCustomOperator( + const CustomOptionType& custom_option, const std::string& opcode_name, + TFLOp op, const std::vector& operands, + const std::vector& results); + + BufferOffset BuildNumericVerifyOperator( + mlir::TFL::NumericVerifyOp op, const std::vector& operands, + const std::vector& results); + Optional> + BuildConvolution2DTransposeBiasOperator( + Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, + const std::vector& operands, + const std::vector& results); + Optional> BuildMaxPoolingWithArgMax2DOperator( + Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, + const std::vector& operands, + const std::vector& results); + Optional> BuildMaxUnpooling2DOperator( + Operation* inst, mlir::TFL::MaxUnpooling2DOp op, + const std::vector& operands, + const std::vector& results); + + Optional CreateFlexOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + Optional CreateCustomOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + std::unique_ptr CreateFlexBuilderWithNodeAttrs( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + // Returns opcode index for op identified by the op_name, if already + // available. Otherwise, creates a new OperatorCode using the given `builtin` + // operator and associates it with `op_name`. + uint32_t GetOpcodeIndex(const std::string& op_name, + tflite::BuiltinOperator builtin); + + // Builds operator for the given operation with specified operand and result + // tensor indices. Emits an error and returns llvm::None on failure. + Optional> BuildOperator( + Operation* inst, const std::vector& operands, + const std::vector& results, + const std::vector& intermediates); + + // Build a subgraph with a given name out of the region either corresponding + // to a function's body or while op. + Optional> BuildSubGraph( + const std::string& name, Region* region); + + // Builds Metadata with the given `name` and buffer `content`. + BufferOffset BuildMetadata(StringRef name, + StringRef content); + + // Encodes the `tfl.metadata` dictionary attribute of the module to the + // metadata section in the final model. + Optional>> + CreateMetadataVector(); + + // Uses the tf.entry_function attribute (if set) to initialize the op to name + // mapping. + void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); + + // Determines if the specified operation op's operand at operand_index + // is marked as a stateful operand. + bool IsStatefulOperand(mlir::Operation* op, int operand_index); + + // Returns a unique name for `val`. + std::string UniqueName(mlir::Value val); + + ModuleOp module_; + + tensorflow::OpOrArgNameMapper& name_mapper_; + + flatbuffers::FlatBufferBuilder builder_; + BufferOffset empty_buffer_; + + std::vector> buffers_; + + // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. + absl::flat_hash_map opcode_index_map_; + std::vector> opcodes_; + + // Maps function name to index of the corresponding subgraph in the FlatBuffer + // model. + absl::flat_hash_map subgraph_index_map_; + absl::flat_hash_set enabled_op_types_; + + // Points to TensorFlow and TFLite dialects, respectively. nullptr if the + // dialect is not registered. + const Dialect* tf_dialect_; + const Dialect* tfl_dialect_; + + // The failed ops during legalization. + std::set failed_flex_ops_; + std::set failed_custom_ops_; +}; + +std::string Translator::UniqueName(mlir::Value val) { + return std::string(name_mapper_.GetUniqueName(val)); +} + +Optional> Translator::BuildBuffer( + Operation* inst) { + ElementsAttr attr; + if (auto cst = dyn_cast(inst)) { + // ConstantOp have ElementAttr at this point due to validation of the TFLite + // module. + attr = cst.getValue().cast(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else { + return empty_buffer_; + } + + tensorflow::Tensor tensor; + auto status = tensorflow::ConvertToTensor(attr, &tensor); + if (!status.ok()) { + inst->emitError( + Twine("failed to convert value attribute to tensor with error: " + + status.ToString())); + return llvm::None; + } + + // TensorFlow and TensorFlow Lite use different string encoding formats. + // Convert to TensorFlow Lite format is it's a constant string tensor. + if (tensor.dtype() == tensorflow::DT_STRING) { + ::tflite::DynamicBuffer dynamic_buffer; + auto flat = tensor.flat<::tensorflow::tstring>(); + for (int i = 0; i < flat.size(); ++i) { + const auto& str = flat(i); + dynamic_buffer.AddString(str.c_str(), str.length()); + } + char* tensor_buffer; + int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); + auto buffer_data = + builder_.CreateVector(reinterpret_cast(tensor_buffer), bytes); + free(tensor_buffer); + return tflite::CreateBuffer(builder_, buffer_data); + } + + absl::string_view tensor_data = tensor.tensor_data(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(tensor_data.data()), tensor_data.size()); + return tflite::CreateBuffer(builder_, buffer_data); +} + +Optional> Translator::BuildTensorFromType( + mlir::Type type, const std::string& name) { + auto tensor_type = type.cast(); + + if (!tensor_type.hasStaticShape()) { + return llvm::None; + } + llvm::ArrayRef shape_ref = tensor_type.getShape(); + std::vector shape(shape_ref.begin(), shape_ref.end()); + + auto element_type = tensor_type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); + BufferOffset q_params; + auto qtype = element_type.dyn_cast(); + if (!qtype) { + return llvm::None; + } + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + /*buffer=*/0, builder_.CreateString(name), q_params, + /*is_variable=*/false); +} + +Optional> Translator::BuildTensor( + Value value, const std::string& name, unsigned buffer_idx) { + auto type = value.getType().cast(); + + // TFLite requires tensor shape only for the inputs and constants. + // However, we output all known shapes for better round-tripping + auto check_shape = + [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { + auto is_out_of_range = [](int64_t dim) { + return dim > std::numeric_limits::max(); + }; + + if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) + return mlir::emitError( + value.getLoc(), + "result shape dimensions out of 32 bit int type range"); + + return mlir::success(); + }; + + std::vector shape; + std::vector shape_signature; + if (type.hasStaticShape()) { + llvm::ArrayRef shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape = std::vector(shape_ref.begin(), shape_ref.end()); + } else if (auto* inst = value.getDefiningOp()) { + if (IsConst(inst)) { + // Const op can have a result of dynamic shaped type (e.g. due to constant + // folding), but we can still derive the shape of a constant tensor for + // its attribute type. + mlir::Attribute tensor_attr = inst->getAttr("value"); + llvm::ArrayRef shape_ref = + tensor_attr.getType().cast().getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape = std::vector(shape_ref.begin(), shape_ref.end()); + } + } else if (type.hasRank()) { + llvm::ArrayRef shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape.reserve(shape_ref.size()); + for (auto& dim : shape_ref) { + shape.push_back(dim == -1 ? 1 : dim); + } + shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); + } + + if (auto* inst = value.getDefiningOp()) { + if (auto cst = dyn_cast(inst)) { + // CreateSparsityParameters(cst.s_param()); + } else if (auto cst = dyn_cast(inst)) { + // CreateSparsityParameters(cst.s_param()); + } + } + + Type element_type = type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(type.getElementType()).ValueOrDie(); + + BufferOffset q_params; + if (auto qtype = element_type.dyn_cast()) { + q_params = tflite::CreateQuantizationParameters( + // TODO(fengliuai): min and max values are not stored in the + // quantized type, so both are set to 0. The model couldn't be imported + // to TensorFlow because of this. + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + } else if (auto qtype = + element_type + .dyn_cast()) { + std::vector scales(qtype.getScales().begin(), + qtype.getScales().end()); + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), + builder_.CreateVector(qtype.getZeroPoints()), + tflite::QuantizationDetails_NONE, /*details=*/0, + qtype.getQuantizedDimension()); + } else { + q_params = tflite::CreateQuantizationParameters(builder_); + } + // Check if the value's uses includes an op and usage at an operand index + // marked as a stateful. If so, set the tensor's is_variable as true + // This is v1 ref variable semantics in the TFLite runtime. + bool is_variable = false; + for (auto& use : value.getUses()) { + is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); + if (is_variable) { + break; + } + } + + if (shape_signature.empty()) { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable); + } else { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable, /*sparsity=*/0, + /*shape_signature=*/builder_.CreateVector(shape_signature)); + } +} + +BufferOffset Translator::BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); + int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); + int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); + auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, + else_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_IfOptions, + builtin_options); +} + +BufferOffset Translator::BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); + int body_subgraph_index = subgraph_index_map_.at(op.body().str()); + auto builtin_options = tflite::CreateWhileOptions( + builder_, cond_subgraph_index, body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); +} + +Optional> Translator::BuildWhileOperator( + mlir::TFL::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + auto get_call_index = [&](mlir::Block& b) -> Optional { + if (b.getOperations().size() != 2) return llvm::None; + if (auto call_op = dyn_cast(b.front())) + return subgraph_index_map_.at(call_op.callee().str()); + return llvm::None; + }; + auto body_subgraph_index = get_call_index(op.body().front()); + auto cond_subgraph_index = get_call_index(op.cond().front()); + if (!body_subgraph_index || !cond_subgraph_index) + return op.emitOpError("only single call cond/body while export supported"), + llvm::None; + auto builtin_options = + tflite::CreateWhileOptions(builder_, *cond_subgraph_index, + *body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); +} + +template +BufferOffset Translator::BuildCustomOperator( + const CustomOptionType& custom_option, const std::string& opcode_name, + TFLOp op, const std::vector& operands, + const std::vector& results) { + std::vector custom_option_vector(sizeof(CustomOptionType)); + memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType)); + auto opcode_index = + GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM); + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + builder_.CreateVector(custom_option_vector), + tflite::CustomOptionsFormat_FLEXBUFFERS); +} + +BufferOffset Translator::BuildNumericVerifyOperator( + mlir::TFL::NumericVerifyOp op, const std::vector& operands, + const std::vector& results) { + float tolerance = op.tolerance().convertToFloat(); + return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); +} + +Optional> +Translator::BuildConvolution2DTransposeBiasOperator( + Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, + const std::vector& operands, const std::vector& results) { + TfLiteTransposeConvParams conv_params; + conv_params.stride_height = op.stride_h().getSExtValue(); + conv_params.stride_width = op.stride_w().getSExtValue(); + const auto padding = GetTflitePadding(inst, op.padding()); + if (padding) { + conv_params.padding = *padding; + return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, + operands, results); + } + + return llvm::None; +} + +Optional> +Translator::BuildMaxPoolingWithArgMax2DOperator( + Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, + const std::vector& operands, const std::vector& results) { + const auto pool_params = GetTflitePoolParams(inst, op); + if (pool_params) { + return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, + operands, results); + } + + return llvm::None; +} + +Optional> +Translator::BuildMaxUnpooling2DOperator(Operation* inst, + mlir::TFL::MaxUnpooling2DOp op, + const std::vector& operands, + const std::vector& results) { + const auto pool_params = GetTflitePoolParams(inst, op); + if (pool_params) { + return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, + results); + } + + return llvm::None; +} + +Optional Translator::CreateFlexOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + std::string node_def_str; + if (!node_def.SerializeToString(&node_def_str)) { + return emitError(loc, "failed to serialize tensorflow node_def"), + llvm::None; + } + + auto flex_builder = absl::make_unique(); + flex_builder->Vector([&]() { + flex_builder->String(node_def.op()); + flex_builder->String(node_def_str); + }); + flex_builder->Finish(); + return builder_.CreateVector(flex_builder->GetBuffer()); +} + +Optional Translator::CreateCustomOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + std::string node_def_str; + if (!node_def.SerializeToString(&node_def_str)) { + return emitError(loc, "failed to serialize tensorflow node_def"), + llvm::None; + } + auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); + return builder_.CreateVector(flex_builder->GetBuffer()); +} + +std::unique_ptr +Translator::CreateFlexBuilderWithNodeAttrs( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + auto flex_builder = absl::make_unique(); + size_t map_start = flex_builder->StartMap(); + for (const auto& pair : node_def.attr()) { + const char* key = pair.first.c_str(); + const auto& attr = pair.second; + switch (attr.value_case()) { + case ::tensorflow::AttrValue::kS: + flex_builder->String(key, attr.s()); + break; + case ::tensorflow::AttrValue::kType: { + auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); + if (status_or_tfl_type.ok()) { + flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); + } else { + emitWarning(loc, "ignoring unsupported tensorflow type: ") + << std::to_string(attr.type()); + } + break; + } + case ::tensorflow::AttrValue::kI: + flex_builder->Int(key, attr.i()); + break; + case ::tensorflow::AttrValue::kF: + flex_builder->Float(key, attr.f()); + break; + case ::tensorflow::AttrValue::kB: + flex_builder->Bool(key, attr.b()); + break; + case tensorflow::AttrValue::kList: + if (attr.list().s_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const std::string& v : attr.list().s()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else if (attr.list().i_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const int64_t v : attr.list().i()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else if (attr.list().f_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const float v : attr.list().f()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else { + emitWarning(loc, + "ignoring unsupported type in list attribute with key: ") + << key; + } + break; + default: + emitWarning(loc, "ignoring unsupported attribute type with key: ") + << key; + break; + } + } + flex_builder->EndMap(map_start); + flex_builder->Finish(); + return flex_builder; +} + +uint32_t Translator::GetOpcodeIndex(const std::string& op_name, + tflite::BuiltinOperator builtin) { + auto it = opcode_index_map_.insert({op_name, 0}); + + // If the insert succeeded, the opcode has not been created already. Create a + // new operator code and update its index value in the map. + if (it.second) { + it.first->second = opcodes_.size(); + auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM + ? builder_.CreateString(op_name) + : BufferOffset(); + // Use version 0 for builtin op. This is a way to serialize version field to + // flatbuffer (since 0 is non default) and it will be corrected later. + int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; + opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, + custom_code, op_version)); + } + return it.first->second; +} + +Optional> Translator::BuildOperator( + Operation* inst, const std::vector& operands, + const std::vector& results, + const std::vector& intermediates) { + const auto* dialect = inst->getDialect(); + if (!dialect) { + inst->emitOpError("dialect is not registered"); + return llvm::None; + } + + // If TFLite built in op, create operator as a builtin op. + if (dialect == tfl_dialect_) { + // Only if built-in TFLite op emission is enabled, would legalization have + // converted any TF->TFL. + if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { + return inst->emitOpError( + "is a TFLite builtin op but builtin emission is not enabled"), + llvm::None; + } + + auto builtin_code = GetBuiltinOpCode(inst); + if (!builtin_code) { + if (auto verify_op = dyn_cast(inst)) { + return BuildNumericVerifyOperator(verify_op, operands, results); + } + if (auto conv_transpose_bias_op = + dyn_cast(inst)) { + return BuildConvolution2DTransposeBiasOperator( + inst, conv_transpose_bias_op, operands, results); + } + if (auto max_pooling_with_arg_max_op = + dyn_cast(inst)) { + return BuildMaxPoolingWithArgMax2DOperator( + inst, max_pooling_with_arg_max_op, operands, results); + } + if (auto max_unpooling_op = dyn_cast(inst)) { + return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, + results); + } + if (auto whileOp = dyn_cast(inst)) { + if (inst->getNumOperands() != inst->getNumResults()) { + inst->emitOpError( + "number of operands and results don't match, only canonical " + "TFL While supported"); + return llvm::None; + } + return BuildWhileOperator(whileOp, operands, results); + } + + inst->emitOpError("is not a supported TFLite op"); + return llvm::None; + } + + std::string op_name = inst->getName().getStringRef().str(); + uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); + auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, + results, intermediates, &builder_); + if (!offset) { + inst->emitOpError("is not a supported TFLite op"); + } + return offset; + } + + if (dialect == tf_dialect_) { + std::string op_name; + if (auto ifOp = dyn_cast(inst)) { + return BuildIfOperator(ifOp, operands, results); + } else if (auto whileOp = dyn_cast(inst)) { + return BuildWhileOperator(whileOp, operands, results); + } + + CustomOptionsOffset custom_options; + + // Ops in TF dialect can either be custom ops or flex ops. + // The reason we go directly from TensorFlow dialect MLIR to tensorflow + // node instead of going to TF table gen'd ops via generated code is that + // we do not want to restrict custom and flex op conversion support to + // only those TF ops that are currently registered in MLIR. The current + // model is of an open op system. + // + // The following algorithm is followed: + // if flex is enabled and the op is whitelisted as flex + // we emit op as flex. + // if custom is enabled + // we emit the op as custom. + auto node_def = GetTensorFlowNodeDef(inst); + if (!node_def) { + return llvm::None; + } + + // Flex op case + // Eventually, the whitelist will go away and we will rely on some TF op + // trait (e.g. No side effect) to determine if it is a supported "Flex" + // op or not. + if (enabled_op_types_.contains(OpType::kSelectTf) && + IsWhitelistedFlexOp(node_def->op())) { + // Construct ops as flex op encoding TensorFlow node definition + // as custom options. + // Flex ops are named with the kFlexOpNamePrefix prefix to the actual + // TF op name. + op_name = std::string(kFlexOpNamePrefix) + node_def->op(); + if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { + return llvm::None; + } + } else if (enabled_op_types_.contains(OpType::kCustomOp)) { + // Generic case of custom ops - write using flex buffers since that + // is the only custom options supported by TFLite today. + op_name = node_def->op(); + if (auto options = + CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { + return llvm::None; + } + } else { + // Create description of operation that could not be converted. + const int kLargeElementsAttr = 16; + std::string op_str; + llvm::raw_string_ostream os(op_str); + inst->getName().print(os); + // Print out attributes except for large elementsattributes (which should + // rarely be the cause why the legalization didn't happen). + if (!inst->getAttrList().getAttrs().empty()) { + os << " {"; + bool first = true; + for (auto& named_attr : inst->getAttrList().getDictionary()) { + os << (!first ? ", " : ""); + first = false; + named_attr.first.print(os); + os << " = "; + if (auto element_attr = named_attr.second.dyn_cast()) { + if (element_attr.getNumElements() <= kLargeElementsAttr) { + element_attr.print(os); + } else { + os << ""; + } + } else { + named_attr.second.print(os); + } + } + os << "}"; + } + + // Insert failed op to `flex_ops` or `custom_ops`. + if (IsWhitelistedFlexOp(node_def->op())) { + failed_flex_ops_.insert(os.str()); + } else { + failed_custom_ops_.insert(os.str()); + } + return inst->emitOpError("is neither a custom op nor a flex op"), + llvm::None; + } + + uint32_t opcode_index = + GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + /*custom_options=*/custom_options, + tflite::CustomOptionsFormat_FLEXBUFFERS, + /*mutating_variable_inputs=*/0); + } + + return inst->emitOpError( + "is not any of a builtin TFLite op, a flex TensorFlow op or a " + "custom TensorFlow op"), + llvm::None; +} + +void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { + auto dict_attr = fn.getAttrOfType("tf.entry_function"); + if (!dict_attr) return; + + llvm::SmallVector input_names; + llvm::SmallVector output_names; + if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { + str.getValue().split(input_names, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + if (input_names.size() != fn.getNumArguments()) { + fn.emitWarning() << "invalid entry function specification"; + return; + } + for (auto it : llvm::enumerate(fn.getArguments())) { + name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); + } + *has_input_attr = true; + } + + if (auto str = + dict_attr.get("outputs").dyn_cast_or_null()) { + str.getValue().split(output_names, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + auto term = fn.getBlocks().back().getTerminator(); + if (output_names.size() != term->getNumOperands()) { + fn.emitWarning() << "output names (" << output_names.size() + << ") != terminator operands (" << term->getNumOperands() + << ")"; + return; + } + for (const auto& it : llvm::enumerate(term->getOperands())) { + name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); + } + } +} + +bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { + std::vector operand_indices; + if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; + return absl::c_find(operand_indices, operand_index) != operand_indices.end(); +} + +Optional> Translator::BuildSubGraph( + const std::string& name, Region* region) { + bool has_input_attr = false; + if (auto fn = dyn_cast(region->getParentOp())) { + InitializeNamesFromAttribute(fn, &has_input_attr); + } + std::vector> tensors; + llvm::DenseMap tensor_index_map; + + // Builds tensor and buffer for argument or operation result. Returns false + // on failure. + auto build_tensor_and_buffer = [&](Value value, const std::string& name) { + // NoneType represents optional and may be skipped here. + if (value.getType().isa()) { + return true; + } + + tensor_index_map.insert({value, tensors.size()}); + auto tensor_or = BuildTensor(value, name, buffers_.size()); + if (!tensor_or) return false; + tensors.push_back(*tensor_or); + + // TODO(ashwinm): Check if for stateful tensors, if it is also needed to + // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. + // This does not seem to affect runtime behavior for RNN/LSTM, but would be + // good for reducing memory footprint. + if (auto* inst = value.getDefiningOp()) { + auto buffer_or = BuildBuffer(inst); + if (!buffer_or) return false; + buffers_.push_back(*buffer_or); + } else { + buffers_.push_back(empty_buffer_); + } + return true; + }; + + std::vector> operators; + auto& bb = region->front(); + + // Main function's arguments are first passed to `input` op so they don't + // have associated tensor and buffer. Build FlatBuffer tensor and buffer for + // other functions. + for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { + mlir::BlockArgument arg = bb.getArgument(i); + std::string name; + if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg)); + if (name.empty()) name = absl::StrCat("arg", i); + if (!build_tensor_and_buffer(arg, name)) return llvm::None; + } + + bool failed_once = false; + for (auto& inst : bb) { + if (inst.isKnownTerminator()) break; + std::vector intermediates; + // Build intermediate tensors for tfl.lstm and insert these tensors into + // flatbuffer. + if (llvm::isa(inst)) { + std::vector intermediate_names = { + "input_to_input_intermediate", "input_to_forget_intermediate", + "input_to_cell_intermediate", "input_to_output_intermediate", + "effective_hidden_scale_intermediate"}; + for (const std::string& intermediate : intermediate_names) { + auto intermediate_attr = inst.getAttr(intermediate); + if (auto attr = intermediate_attr.dyn_cast_or_null()) { + Type qtype = attr.getValue(); + auto tensor_or = BuildTensorFromType( + qtype, name_mapper_.GetUniqueName(intermediate).str()); + if (!tensor_or.hasValue()) { + continue; + } else { + intermediates.push_back(tensors.size()); + tensors.push_back(tensor_or.getValue()); + } + } + } + } + + for (auto val : inst.getResults()) { + std::string name = UniqueName(val); + if (!build_tensor_and_buffer(val, name)) return llvm::None; + } + + // Skip constant ops as they don't represent a TFLite operator. + if (IsConst(&inst)) continue; + + // Fetch operand and result tensor indices. + std::vector operands; + operands.reserve(inst.getNumOperands()); + for (auto operand : inst.getOperands()) { + if (operand.getType().isa()) + operands.push_back(kTfLiteOptionalTensor); + else + operands.push_back(tensor_index_map.lookup(operand)); + } + std::vector results; + results.reserve(inst.getNumOperands()); + for (auto result : inst.getResults()) { + results.push_back(tensor_index_map.lookup(result)); + } + + if (auto tfl_operator = + BuildOperator(&inst, operands, results, intermediates)) + operators.push_back(*tfl_operator); + else + failed_once = true; + } + + if (failed_once) return llvm::None; + + // Get input and output tensor indices for the subgraph. + std::vector inputs, outputs; + for (auto arg : bb.getArguments()) { + inputs.push_back(tensor_index_map[arg]); + } + for (auto result : bb.getTerminator()->getOperands()) { + outputs.push_back(tensor_index_map[result]); + } + + return tflite::CreateSubGraph( + builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), + builder_.CreateVector(outputs), builder_.CreateVector(operators), + /*name=*/builder_.CreateString(name)); +} + +BufferOffset Translator::BuildMetadata(StringRef name, + StringRef content) { + auto buffer_index = buffers_.size(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(content.data()), content.size()); + buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); + return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); +} + +Optional>> +Translator::CreateMetadataVector() { + auto dict_attr = module_.getAttrOfType("tfl.metadata"); + std::vector> metadata; + if (dict_attr) { + for (const auto& named_attr : dict_attr) { + StringRef name = named_attr.first; + mlir::Attribute attr = named_attr.second; + if (auto content = attr.dyn_cast()) { + metadata.push_back(BuildMetadata(name, content.getValue())); + } else { + module_.emitError( + "all values in tfl.metadata's dictionary key-value pairs should be " + "string attributes"); + return llvm::None; + } + } + } + // Runtime version string is generated after we update the op + // versions. Here we put a 16-byte dummy string as a placeholder. We choose + // 16-byte because it's the alignment of buffers in flatbuffer, so it won't + // cause any waste of space if the actual string is shorter than 16 bytes. + metadata.push_back( + BuildMetadata("min_runtime_version", std::string(16, '\0'))); + return builder_.CreateVector(metadata); +} + +bool UpdateEntryFunction(ModuleOp module) { + if (module.lookupSymbol("main") != nullptr) { + // We already have an entry function. + return true; + } + + int entry_func_count = 0; + FuncOp entry_func = nullptr; + for (auto fn : module.getOps()) { + auto attrs = fn.getAttrOfType("tf.entry_function"); + if (attrs && !attrs.empty()) { + entry_func_count++; + entry_func = fn; + } + } + + // We should have one & only have one entry function. + if (entry_func_count != 1) return false; + + // Update the entry func to main. + entry_func.setName("main"); + return true; +} + +Optional Translator::Translate( + ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { + if (!UpdateEntryFunction(module)) return llvm::None; + if (!IsValidTFLiteMlirModule(module)) return llvm::None; + Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_or_arg_name_mapper); + return translator.TranslateInternal(); +} + +Optional Translator::TranslateInternal() { + // A list of named regions in the module with main function being the first in + // the list. The main function is required as the first subgraph in the model + // is entry point for the model. + std::vector> named_regions; + named_regions.reserve(std::distance(module_.begin(), module_.end())); + + int subgraph_idx = 0; + FuncOp main_fn = module_.lookupSymbol("main"); + subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; + named_regions.emplace_back("main", &main_fn.getBody()); + // Walk over the module collection ops with functions and while ops. + module_.walk([&](FuncOp fn) { + if (fn != main_fn) { + subgraph_index_map_[fn.getName().str()] = subgraph_idx++; + named_regions.emplace_back(fn.getName().str(), &fn.getBody()); + } + }); + + // Build subgraph for each of the named regions. + std::vector> subgraphs; + subgraphs.reserve(named_regions.size()); + int first_failed_func = -1; + for (auto it : llvm::enumerate(named_regions)) { + auto subgraph_or = BuildSubGraph(it.value().first, it.value().second); + if (!subgraph_or) { + if (first_failed_func == -1) + // Record the index of the first region that cannot be converted. + // Keep looping through all subgraphs in the module to make sure that + // we collect the list of missing ops from the entire module. + first_failed_func = it.index(); + } else { + subgraphs.push_back(*subgraph_or); + } + } + + if (first_failed_func != -1) { + std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); + std::string failed_custom_ops_list = + absl::StrJoin(failed_custom_ops_, "\n\t"); + std::string err; + if (!failed_flex_ops_list.empty()) + err += + "Ops that can be supported by the flex runtime (enabled via setting " + "the -emit-select-tf-ops flag):\n\t" + + failed_flex_ops_list; + if (!failed_custom_ops_list.empty()) + err += + "Ops that need custom implementation (enabled via setting the " + "-emit-custom-ops flag):\n\t" + + failed_custom_ops_list; + + auto& failed_region = named_regions[first_failed_func]; + return failed_region.second->getParentOp()->emitError() + << "failed while converting: '" << failed_region.first + << "': " << err, + llvm::None; + } + + std::string model_description; + if (auto attr = module_.getAttrOfType("tfl.description")) { + model_description = attr.getValue().str(); + } else { + model_description = "MLIR Converted."; + } + + // Build the model and finish the model building process. + auto description = builder_.CreateString(model_description.data()); + VectorBufferOffset metadata_buffer = 0; // Deprecated + auto metadata = CreateMetadataVector(); + if (!metadata) return llvm::None; + + auto model = tflite::CreateModel( + builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), + builder_.CreateVector(subgraphs), description, + builder_.CreateVector(buffers_), metadata_buffer, *metadata); + tflite::FinishModelBuffer(builder_, model); + tflite::UpdateOpVersion(builder_.GetBufferPointer()); + tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); + + // Return serialized string for the built FlatBuffer. + return std::string(reinterpret_cast(builder_.GetBufferPointer()), + builder_.GetSize()); +} + +} // namespace + +// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer +// format. Returns false on success. +// +// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting +// the following: +// +// * Quantization +// * Ops with variable tensors +// +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + OpOrArgNameMapper* op_or_arg_name_mapper) { + auto maybe_translated = + Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_or_arg_name_mapper); + if (!maybe_translated) return true; + *serialized_flatbuffer = std::move(*maybe_translated); + return false; +} + +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops) { + OpOrArgLocNameMapper op_or_arg_name_mapper; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); +} diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.h b/tensorflow/compiler/mlir/lite/flatbuffer_export.h similarity index 90% rename from tensorflow/compiler/mlir/lite/flatbuffer_translate.h rename to tensorflow/compiler/mlir/lite/flatbuffer_export.h index 03f92ddbf03..f89893d5c87 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ #include @@ -40,4 +40,4 @@ bool MlirToFlatBufferTranslateFunction( tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); } // namespace tflite -#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h b/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h similarity index 84% rename from tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h rename to tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h index 6c8f80d4e05..4e891a5b266 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ #include @@ -28,4 +28,4 @@ extern bool lower_tensor_list_ops; // The flag to control whether debug info gets stripped on export. extern bool strip_debug_info; -#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 4b888764053..1eec402d35a 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -63,20 +63,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Translation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -100,45 +96,6 @@ using xla::StatusOr; namespace errors = tensorflow::errors; namespace tfl = mlir::TFL; -using llvm::cl::opt; - -// Commandline flag to enable the control of flatbuffer import. -bool use_external_constant; - -// Commandline flag to enable graph pruning. -bool experimental_prune_unreachable_nodes_unconditionally; - -// NOLINTNEXTLINE -static opt use_external_constant_flag( - "use-external-constant", - llvm::cl::desc("Use external constant during flatbuffer import"), - llvm::cl::location(use_external_constant), llvm::cl::init(false)); - -// TODO(b/147111261): After the importer supports generic custom ops, we should -// change the flag to a more lightwise flag, e.g. -// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune -// the operations. -// NOLINTNEXTLINE -static opt experimental_prune_unreachable_nodes_unconditionally_flg( - "experimental-prune-unreachable-nodes-unconditionally", - llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), - llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static opt input_arrays_flag( - "input-arrays", - llvm::cl::desc( - "List of input tensors, if different from the default inputs"), - llvm::cl::init("")); - -// NOLINTNEXTLINE -static opt output_arrays_flag( - "output-arrays", - llvm::cl::desc( - "List of output tensors, if different from the default outputs"), - llvm::cl::init("")); - namespace { bool IsScalar(const TensorT& tensor) { // TODO(b/138222071) We can't distinguish scalars and unranked tensors @@ -1063,42 +1020,3 @@ OwningModuleRef tflite::FlatBufferToMlir( return OwningModuleRef(module); } - -static OwningModuleRef FlatBufferFileToMlirTrans( - llvm::SourceMgr* source_mgr, MLIRContext* context, - bool use_external_constant, - bool experimental_prune_unreachable_nodes_unconditionally) { - const llvm::MemoryBuffer* input = - source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); - std::string error; - auto loc = - mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); - - // Parses input/output names from command line options. - std::vector inputs; - std::vector outputs; - // Use output parser since we only have tensor names. - if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) { - return emitError(loc, "parsing input array info failed ") - << input_arrays_flag, - nullptr; - } - if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) { - return emitError(loc, "parsing output array info failed ") - << output_arrays_flag, - nullptr; - } - - return tflite::FlatBufferToMlir( - absl::string_view(input->getBufferStart(), input->getBufferSize()), - context, loc, use_external_constant, inputs, outputs, - experimental_prune_unreachable_nodes_unconditionally); -} - -static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( - "tflite-flatbuffer-to-mlir", - [](llvm::SourceMgr& source_mgr, MLIRContext* context) { - return FlatBufferFileToMlirTrans( - &source_mgr, context, use_external_constant, - experimental_prune_unreachable_nodes_unconditionally); - }); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index e8337d4a79f..ee7ac81dce9 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -13,31 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" @@ -56,67 +31,48 @@ limitations under the License. #include "mlir/IR/Value.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "mlir/Translation.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/string_util.h" -#include "tensorflow/lite/tools/versioning/op_version.h" -#include "tensorflow/lite/tools/versioning/runtime_version.h" -#include "tensorflow/lite/version.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -using llvm::dyn_cast; -using llvm::formatv; -using llvm::isa; -using llvm::Optional; -using llvm::StringRef; -using llvm::Twine; -using mlir::Dialect; -using mlir::ElementsAttr; -using mlir::FuncOp; -using mlir::MLIRContext; -using mlir::ModuleOp; -using mlir::NoneType; -using mlir::Operation; -using mlir::Region; -using mlir::StringAttr; -using mlir::TensorType; -using mlir::TranslateFromMLIRRegistration; -using mlir::Type; -using mlir::UnknownLoc; -using mlir::Value; -using tensorflow::OpOrArgLocNameMapper; -using tensorflow::OpOrArgNameMapper; -using tensorflow::Status; -using tflite::flex::IsWhitelistedFlexOp; -using xla::StatusOr; +using llvm::cl::opt; -template -using BufferOffset = flatbuffers::Offset; +// Commandline flag to enable the control of flatbuffer import. +bool use_external_constant; -template -using VectorBufferOffset = flatbuffers::Offset>; +// Commandline flag to enable graph pruning. +bool experimental_prune_unreachable_nodes_unconditionally; -using CustomOptionsOffset = VectorBufferOffset; +// NOLINTNEXTLINE +static opt use_external_constant_flag( + "use-external-constant", + llvm::cl::desc("Use external constant during flatbuffer import"), + llvm::cl::location(use_external_constant), llvm::cl::init(false)); -namespace error = tensorflow::error; -namespace tfl = mlir::TFL; +// TODO(b/147111261): After the importer supports generic custom ops, we should +// change the flag to a more lightwise flag, e.g. +// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune +// the operations. +// NOLINTNEXTLINE +static opt experimental_prune_unreachable_nodes_unconditionally_flg( + "experimental-prune-unreachable-nodes-unconditionally", + llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), + llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), + llvm::cl::init(false)); +// NOLINTNEXTLINE +static opt input_arrays_flag( + "input-arrays", + llvm::cl::desc( + "List of input tensors, if different from the default inputs"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +static opt output_arrays_flag( + "output-arrays", + llvm::cl::desc( + "List of output tensors, if different from the default outputs"), + llvm::cl::init("")); using llvm::cl::opt; // These command line flags enable control of the translation implementation. @@ -157,1353 +113,48 @@ static opt strip_debug_info_flag( "strip-debug-info", llvm::cl::desc("Strip debug info during export"), llvm::cl::location(strip_debug_info), llvm::cl::init(false)); -ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; - -// Use initial buffer size in flatbuffer builder to be same as the initial size -// used by the TOCO export. (It does not explain rationale for this choice.) -constexpr size_t kInitialBufferSize = 10240; - -// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. -// Since tflite doesn't support unsigned for other types, returns error if -// `isSigned` is set to false for other types. -static StatusOr GetTFLiteType(Type type, - bool is_signed = true) { - if (!is_signed && type.isSignlessInteger(8)) { - return tflite::TensorType_UINT8; - } - if (!is_signed) { - return Status(error::INVALID_ARGUMENT, - "'isSigned' can only be set for 8-bits integer type"); - } - switch (type.getKind()) { - case mlir::StandardTypes::F32: - return tflite::TensorType_FLOAT32; - case mlir::StandardTypes::F16: - return tflite::TensorType_FLOAT16; - case mlir::TF::TensorFlowTypes::STRING: - return tflite::TensorType_STRING; - case mlir::TF::TensorFlowTypes::QUINT8: - return tflite::TensorType_UINT8; - case mlir::StandardTypes::Complex: { - auto ftype = type.cast().getElementType(); - if (ftype && ftype.isF32()) { - return tflite::TensorType_COMPLEX64; - } - return Status(error::INVALID_ARGUMENT, "Unsupported type"); - } - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return tflite::TensorType_BOOL; - case 8: - return itype.isUnsigned() ? tflite::TensorType_UINT8 - : tflite::TensorType_INT8; - case 16: - return tflite::TensorType_INT16; - case 32: - return tflite::TensorType_INT32; - case 64: - return tflite::TensorType_INT64; - } - } - case mlir::quant::QuantizationTypes::UniformQuantized: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::TF::TensorFlowTypes::RESOURCE: { - // Treat tf.resource values as integer values in flatbuffer. - // TODO(b/146131919): Maybe need to have a detailed design for supporting - // other resource types beyonds hash table resources and resource - // variables. - return tflite::TensorType_INT32; - } - default: - // TFLite export fills FLOAT32 for unknown data types. Returning an error - // for now for safety and this could be revisited when required. - return Status(error::INVALID_ARGUMENT, "Unsupported type"); - } -} - -static bool IsConst(Operation* op) { - return isa(op) || isa(op) || - isa(op) || isa(op); -} - -template -static bool HasValidTFLiteType(Value value, T& error_handler) { - // None type is allowed to represent unspecified operands. - if (value.getType().isa()) return true; - - auto type = value.getType().dyn_cast(); - if (!type) { - if (auto op = value.getDefiningOp()) { - error_handler.emitError() - << '\'' << op << "' should produce value of tensor type instead of " - << value.getType(); - return false; - } - error_handler.emitError("expected tensor type, got ") << value.getType(); - return false; - } - - Type element_type = type.getElementType(); - auto status = GetTFLiteType(element_type); - if (!status.ok()) { - return error_handler.emitError( - formatv("Failed to convert element type '{0}': {1}", - element_type, status.status().error_message())), - false; - } - return true; -} - -// Returns true if the module holds all the invariants expected by the -// Translator class. -// TODO(hinsu): Now that translation is done by making a single pass over the -// MLIR module, consider inlining these validation checks at the place where -// these invariants are assumed instead of checking upfront. -static bool IsValidTFLiteMlirModule(ModuleOp module) { - MLIRContext* context = module.getContext(); - - // Verify that module has a function named main. - FuncOp main_fn = module.lookupSymbol("main"); - if (!main_fn) { - return emitError(UnknownLoc::get(context), - "should have a function named 'main'"), - false; - } - - for (auto fn : module.getOps()) { - if (fn.getBlocks().size() != 1) { - return fn.emitError("should have exactly one basic block"), false; - } - auto& bb = fn.getBlocks().front(); - - for (auto arg : bb.getArguments()) { - if (!HasValidTFLiteType(arg, fn)) - return fn.emitError("invalid TFLite type: ") << arg.getType(), false; - } - - // Verify that all operations except the terminator have exactly one - // result of type supported by TFLite. - for (auto& inst : bb) { - if (inst.isKnownTerminator()) break; - - for (auto result : inst.getResults()) { - if (!HasValidTFLiteType(result, inst)) - return fn.emitError("invalid TFLite type: ") << result.getType(), - false; - } - } - } - - return true; -} - -static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( - ::mlir::Operation* inst) { - // We pass empty string for the original node_def name since Flex runtime - // does not care about this being set correctly on node_def. There is no - // "easy" (see b/120948529) way yet to get this from MLIR inst. - auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( - inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); - if (!status_or_node_def.ok()) { - inst->emitOpError( - Twine("failed to obtain TensorFlow nodedef with status: " + - status_or_node_def.status().ToString())); - return {}; - } - return std::move(status_or_node_def.ValueOrDie()); -} - -// Converts a mlir padding StringRef to TfLitePadding. -// Returns llvm::None if conversion fails. -static Optional GetTflitePadding(Operation* inst, - llvm::StringRef padding) { - const tflite::Padding padding_attr = - std::move(llvm::StringSwitch(padding) - .Case("SAME", tflite::Padding_SAME) - .Case("VALID", tflite::Padding_VALID)); - if (padding_attr == tflite::Padding_SAME) { - return kTfLitePaddingSame; - } - if (padding_attr == tflite::Padding_VALID) { - return kTfLitePaddingValid; - } - - return inst->emitOpError() << "Invalid padding attribute: " << padding, - llvm::None; -} - -// Extracts TfLitePoolParams from a TFL custom op. -// Template parameter, TFLOp, should be a TFL custom op containing attributes -// generated from TfLitePoolParams. -// Returns llvm::None if conversion fails. -template -static Optional GetTflitePoolParams(Operation* inst, - TFLOp op) { - TfLitePoolParams pool_params; - pool_params.stride_height = op.stride_h().getSExtValue(); - pool_params.stride_width = op.stride_w().getSExtValue(); - pool_params.filter_height = op.filter_h().getSExtValue(); - pool_params.filter_width = op.filter_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - pool_params.padding = *padding; - pool_params.activation = kTfLiteActNone; - pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; - return pool_params; - } - - return llvm::None; -} - +namespace mlir { namespace { +static OwningModuleRef FlatBufferFileToMlirTrans( + llvm::SourceMgr* source_mgr, MLIRContext* context, + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { + const llvm::MemoryBuffer* input = + source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); + std::string error; + auto loc = + mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); -// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. -class Translator { - public: - // Translates the given MLIR module into TFLite FlatBuffer format and returns - // the serialized output. Returns llvm::None on unsupported, invalid inputs or - // internal error. - static Optional Translate( - ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); - - private: - enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; - explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, - bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) - : module_(module), - name_mapper_(*op_or_arg_name_mapper), - builder_(kInitialBufferSize) { - // The first buffer must be empty according to the schema definition. - empty_buffer_ = tflite::CreateBuffer(builder_); - buffers_.push_back(empty_buffer_); - if (emit_builtin_tflite_ops) { - enabled_op_types_.emplace(OpType::kTfliteBuiltin); - } - if (emit_select_tf_ops) { - enabled_op_types_.emplace(OpType::kSelectTf); - } - if (emit_custom_ops) { - enabled_op_types_.emplace(OpType::kCustomOp); - } - tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); - tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); + // Parses input/output names from command line options. + std::vector inputs; + std::vector outputs; + // Use output parser since we only have tensor names. + if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) { + return emitError(loc, "parsing input array info failed ") + << input_arrays_flag, + nullptr; } - - Optional TranslateInternal(); - - // Returns TFLite buffer populated with constant value if the operation is - // TFLite constant operation. Otherwise, returns an empty buffer. Emits error - // and returns llvm::None on failure. - Optional> BuildBuffer(Operation* inst); - - // Build TFLite tensor from the given type. This function is for tfl.lstm - // intermediates, which should have UniformQuantizedType. - Optional> BuildTensorFromType( - mlir::Type type, const std::string& name); - - // Builds TFLite tensor from the given value. `buffer_idx` is index of the - // corresponding buffer. Emits error and returns llvm::None on failure. - Optional> BuildTensor(Value value, - const std::string& name, - unsigned buffer_idx); - - // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove - // these 2 functions here. - BufferOffset BuildIfOperator( - mlir::TF::IfOp op, const std::vector& operands, - const std::vector& results); - BufferOffset BuildWhileOperator( - mlir::TF::WhileOp op, const std::vector& operands, - const std::vector& results); - - // Build while operator where cond & body are regions. - Optional> BuildWhileOperator( - mlir::TFL::WhileOp op, const std::vector& operands, - const std::vector& results); - - // Builds custom operators. - // Templated on a) data type of custom_option to be stored into flatbuffer, - // and b) TFL custom op type. - template - BufferOffset BuildCustomOperator( - const CustomOptionType& custom_option, const std::string& opcode_name, - TFLOp op, const std::vector& operands, - const std::vector& results); - - BufferOffset BuildNumericVerifyOperator( - mlir::TFL::NumericVerifyOp op, const std::vector& operands, - const std::vector& results); - Optional> - BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxUnpooling2DOperator( - Operation* inst, mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results); - - Optional CreateFlexOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - Optional CreateCustomOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - std::unique_ptr CreateFlexBuilderWithNodeAttrs( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - // Returns opcode index for op identified by the op_name, if already - // available. Otherwise, creates a new OperatorCode using the given `builtin` - // operator and associates it with `op_name`. - uint32_t GetOpcodeIndex(const std::string& op_name, - tflite::BuiltinOperator builtin); - - // Builds operator for the given operation with specified operand and result - // tensor indices. Emits an error and returns llvm::None on failure. - Optional> BuildOperator( - Operation* inst, const std::vector& operands, - const std::vector& results, - const std::vector& intermediates); - - // Build a subgraph with a given name out of the region either corresponding - // to a function's body or while op. - Optional> BuildSubGraph( - const std::string& name, Region* region); - - // Builds Metadata with the given `name` and buffer `content`. - BufferOffset BuildMetadata(StringRef name, - StringRef content); - - // Encodes the `tfl.metadata` dictionary attribute of the module to the - // metadata section in the final model. - Optional>> - CreateMetadataVector(); - - // Uses the tf.entry_function attribute (if set) to initialize the op to name - // mapping. - void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); - - // Determines if the specified operation op's operand at operand_index - // is marked as a stateful operand. - bool IsStatefulOperand(mlir::Operation* op, int operand_index); - - // Returns a unique name for `val`. - std::string UniqueName(mlir::Value val); - - ModuleOp module_; - - tensorflow::OpOrArgNameMapper& name_mapper_; - - flatbuffers::FlatBufferBuilder builder_; - BufferOffset empty_buffer_; - - std::vector> buffers_; - - // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. - absl::flat_hash_map opcode_index_map_; - std::vector> opcodes_; - - // Maps function name to index of the corresponding subgraph in the FlatBuffer - // model. - absl::flat_hash_map subgraph_index_map_; - absl::flat_hash_set enabled_op_types_; - - // Points to TensorFlow and TFLite dialects, respectively. nullptr if the - // dialect is not registered. - const Dialect* tf_dialect_; - const Dialect* tfl_dialect_; - - // The failed ops during legalization. - std::set failed_flex_ops_; - std::set failed_custom_ops_; -}; - -std::string Translator::UniqueName(mlir::Value val) { - return std::string(name_mapper_.GetUniqueName(val)); + if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) { + return emitError(loc, "parsing output array info failed ") + << output_arrays_flag, + nullptr; + } + return tflite::FlatBufferToMlir( + absl::string_view(input->getBufferStart(), input->getBufferSize()), + context, loc, use_external_constant, inputs, outputs, + experimental_prune_unreachable_nodes_unconditionally); } -Optional> Translator::BuildBuffer( - Operation* inst) { - ElementsAttr attr; - if (auto cst = dyn_cast(inst)) { - // ConstantOp have ElementAttr at this point due to validation of the TFLite - // module. - attr = cst.getValue().cast(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else { - return empty_buffer_; - } - - tensorflow::Tensor tensor; - auto status = tensorflow::ConvertToTensor(attr, &tensor); - if (!status.ok()) { - inst->emitError( - Twine("failed to convert value attribute to tensor with error: " + - status.ToString())); - return llvm::None; - } - - // TensorFlow and TensorFlow Lite use different string encoding formats. - // Convert to TensorFlow Lite format is it's a constant string tensor. - if (tensor.dtype() == tensorflow::DT_STRING) { - ::tflite::DynamicBuffer dynamic_buffer; - auto flat = tensor.flat<::tensorflow::tstring>(); - for (int i = 0; i < flat.size(); ++i) { - const auto& str = flat(i); - dynamic_buffer.AddString(str.c_str(), str.length()); - } - char* tensor_buffer; - int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); - auto buffer_data = - builder_.CreateVector(reinterpret_cast(tensor_buffer), bytes); - free(tensor_buffer); - return tflite::CreateBuffer(builder_, buffer_data); - } - - absl::string_view tensor_data = tensor.tensor_data(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(tensor_data.data()), tensor_data.size()); - return tflite::CreateBuffer(builder_, buffer_data); -} - -Optional> Translator::BuildTensorFromType( - mlir::Type type, const std::string& name) { - auto tensor_type = type.cast(); - - if (!tensor_type.hasStaticShape()) { - return llvm::None; - } - llvm::ArrayRef shape_ref = tensor_type.getShape(); - std::vector shape(shape_ref.begin(), shape_ref.end()); - - auto element_type = tensor_type.getElementType(); - tflite::TensorType tflite_element_type = - GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); - BufferOffset q_params; - auto qtype = element_type.dyn_cast(); - if (!qtype) { - return llvm::None; - } - q_params = tflite::CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, - builder_.CreateVector({static_cast(qtype.getScale())}), - builder_.CreateVector({qtype.getZeroPoint()})); - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - /*buffer=*/0, builder_.CreateString(name), q_params, - /*is_variable=*/false); -} - -Optional> Translator::BuildTensor( - Value value, const std::string& name, unsigned buffer_idx) { - auto type = value.getType().cast(); - - // TFLite requires tensor shape only for the inputs and constants. - // However, we output all known shapes for better round-tripping - auto check_shape = - [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { - auto is_out_of_range = [](int64_t dim) { - return dim > std::numeric_limits::max(); - }; - - if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) - return mlir::emitError( - value.getLoc(), - "result shape dimensions out of 32 bit int type range"); - - return mlir::success(); - }; - - std::vector shape; - std::vector shape_signature; - if (type.hasStaticShape()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (auto* inst = value.getDefiningOp()) { - if (IsConst(inst)) { - // Const op can have a result of dynamic shaped type (e.g. due to constant - // folding), but we can still derive the shape of a constant tensor for - // its attribute type. - mlir::Attribute tensor_attr = inst->getAttr("value"); - llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } - } else if (type.hasRank()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape.reserve(shape_ref.size()); - for (auto& dim : shape_ref) { - shape.push_back(dim == -1 ? 1 : dim); - } - shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); - } - - if (auto* inst = value.getDefiningOp()) { - if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); - } else if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); - } - } - - Type element_type = type.getElementType(); - tflite::TensorType tflite_element_type = - GetTFLiteType(type.getElementType()).ValueOrDie(); - - BufferOffset q_params; - if (auto qtype = element_type.dyn_cast()) { - q_params = tflite::CreateQuantizationParameters( - // TODO(fengliuai): min and max values are not stored in the - // quantized type, so both are set to 0. The model couldn't be imported - // to TensorFlow because of this. - builder_, /*min=*/0, /*max=*/0, - builder_.CreateVector({static_cast(qtype.getScale())}), - builder_.CreateVector({qtype.getZeroPoint()})); - } else if (auto qtype = - element_type - .dyn_cast()) { - std::vector scales(qtype.getScales().begin(), - qtype.getScales().end()); - q_params = tflite::CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), - builder_.CreateVector(qtype.getZeroPoints()), - tflite::QuantizationDetails_NONE, /*details=*/0, - qtype.getQuantizedDimension()); - } else { - q_params = tflite::CreateQuantizationParameters(builder_); - } - // Check if the value's uses includes an op and usage at an operand index - // marked as a stateful. If so, set the tensor's is_variable as true - // This is v1 ref variable semantics in the TFLite runtime. - bool is_variable = false; - for (auto& use : value.getUses()) { - is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); - if (is_variable) { - break; - } - } - - if (shape_signature.empty()) { - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable); - } else { - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable, /*sparsity=*/0, - /*shape_signature=*/builder_.CreateVector(shape_signature)); - } -} - -BufferOffset Translator::BuildIfOperator( - mlir::TF::IfOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); - int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); - int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); - auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, - else_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_IfOptions, - builtin_options); -} - -BufferOffset Translator::BuildWhileOperator( - mlir::TF::WhileOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); - int body_subgraph_index = subgraph_index_map_.at(op.body().str()); - auto builtin_options = tflite::CreateWhileOptions( - builder_, cond_subgraph_index, body_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_WhileOptions, - builtin_options); -} - -Optional> Translator::BuildWhileOperator( - mlir::TFL::WhileOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - auto get_call_index = [&](mlir::Block& b) -> Optional { - if (b.getOperations().size() != 2) return llvm::None; - if (auto call_op = dyn_cast(b.front())) - return subgraph_index_map_.at(call_op.callee().str()); - return llvm::None; - }; - auto body_subgraph_index = get_call_index(op.body().front()); - auto cond_subgraph_index = get_call_index(op.cond().front()); - if (!body_subgraph_index || !cond_subgraph_index) - return op.emitOpError("only single call cond/body while export supported"), - llvm::None; - auto builtin_options = - tflite::CreateWhileOptions(builder_, *cond_subgraph_index, - *body_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_WhileOptions, - builtin_options); -} - -template -BufferOffset Translator::BuildCustomOperator( - const CustomOptionType& custom_option, const std::string& opcode_name, - TFLOp op, const std::vector& operands, - const std::vector& results) { - std::vector custom_option_vector(sizeof(CustomOptionType)); - memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType)); - auto opcode_index = - GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM); - return tflite::CreateOperator( - builder_, opcode_index, builder_.CreateVector(operands), - builder_.CreateVector(results), tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - builder_.CreateVector(custom_option_vector), - tflite::CustomOptionsFormat_FLEXBUFFERS); -} - -BufferOffset Translator::BuildNumericVerifyOperator( - mlir::TFL::NumericVerifyOp op, const std::vector& operands, - const std::vector& results) { - float tolerance = op.tolerance().convertToFloat(); - return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); -} - -Optional> -Translator::BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, const std::vector& results) { - TfLiteTransposeConvParams conv_params; - conv_params.stride_height = op.stride_h().getSExtValue(); - conv_params.stride_width = op.stride_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - conv_params.padding = *padding; - return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxUnpooling2DOperator(Operation* inst, - mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, - results); - } - - return llvm::None; -} - -Optional Translator::CreateFlexOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } - - auto flex_builder = absl::make_unique(); - flex_builder->Vector([&]() { - flex_builder->String(node_def.op()); - flex_builder->String(node_def_str); - }); - flex_builder->Finish(); - return builder_.CreateVector(flex_builder->GetBuffer()); -} - -Optional Translator::CreateCustomOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } - auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); - return builder_.CreateVector(flex_builder->GetBuffer()); -} - -std::unique_ptr -Translator::CreateFlexBuilderWithNodeAttrs( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - auto flex_builder = absl::make_unique(); - size_t map_start = flex_builder->StartMap(); - for (const auto& pair : node_def.attr()) { - const char* key = pair.first.c_str(); - const auto& attr = pair.second; - switch (attr.value_case()) { - case ::tensorflow::AttrValue::kS: - flex_builder->String(key, attr.s()); - break; - case ::tensorflow::AttrValue::kType: { - auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); - if (status_or_tfl_type.ok()) { - flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); - } else { - emitWarning(loc, "ignoring unsupported tensorflow type: ") - << std::to_string(attr.type()); - } - break; - } - case ::tensorflow::AttrValue::kI: - flex_builder->Int(key, attr.i()); - break; - case ::tensorflow::AttrValue::kF: - flex_builder->Float(key, attr.f()); - break; - case ::tensorflow::AttrValue::kB: - flex_builder->Bool(key, attr.b()); - break; - case tensorflow::AttrValue::kList: - if (attr.list().s_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const std::string& v : attr.list().s()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else if (attr.list().i_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const int64_t v : attr.list().i()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else if (attr.list().f_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const float v : attr.list().f()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else { - emitWarning(loc, - "ignoring unsupported type in list attribute with key: ") - << key; - } - break; - default: - emitWarning(loc, "ignoring unsupported attribute type with key: ") - << key; - break; - } - } - flex_builder->EndMap(map_start); - flex_builder->Finish(); - return flex_builder; -} - -uint32_t Translator::GetOpcodeIndex(const std::string& op_name, - tflite::BuiltinOperator builtin) { - auto it = opcode_index_map_.insert({op_name, 0}); - - // If the insert succeeded, the opcode has not been created already. Create a - // new operator code and update its index value in the map. - if (it.second) { - it.first->second = opcodes_.size(); - auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM - ? builder_.CreateString(op_name) - : BufferOffset(); - // Use version 0 for builtin op. This is a way to serialize version field to - // flatbuffer (since 0 is non default) and it will be corrected later. - int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; - opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, - custom_code, op_version)); - } - return it.first->second; -} - -Optional> Translator::BuildOperator( - Operation* inst, const std::vector& operands, - const std::vector& results, - const std::vector& intermediates) { - const auto* dialect = inst->getDialect(); - if (!dialect) { - inst->emitOpError("dialect is not registered"); - return llvm::None; - } - - // If TFLite built in op, create operator as a builtin op. - if (dialect == tfl_dialect_) { - // Only if built-in TFLite op emission is enabled, would legalization have - // converted any TF->TFL. - if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { - return inst->emitOpError( - "is a TFLite builtin op but builtin emission is not enabled"), - llvm::None; - } - - auto builtin_code = GetBuiltinOpCode(inst); - if (!builtin_code) { - if (auto verify_op = dyn_cast(inst)) { - return BuildNumericVerifyOperator(verify_op, operands, results); - } - if (auto conv_transpose_bias_op = - dyn_cast(inst)) { - return BuildConvolution2DTransposeBiasOperator( - inst, conv_transpose_bias_op, operands, results); - } - if (auto max_pooling_with_arg_max_op = - dyn_cast(inst)) { - return BuildMaxPoolingWithArgMax2DOperator( - inst, max_pooling_with_arg_max_op, operands, results); - } - if (auto max_unpooling_op = dyn_cast(inst)) { - return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, - results); - } - if (auto whileOp = dyn_cast(inst)) { - if (inst->getNumOperands() != inst->getNumResults()) { - inst->emitOpError( - "number of operands and results don't match, only canonical " - "TFL While supported"); - return llvm::None; - } - return BuildWhileOperator(whileOp, operands, results); - } - - inst->emitOpError("is not a supported TFLite op"); - return llvm::None; - } - - std::string op_name = inst->getName().getStringRef().str(); - uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); - auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, - results, intermediates, &builder_); - if (!offset) { - inst->emitOpError("is not a supported TFLite op"); - } - return offset; - } - - if (dialect == tf_dialect_) { - std::string op_name; - if (auto ifOp = dyn_cast(inst)) { - return BuildIfOperator(ifOp, operands, results); - } else if (auto whileOp = dyn_cast(inst)) { - return BuildWhileOperator(whileOp, operands, results); - } - - CustomOptionsOffset custom_options; - - // Ops in TF dialect can either be custom ops or flex ops. - // The reason we go directly from TensorFlow dialect MLIR to tensorflow - // node instead of going to TF table gen'd ops via generated code is that - // we do not want to restrict custom and flex op conversion support to - // only those TF ops that are currently registered in MLIR. The current - // model is of an open op system. - // - // The following algorithm is followed: - // if flex is enabled and the op is whitelisted as flex - // we emit op as flex. - // if custom is enabled - // we emit the op as custom. - auto node_def = GetTensorFlowNodeDef(inst); - if (!node_def) { - return llvm::None; - } - - // Flex op case - // Eventually, the whitelist will go away and we will rely on some TF op - // trait (e.g. No side effect) to determine if it is a supported "Flex" - // op or not. - if (enabled_op_types_.contains(OpType::kSelectTf) && - IsWhitelistedFlexOp(node_def->op())) { - // Construct ops as flex op encoding TensorFlow node definition - // as custom options. - // Flex ops are named with the kFlexOpNamePrefix prefix to the actual - // TF op name. - op_name = std::string(kFlexOpNamePrefix) + node_def->op(); - if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else if (enabled_op_types_.contains(OpType::kCustomOp)) { - // Generic case of custom ops - write using flex buffers since that - // is the only custom options supported by TFLite today. - op_name = node_def->op(); - if (auto options = - CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else { - // Create description of operation that could not be converted. - const int kLargeElementsAttr = 16; - std::string op_str; - llvm::raw_string_ostream os(op_str); - inst->getName().print(os); - // Print out attributes except for large elementsattributes (which should - // rarely be the cause why the legalization didn't happen). - if (!inst->getAttrList().getAttrs().empty()) { - os << " {"; - bool first = true; - for (auto& named_attr : inst->getAttrList().getDictionary()) { - os << (!first ? ", " : ""); - first = false; - named_attr.first.print(os); - os << " = "; - if (auto element_attr = named_attr.second.dyn_cast()) { - if (element_attr.getNumElements() <= kLargeElementsAttr) { - element_attr.print(os); - } else { - os << ""; - } - } else { - named_attr.second.print(os); - } - } - os << "}"; - } - - // Insert failed op to `flex_ops` or `custom_ops`. - if (IsWhitelistedFlexOp(node_def->op())) { - failed_flex_ops_.insert(os.str()); - } else { - failed_custom_ops_.insert(os.str()); - } - return inst->emitOpError("is neither a custom op nor a flex op"), - llvm::None; - } - - uint32_t opcode_index = - GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - /*custom_options=*/custom_options, - tflite::CustomOptionsFormat_FLEXBUFFERS, - /*mutating_variable_inputs=*/0); - } - - return inst->emitOpError( - "is not any of a builtin TFLite op, a flex TensorFlow op or a " - "custom TensorFlow op"), - llvm::None; -} - -void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { - auto dict_attr = fn.getAttrOfType("tf.entry_function"); - if (!dict_attr) return; - - llvm::SmallVector input_names; - llvm::SmallVector output_names; - if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { - str.getValue().split(input_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - if (input_names.size() != fn.getNumArguments()) { - fn.emitWarning() << "invalid entry function specification"; - return; - } - for (auto it : llvm::enumerate(fn.getArguments())) { - name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); - } - *has_input_attr = true; - } - - if (auto str = - dict_attr.get("outputs").dyn_cast_or_null()) { - str.getValue().split(output_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - auto term = fn.getBlocks().back().getTerminator(); - if (output_names.size() != term->getNumOperands()) { - fn.emitWarning() << "output names (" << output_names.size() - << ") != terminator operands (" << term->getNumOperands() - << ")"; - return; - } - for (const auto& it : llvm::enumerate(term->getOperands())) { - name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); - } - } -} - -bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { - std::vector operand_indices; - if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; - return absl::c_find(operand_indices, operand_index) != operand_indices.end(); -} - -Optional> Translator::BuildSubGraph( - const std::string& name, Region* region) { - bool has_input_attr = false; - if (auto fn = dyn_cast(region->getParentOp())) { - InitializeNamesFromAttribute(fn, &has_input_attr); - } - std::vector> tensors; - llvm::DenseMap tensor_index_map; - - // Builds tensor and buffer for argument or operation result. Returns false - // on failure. - auto build_tensor_and_buffer = [&](Value value, const std::string& name) { - // NoneType represents optional and may be skipped here. - if (value.getType().isa()) { - return true; - } - - tensor_index_map.insert({value, tensors.size()}); - auto tensor_or = BuildTensor(value, name, buffers_.size()); - if (!tensor_or) return false; - tensors.push_back(*tensor_or); - - // TODO(ashwinm): Check if for stateful tensors, if it is also needed to - // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. - // This does not seem to affect runtime behavior for RNN/LSTM, but would be - // good for reducing memory footprint. - if (auto* inst = value.getDefiningOp()) { - auto buffer_or = BuildBuffer(inst); - if (!buffer_or) return false; - buffers_.push_back(*buffer_or); - } else { - buffers_.push_back(empty_buffer_); - } - return true; - }; - - std::vector> operators; - auto& bb = region->front(); - - // Main function's arguments are first passed to `input` op so they don't - // have associated tensor and buffer. Build FlatBuffer tensor and buffer for - // other functions. - for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { - mlir::BlockArgument arg = bb.getArgument(i); - std::string name; - if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg)); - if (name.empty()) name = absl::StrCat("arg", i); - if (!build_tensor_and_buffer(arg, name)) return llvm::None; - } - - bool failed_once = false; - for (auto& inst : bb) { - if (inst.isKnownTerminator()) break; - std::vector intermediates; - // Build intermediate tensors for tfl.lstm and insert these tensors into - // flatbuffer. - if (llvm::isa(inst)) { - std::vector intermediate_names = { - "input_to_input_intermediate", "input_to_forget_intermediate", - "input_to_cell_intermediate", "input_to_output_intermediate", - "effective_hidden_scale_intermediate"}; - for (const std::string& intermediate : intermediate_names) { - auto intermediate_attr = inst.getAttr(intermediate); - if (auto attr = intermediate_attr.dyn_cast_or_null()) { - Type qtype = attr.getValue(); - auto tensor_or = BuildTensorFromType( - qtype, name_mapper_.GetUniqueName(intermediate).str()); - if (!tensor_or.hasValue()) { - continue; - } else { - intermediates.push_back(tensors.size()); - tensors.push_back(tensor_or.getValue()); - } - } - } - } - - for (auto val : inst.getResults()) { - std::string name = UniqueName(val); - if (!build_tensor_and_buffer(val, name)) return llvm::None; - } - - // Skip constant ops as they don't represent a TFLite operator. - if (IsConst(&inst)) continue; - - // Fetch operand and result tensor indices. - std::vector operands; - operands.reserve(inst.getNumOperands()); - for (auto operand : inst.getOperands()) { - if (operand.getType().isa()) - operands.push_back(kTfLiteOptionalTensor); - else - operands.push_back(tensor_index_map.lookup(operand)); - } - std::vector results; - results.reserve(inst.getNumOperands()); - for (auto result : inst.getResults()) { - results.push_back(tensor_index_map.lookup(result)); - } - - if (auto tfl_operator = - BuildOperator(&inst, operands, results, intermediates)) - operators.push_back(*tfl_operator); - else - failed_once = true; - } - - if (failed_once) return llvm::None; - - // Get input and output tensor indices for the subgraph. - std::vector inputs, outputs; - for (auto arg : bb.getArguments()) { - inputs.push_back(tensor_index_map[arg]); - } - for (auto result : bb.getTerminator()->getOperands()) { - outputs.push_back(tensor_index_map[result]); - } - - return tflite::CreateSubGraph( - builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), - builder_.CreateVector(outputs), builder_.CreateVector(operators), - /*name=*/builder_.CreateString(name)); -} - -BufferOffset Translator::BuildMetadata(StringRef name, - StringRef content) { - auto buffer_index = buffers_.size(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(content.data()), content.size()); - buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); - return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); -} - -Optional>> -Translator::CreateMetadataVector() { - auto dict_attr = module_.getAttrOfType("tfl.metadata"); - std::vector> metadata; - if (dict_attr) { - for (const auto& named_attr : dict_attr) { - StringRef name = named_attr.first; - mlir::Attribute attr = named_attr.second; - if (auto content = attr.dyn_cast()) { - metadata.push_back(BuildMetadata(name, content.getValue())); - } else { - module_.emitError( - "all values in tfl.metadata's dictionary key-value pairs should be " - "string attributes"); - return llvm::None; - } - } - } - // Runtime version string is generated after we update the op - // versions. Here we put a 16-byte dummy string as a placeholder. We choose - // 16-byte because it's the alignment of buffers in flatbuffer, so it won't - // cause any waste of space if the actual string is shorter than 16 bytes. - metadata.push_back( - BuildMetadata("min_runtime_version", std::string(16, '\0'))); - return builder_.CreateVector(metadata); -} - -bool UpdateEntryFunction(ModuleOp module) { - if (module.lookupSymbol("main") != nullptr) { - // We already have an entry function. - return true; - } - - int entry_func_count = 0; - FuncOp entry_func = nullptr; - for (auto fn : module.getOps()) { - auto attrs = fn.getAttrOfType("tf.entry_function"); - if (attrs && !attrs.empty()) { - entry_func_count++; - entry_func = fn; - } - } - - // We should have one & only have one entry function. - if (entry_func_count != 1) return false; - - // Update the entry func to main. - entry_func.setName("main"); - return true; -} - -Optional Translator::Translate( - ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { - if (!UpdateEntryFunction(module)) return llvm::None; - if (!IsValidTFLiteMlirModule(module)) return llvm::None; - Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - return translator.TranslateInternal(); -} - -Optional Translator::TranslateInternal() { - // A list of named regions in the module with main function being the first in - // the list. The main function is required as the first subgraph in the model - // is entry point for the model. - std::vector> named_regions; - named_regions.reserve(std::distance(module_.begin(), module_.end())); - - int subgraph_idx = 0; - FuncOp main_fn = module_.lookupSymbol("main"); - subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back("main", &main_fn.getBody()); - // Walk over the module collection ops with functions and while ops. - module_.walk([&](FuncOp fn) { - if (fn != main_fn) { - subgraph_index_map_[fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back(fn.getName().str(), &fn.getBody()); - } - }); - - // Build subgraph for each of the named regions. - std::vector> subgraphs; - subgraphs.reserve(named_regions.size()); - int first_failed_func = -1; - for (auto it : llvm::enumerate(named_regions)) { - auto subgraph_or = BuildSubGraph(it.value().first, it.value().second); - if (!subgraph_or) { - if (first_failed_func == -1) - // Record the index of the first region that cannot be converted. - // Keep looping through all subgraphs in the module to make sure that - // we collect the list of missing ops from the entire module. - first_failed_func = it.index(); - } else { - subgraphs.push_back(*subgraph_or); - } - } - - if (first_failed_func != -1) { - std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); - std::string failed_custom_ops_list = - absl::StrJoin(failed_custom_ops_, "\n\t"); - std::string err; - if (!failed_flex_ops_list.empty()) - err += - "Ops that can be supported by the flex runtime (enabled via setting " - "the -emit-select-tf-ops flag):\n\t" + - failed_flex_ops_list; - if (!failed_custom_ops_list.empty()) - err += - "Ops that need custom implementation (enabled via setting the " - "-emit-custom-ops flag):\n\t" + - failed_custom_ops_list; - - auto& failed_region = named_regions[first_failed_func]; - return failed_region.second->getParentOp()->emitError() - << "failed while converting: '" << failed_region.first - << "': " << err, - llvm::None; - } - - std::string model_description; - if (auto attr = module_.getAttrOfType("tfl.description")) { - model_description = attr.getValue().str(); - } else { - model_description = "MLIR Converted."; - } - - // Build the model and finish the model building process. - auto description = builder_.CreateString(model_description.data()); - VectorBufferOffset metadata_buffer = 0; // Deprecated - auto metadata = CreateMetadataVector(); - if (!metadata) return llvm::None; - - auto model = tflite::CreateModel( - builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), - builder_.CreateVector(subgraphs), description, - builder_.CreateVector(buffers_), metadata_buffer, *metadata); - tflite::FinishModelBuffer(builder_, model); - tflite::UpdateOpVersion(builder_.GetBufferPointer()); - tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); - - // Return serialized string for the built FlatBuffer. - return std::string(reinterpret_cast(builder_.GetBufferPointer()), - builder_.GetSize()); -} - -} // namespace - -// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer -// format. Returns false on success. -// -// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting -// the following: -// -// * Quantization -// * Ops with variable tensors -// -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) { - auto maybe_translated = - Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - if (!maybe_translated) return true; - *serialized_flatbuffer = std::move(*maybe_translated); - return false; -} - -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops) { - OpOrArgLocNameMapper op_or_arg_name_mapper; - return MlirToFlatBufferTranslateFunction( - module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); -} - -static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction( +static LogicalResult MlirToFlatBufferFileTranslateFunction( ModuleOp module, llvm::raw_ostream& output) { std::string serialized_flatbuffer; - std::unique_ptr op_or_arg_name_mapper; + std::unique_ptr op_or_arg_name_mapper; if (strip_debug_info) { op_or_arg_name_mapper = std::make_unique(); } else { - op_or_arg_name_mapper = std::make_unique(); + op_or_arg_name_mapper = + std::make_unique(); } if (tflite::MlirToFlatBufferTranslateFunction( module, &serialized_flatbuffer, emit_builtin_tflite_ops, @@ -1511,8 +162,18 @@ static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction( return mlir::failure(); output << serialized_flatbuffer; - return mlir::success(); + return success(); } +} // namespace + +static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( + "tflite-flatbuffer-to-mlir", + [](llvm::SourceMgr& source_mgr, MLIRContext* context) { + return FlatBufferFileToMlirTrans( + &source_mgr, context, use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); + }); static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate( "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction); +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index 6f8292308a4..d17215566a1 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -34,8 +34,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/Parser.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/delegates/flex/delegate.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 2f677397109..7557ff5223c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index c05337918f2..f04dc9c2961 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 74e48cd6d91..bb82988def1 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -28,8 +28,8 @@ limitations under the License. #include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index b05dcaadab2..1ba0c025613 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "mlir/Transforms/Passes.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 1cc26c9bb4d..782102510fa 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -811,7 +811,8 @@ cc_library( srcs = ["utils/error_util.cc"], hdrs = ["utils/error_util.h"], deps = [ - "//tensorflow/core:lib", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", ], diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc index 60646ae764e..5514a788996 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index 7eb30ee2c46..1bc0a23e359 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -21,7 +21,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" // Error utilities for MLIR when interacting with code using Status returns. namespace mlir { From 1409428ce2afc796dbaae252cdb476d22e08bfe4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Mar 2020 21:45:47 -0700 Subject: [PATCH 383/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302255756 Change-Id: I761ed621bf3f6c3a356dd1c50bc49985b04b5eb5 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 75d86f71b78..68bb1dc49f5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 40675b3af0f18a9cc5175fcabb2ee129909a3b63 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Mar 2020 23:45:45 -0700 Subject: [PATCH 384/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302263587 Change-Id: I5acdf713ec8a038331b1086a06ee1d92861a62b6 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 68bb1dc49f5..75d86f71b78 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From b65dae5e5fdff80eaa21384d56c47dea0f21a05b Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Sun, 22 Mar 2020 00:01:18 -0700 Subject: [PATCH 385/492] Fix nightly smoke test after CL/301880779 PiperOrigin-RevId: 302264423 Change-Id: If874506900e29cac7cbc457e7cc5e6a1cc0f5852 --- .../tools/ci_build/builds/nightly_release_smoke_test.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh b/tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh index d696e08d790..93a1888571e 100644 --- a/tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh +++ b/tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh @@ -54,13 +54,15 @@ function test_tf_imports() { # test for basic import and perform tf.add operation. RET_VAL=$(python -c "import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)") if ! [[ ${RET_VAL} == *'(4,)'* ]]; then + echo "Unexpected return value: ${RET_VALUE}" echo "PIP test on virtualenv FAILED, will not upload ${WHL_NAME} package." return 1 fi # test basic keras is available RET_VAL=$(python -c "import tensorflow as tf; print(tf.keras.__name__)") - if ! [[ ${RET_VAL} == *'tensorflow.python.keras.api._v2.keras'* ]]; then + if ! [[ ${RET_VAL} == *'tensorflow.keras'* ]]; then + echo "Unexpected return value: ${RET_VALUE}" echo "PIP test on virtualenv FAILED, will not upload ${WHL_NAME} package." return 1 fi @@ -68,6 +70,7 @@ function test_tf_imports() { # similar test for estimator RET_VAL=$(python -c "import tensorflow as tf; print(tf.estimator.__name__)") if ! [[ ${RET_VAL} == *'tensorflow_estimator.python.estimator.api._v2.estimator'* ]]; then + echo "Unexpected return value: ${RET_VALUE}" echo "PIP test on virtualenv FAILED, will not upload ${WHL_NAME} package." return 1 fi From 1d961d8b58853480a30302af36fe17ee0b2a2a04 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Mar 2020 02:02:35 -0700 Subject: [PATCH 386/492] compat: Update forward compatibility horizon to 2020-03-22 PiperOrigin-RevId: 302273488 Change-Id: I12659716ef138464820fa96b168353909ae07567 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 288c0670968..81a7d03f110 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 21) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 22) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 313afe61fa9d74bb1ab723f06af3c521dbeb5bb7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Mar 2020 09:45:53 -0700 Subject: [PATCH 387/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302304898 Change-Id: I298e29b9ee87670aebe0ea714059f26c90657a7e --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 75d86f71b78..68bb1dc49f5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From f7b6793c6611210405d066dde84cc67adca4097c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Mar 2020 11:46:05 -0700 Subject: [PATCH 388/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302312995 Change-Id: Id324592b0835289d492feabb80d3bc71d8488e56 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 68bb1dc49f5..75d86f71b78 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 8a26aad292b4ad0cc8468358d2aff9f4e488008b Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 22 Mar 2020 14:40:27 -0500 Subject: [PATCH 389/492] fix TFE_OpReset doesn't clear inputs. --- tensorflow/c/eager/c_api_experimental.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index afa36fe1210..9d491c72f38 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -31,6 +31,7 @@ using tensorflow::string; void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { if (op_to_reset) { + op_to_reset->operation->Clear(); status->status = op_to_reset->operation->Reset(op_or_function_name, raw_device_name); } else { From 3ebd0c683efa922c3c3988c51856cfbf5b37c8b7 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Sun, 22 Mar 2020 13:51:11 -0700 Subject: [PATCH 390/492] [Executor] Restructure `ActivateNodes()` to avoid unnecessary loads and branches. The basic idea of the optimization is to avoid reading each destination `NodeItem` when propagating values/control signals from the outputs of a node. To do that, we make the following changes: 1. Avoid building executor structures for the sink node, which is a no-op anyway. This avoids the need to compare the endpoint of every edge to the sink node, and removes the `NodeItem::is_sink` bit. 2. Optimize the common case when all consumers of a particular op are neither merge nor control trigger nodes. We record this fact in the source `NodeItem`, using the bit we saved in (1). 3. Repurpose the `EdgeInput::input_slot` field so that it is an offset directly into the vector of input tensors, rather than a relative offset from the destination `NodeItem::input_start`. 4. Move `NodeItem::pending_id` into a dense vector of `PendingCounts::Handle` values owned by the `ExecutorImpl`. This is a more compact and cache-friendly structure for `ActivateNodes()` to access in the common case. 5. Modify `PendingCounts::adjust_for_activations()` to return its value in a register, instead of using output parameters. PiperOrigin-RevId: 302321563 Change-Id: If1b7feee54705c403fe27d48552545c00cd98601 --- tensorflow/core/common_runtime/executor.cc | 212 +++++++++++++----- .../core/common_runtime/executor_test.cc | 3 + .../core/common_runtime/pending_counts.h | 33 ++- .../common_runtime/pending_counts_test.cc | 13 +- 4 files changed, 187 insertions(+), 74 deletions(-) diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index f3cf11b274f..13c07dc0ec1 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -162,7 +162,6 @@ struct NodeItem { bool is_exit : 1; // True iff IsExit(node) bool is_control_trigger : 1; // True iff IsControlTrigger(node) bool is_source : 1; // True iff IsSource(node) - bool is_sink : 1; // True iff IsSink(node) // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) bool is_enter_exit_or_next_iter : 1; bool is_transfer_node : 1; // True iff IsTransferNode(node) @@ -170,6 +169,11 @@ struct NodeItem { bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) bool is_next_iteration : 1; // True iff IsNextIteration(node) bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp") + bool + is_any_consumer_merge_or_control_trigger : 1; // True iff the destination + // of any output edge is a + // merge or control trigger + // node. // The kernel for this node. OpKernel* kernel = nullptr; @@ -186,8 +190,6 @@ struct NodeItem { // for this node. int input_start = 0; - PendingCounts::Handle pending_id; - // Number of output edges. size_t num_output_edges; @@ -203,6 +205,11 @@ struct NodeItem { DCHECK_LT(i, num_output_edges); return output_edge_base()[i]; } + EdgeInfo& output_edge(int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, num_output_edges); + return output_edge_base()[i]; + } DataType input_type(int i) const { DCHECK_LT(i, num_inputs); @@ -227,8 +234,6 @@ struct NodeItem { string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id); if (is_source) { strings::StrAppend(&ret, " source}"); - } else if (is_sink) { - strings::StrAppend(&ret, " sink}"); } else { strings::StrAppend(&ret, " def:{", SummarizeNodeDef(kernel->def()), "}}"); } @@ -357,7 +362,8 @@ class ExecutorImpl : public Executor { absl::make_unique(gview.num_nodes()); for (int32 i = 0; i < gview.num_nodes(); ++i) { if (gview.node(i)) { - is_expensive_[i] = gview.node(i)->kernel->IsExpensive(); + is_expensive_[i] = + gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive(); cost_estimates_[i] = kInitialCostEstimateCycles; } } @@ -458,6 +464,7 @@ class ExecutorImpl : public Executor { // Owned. LocalExecutorParams params_; GraphView gview_; + std::vector pending_ids_; mutable KernelStats kernel_stats_; // Root nodes (with no in edges) that should form the initial ready queue @@ -500,7 +507,10 @@ GraphView::~GraphView() { } size_t GraphView::NodeItemBytes(const Node* n) { - const size_t num_output_edges = n->out_edges().size(); + size_t num_output_edges = n->out_edges().size(); + for (auto e : n->out_edges()) { + if (IsSink(e->dst())) --num_output_edges; + } const int num_inputs = n->num_inputs(); const int num_outputs = n->num_outputs(); @@ -551,7 +561,10 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { node_offsets_[id] = offset; ptr += bytes; - const size_t num_output_edges = n->out_edges().size(); + size_t num_output_edges = n->out_edges().size(); + for (auto e : n->out_edges()) { + if (IsSink(e->dst())) --num_output_edges; + } const int num_inputs = n->num_inputs(); const int num_outputs = n->num_outputs(); @@ -568,6 +581,7 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { gtl::InlinedVector last_indices(num_outputs, nullptr); EdgeInfo* dst_edge = item->output_edge_base(); for (auto e : n->out_edges()) { + if (IsSink(e->dst())) continue; dst_edge->dst_id = e->dst()->id(); CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits dst_edge->output_slot = e->src_output(); @@ -576,6 +590,8 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { if (output_slot >= 0) { last_indices[output_slot] = dst_edge; } + // NOTE: The `input_slot` will be rewritten to the frame-wide offset later + // in `ExecutorImpl::Initialize()`. dst_edge->input_slot = e->dst_input(); dst_edge++; } @@ -697,9 +713,12 @@ Status ExecutorImpl::Initialize(const Graph& graph) { EnsureFrameInfo(it)->nodes = new std::vector; } + pending_ids_.resize(gview_.num_nodes()); + // Preprocess every node in the graph to create an instance of op // kernel for each node. for (const Node* n : graph.nodes()) { + if (IsSink(n)) continue; const int id = n->id(); const string& frame_name = cf_info.frame_names[id]; FrameInfo* frame_info = EnsureFrameInfo(frame_name); @@ -719,6 +738,13 @@ Status ExecutorImpl::Initialize(const Graph& graph) { CHECK(item->kernel); item->kernel_is_async = (item->kernel->AsAsync() != nullptr); item->is_merge = IsMerge(n); + item->is_any_consumer_merge_or_control_trigger = false; + for (const Node* consumer : n->out_nodes()) { + if (IsMerge(consumer) || IsControlTrigger(consumer)) { + item->is_any_consumer_merge_or_control_trigger = true; + break; + } + } const Tensor* const_tensor = item->kernel->const_tensor(); if (const_tensor) { // Hold onto a shallow copy of the constant tensor in `*this` so that the @@ -740,7 +766,6 @@ Status ExecutorImpl::Initialize(const Graph& graph) { item->is_exit = IsExit(n); item->is_control_trigger = IsControlTrigger(n); item->is_source = IsSource(n); - item->is_sink = IsSink(n); item->is_enter_exit_or_next_iter = (IsEnter(n) || IsExit(n) || IsNextIteration(n)); item->is_transfer_node = IsTransferNode(n); @@ -754,7 +779,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) { // space to store these maximal count values. size_t max_pending, max_dead; GetMaxPendingCounts(n, &max_pending, &max_dead); - item->pending_id = + pending_ids_[id] = frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead); // See if this node is a root node, and if so, add item to root_nodes_. @@ -775,6 +800,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) { std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false); size_t unused_outputs = n->num_outputs(); for (const Edge* e : n->out_edges()) { + if (IsSink(e->dst())) continue; if (e->src_output() >= 0) { if (!outputs_required[e->src_output()]) { --unused_outputs; @@ -792,8 +818,24 @@ Status ExecutorImpl::Initialize(const Graph& graph) { } } - // Initialize PendingCounts only after item->pending_id is initialized for - // all nodes. + // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input + // location. + for (const Node* n : graph.nodes()) { + if (IsSink(n)) continue; + const int id = n->id(); + NodeItem* item = gview_.node(id); + + for (size_t out_index = 0; out_index < item->num_output_edges; + out_index++) { + EdgeInfo& e = item->output_edge(out_index); + const int dst_id = e.dst_id; + NodeItem* dst_item = gview_.node(dst_id); + e.input_slot += dst_item->input_start; + } + } + + // Initialize PendingCounts only after pending_ids_[node.id] is initialized + // for all nodes. InitializePending(&graph, cf_info); kernel_stats_.Initialize(gview_); return gview_.SetAllocAttrs(&graph, params_.device); @@ -825,7 +867,7 @@ void GraphView::SetScopedAllocatorAttrs( // Control edges out of the ScopedAllocator should be use instances, but may // include a few other nodes. for (const auto& e : sa->out_edges()) { - if (!e->IsControlEdge()) { + if (IsSink(e->dst()) || !e->IsControlEdge()) { continue; } Node* use_node = e->dst(); @@ -842,7 +884,7 @@ void GraphView::SetScopedAllocatorAttrs( // There can be more than one output using ScopedAllocation, but this // analysis assumes they use the same ScopedAllocator. for (const auto& e : use_node->out_edges()) { - if (!e->IsControlEdge()) { + if (IsSink(e->dst()) || !e->IsControlEdge()) { AllocatorAttributes attr; if (ExtractScopedAllocatorAttr(scoped_allocator_attrs, e->src_output(), &attr)) { @@ -1136,10 +1178,9 @@ class ExecutorState { void increment_dead_count(PendingCounts::Handle h) { counts.increment_dead_count(h); } - void adjust_for_activation(PendingCounts::Handle h, bool increment_dead, - int* pending_result, int* dead_result) { - counts.adjust_for_activation(h, increment_dead, pending_result, - dead_result); + PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h, + bool increment_dead) { + return counts.adjust_for_activation(h, increment_dead); } ~IterationState() { delete[] input_tensors; } @@ -1354,6 +1395,18 @@ class ExecutorState { iterations[i] = nullptr; } } + + private: + // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. + void ActivateNodesFastPath(const NodeItem* item, const bool is_dead, + int64 iter, EntryVector* outputs, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, + int64 iter, EntryVector* outputs, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); }; // A tagged node: . @@ -1533,7 +1586,7 @@ class ExecutorState { // For debugging/logging only. inline void MaybeMarkCompleted(FrameState* frame, int64 iter, - const NodeItem& item); + const int node_id); // Provide debugging output about an outstanding node in the executor. void DumpPendingNodeState(const int node_id, const Entry* input_vector, @@ -1660,6 +1713,7 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g, for (const Edge* out_edge : curr_node->out_edges()) { Node* out = out_edge->dst(); + if (IsSink(out)) continue; const int out_id = out->id(); // Add to ready queue if not visited. @@ -1688,13 +1742,13 @@ void ExecutorImpl::InitializePending(const Graph* graph, finfo->pending_counts = counts; } for (const Node* n : graph->nodes()) { + if (IsSink(n)) continue; const int id = n->id(); const string& name = cf_info.frame_names[id]; size_t max_pending, max_dead; GetMaxPendingCounts(n, &max_pending, &max_dead); - const NodeItem* item = gview_.node(id); PendingCounts* counts = EnsureFrameInfo(name)->pending_counts; - counts->set_initial_count(item->pending_id, max_pending); + counts->set_initial_count(pending_ids_[id], max_pending); } } @@ -1867,7 +1921,7 @@ void ExecutorState::ProcessAsync(const NodeItem& item, } FrameState* input_frame = state->tagged_node.input_frame; const int64 input_iter = state->tagged_node.input_iter; - MaybeMarkCompleted(input_frame, input_iter, *state->item); + MaybeMarkCompleted(input_frame, input_iter, state->item->node_id); TaggedNodeSeq ready; if (s.ok()) { PropagateOutputs(state->tagged_node, state->item, &outputs, &ready); @@ -1989,7 +2043,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { mutex_lock l(input_frame->mu); - input_frame->GetIteration(input_iter)->mark_started(item.pending_id); + input_frame->GetIteration(input_iter) + ->mark_started(impl_->pending_ids_[id]); } params.track_allocations = false; @@ -2035,7 +2090,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } - MaybeMarkCompleted(input_frame, input_iter, item); + MaybeMarkCompleted(input_frame, input_iter, id); // Continue to process the nodes in 'inline_ready'. completed = NodeDone(s, &ready, stats, &inline_ready); continue; @@ -2070,7 +2125,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } - MaybeMarkCompleted(input_frame, input_iter, item); + MaybeMarkCompleted(input_frame, input_iter, id); // Propagates outputs. if (s.ok()) { PropagateOutputs(tagged_node, &item, &outputs, &ready); @@ -2535,12 +2590,12 @@ void ExecutorState::ScheduleReady(TaggedNodeSeq* ready, } inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter, - const NodeItem& item) { + const int node_id) { // TODO(misard) Replace with a finer-grain enabling flag once we // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { mutex_lock l(frame->mu); - frame->GetIteration(iter)->mark_completed(item.pending_id); + frame->GetIteration(iter)->mark_completed(impl_->pending_ids_[node_id]); } } @@ -2615,7 +2670,7 @@ void ExecutorState::DumpIterationState(const FrameState* frame, const std::vector* nodes = frame->nodes; // Dump any waiting nodes that are holding on to tensors. for (const NodeItem* node : *nodes) { - PendingCounts::Handle pending_id = node->pending_id; + PendingCounts::Handle pending_id = impl_->pending_ids_[node->node_id]; if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { DumpPendingNodeState(node->node_id, iteration->input_tensors, false); @@ -2623,7 +2678,7 @@ void ExecutorState::DumpIterationState(const FrameState* frame, } // Then the active nodes. for (const NodeItem* node : *nodes) { - PendingCounts::Handle pending_id = node->pending_id; + PendingCounts::Handle pending_id = impl_->pending_ids_[node->node_id]; if (iteration->node_state(pending_id) == PendingCounts::STARTED) { DumpActiveNodeState(node->node_id, iteration->input_tensors); } @@ -2845,11 +2900,7 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { const EdgeInfo& e = item->output_edge(i); const NodeItem& dst_item = *impl_->gview_.node(e.dst_id); - const auto dst_pending_id = dst_item.pending_id; - - // TODO(yuanbyu): We don't need this if we require the subgraph - // given to an executor not to contain a sink node. - if (dst_item.is_sink) continue; + const auto dst_pending_id = impl_->pending_ids_[e.dst_id]; bool dst_dead = true; bool dst_ready = false; @@ -2912,26 +2963,77 @@ void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter, } } -void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, - const bool is_dead, int64 iter, - EntryVector* outputs, - TaggedNodeSeq* ready) { +void ExecutorState::FrameState::ActivateNodesFastPath(const NodeItem* item, + const bool is_dead, + int64 iter, + EntryVector* outputs, + TaggedNodeSeq* ready) { + // If we know that none of the item's edge destinations require special + // handling (i.e. none of the nodes is a merge or control trigger node), we + // can take a fast path that avoids accessing the destination NodeItem. const GraphView& gview = executor->gview_; IterationState* iter_state = GetIteration(iter); const size_t num_output_edges = item->num_output_edges; const EdgeInfo* edges = item->output_edge_list(); Entry* input_tensors = iter_state->input_tensors; + + for (size_t out_index = 0; out_index < num_output_edges; out_index++) { + const EdgeInfo& e = edges[out_index]; + const int dst_id = e.dst_id; + const PendingCounts::Handle dst_pending_id = executor->pending_ids_[dst_id]; + const int src_slot = e.output_slot; + + // True iff this input for dst is needed. We only set this input for + // dst if this flag is true. This is needed to make the thread safety + // analysis happy. + const bool is_control_edge = (src_slot == Graph::kControlSlot); + const bool dst_need_input = !is_control_edge; + + const bool increment_dead = + (is_dead || (!is_control_edge && + (*outputs)[src_slot].state == Entry::State::NO_VALUE)); + const PendingCounts::AdjustResult adjust_result = + iter_state->adjust_for_activation(dst_pending_id, increment_dead); + + if (dst_need_input) { + const int dst_loc = e.input_slot; + if (e.is_last) { + input_tensors[dst_loc] = std::move((*outputs)[src_slot]); + } else { + input_tensors[dst_loc] = (*outputs)[src_slot]; + } + } + + // Add dst to the ready queue if it's ready + if (!adjust_result.any_pending) { + const NodeItem* dst_item = gview.node(dst_id); + ready->emplace_back(dst_item, this, iter, adjust_result.any_dead); + iter_state->outstanding_ops++; + } + } +} + +void ExecutorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, + const bool is_dead, + int64 iter, + EntryVector* outputs, + TaggedNodeSeq* ready) { + // If any of the edge destinations is a merge or a control trigger node, + // we need to read each destination NodeItem to determine what action + // to take. + const GraphView& gview = executor->gview_; + IterationState* iter_state = GetIteration(iter); + const size_t num_output_edges = item->num_output_edges; + const EdgeInfo* edges = item->output_edge_list(); + Entry* input_tensors = iter_state->input_tensors; + for (size_t out_index = 0; out_index < num_output_edges; out_index++) { const EdgeInfo& e = edges[out_index]; const int dst_id = e.dst_id; const NodeItem* dst_item = gview.node(dst_id); - const PendingCounts::Handle dst_pending_id = dst_item->pending_id; + const PendingCounts::Handle dst_pending_id = executor->pending_ids_[dst_id]; const int src_slot = e.output_slot; - // TODO(yuanbyu): We don't need this if we require the subgraph - // given to an executor not to contain a sink node. - if (dst_item->is_sink) continue; - bool dst_dead = false; bool dst_ready = false; // True iff this input for dst is needed. We only set this input for @@ -2977,19 +3079,18 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, } } } else { + // Handle all other (non-merge) nodes. const bool increment_dead = (is_dead || (!is_control_edge && (*outputs)[src_slot].state == Entry::State::NO_VALUE)); - int pending, dead; - iter_state->adjust_for_activation(dst_pending_id, increment_dead, - &pending, &dead); - dst_dead = (dead > 0); - dst_ready = (pending == 0); + const PendingCounts::AdjustResult adjust_result = + iter_state->adjust_for_activation(dst_pending_id, increment_dead); + dst_dead = adjust_result.any_dead; + dst_ready = !adjust_result.any_pending; } if (dst_need_input) { - const int dst_slot = e.input_slot; - const int dst_loc = dst_item->input_start + dst_slot; + const int dst_loc = e.input_slot; if (e.is_last) { input_tensors[dst_loc] = std::move((*outputs)[src_slot]); } else { @@ -3006,6 +3107,17 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, } } +void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, + const bool is_dead, int64 iter, + EntryVector* outputs, + TaggedNodeSeq* ready) { + if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { + ActivateNodesSlowPath(item, is_dead, iter, outputs, ready); + } else { + ActivateNodesFastPath(item, is_dead, iter, outputs, ready); + } +} + void ExecutorState::FrameState::ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) { diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index 5f012b19bc2..74febf43287 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/simple_philox.h" @@ -452,6 +453,7 @@ static void BM_executor(int iters, int width, int depth) { SetBenchmarkLabel(strings::StrCat("Nodes = ", cur)); SetBenchmarkItemsProcessed(cur * static_cast(iters)); #endif // PLATFORM_GOOGLE + FixupSourceAndSinkEdges(g); testing::StartTiming(); test::Benchmark("cpu", g).Run(iters); } @@ -487,6 +489,7 @@ static void BM_FeedInputFetchOutput(int iters) { #ifdef PLATFORM_GOOGLE SetBenchmarkItemsProcessed(static_cast(iters)); #endif // PLATFORM_GOOGLE + FixupSourceAndSinkEdges(g); testing::StartTiming(); test::Benchmark("cpu", g).RunWithRendezvousArgs({{x_key, val}, {y_key, val}}, {z_key}, iters); diff --git a/tensorflow/core/common_runtime/pending_counts.h b/tensorflow/core/common_runtime/pending_counts.h index 5e1925c4016..b4338af6fde 100644 --- a/tensorflow/core/common_runtime/pending_counts.h +++ b/tensorflow/core/common_runtime/pending_counts.h @@ -208,21 +208,25 @@ class PendingCounts { } } + struct AdjustResult { + bool any_dead; + bool any_pending; + + AdjustResult(bool any_dead, bool any_pending) + : any_dead(any_dead), any_pending(any_pending) {} + }; + // A streamlined routine that does several pieces of bookkeeping at // once. Equivalent to: // if (increment_dead) increment_dead_count(h); // decrement_pending(h, 1); - // *pending_result = pending(h); - // *dead_result = dead_count(h); - void adjust_for_activation(Handle h, bool increment_dead, int* pending_result, - int* dead_result) { + // return {dead_count(h) > 0, pending(h) > 0}; + AdjustResult adjust_for_activation(Handle h, bool increment_dead) { DCHECK_GE(pending(h), 1); if (h.is_large_) { - adjust_for_activation_shared(Large(h), increment_dead, pending_result, - dead_result); + return adjust_for_activation_shared(Large(h), increment_dead); } else { - adjust_for_activation_shared(Packed(h), increment_dead, pending_result, - dead_result); + return adjust_for_activation_shared(Packed(h), increment_dead); } } @@ -238,17 +242,12 @@ class PendingCounts { private: template - inline void adjust_for_activation_shared(T* c, bool increment_dead, - int* pending_result, - int* dead_result) { - if (increment_dead) { - if (PENDING_NOTREADY == NodeStateForStruct(c)) { - c->dead_count++; - } + inline AdjustResult adjust_for_activation_shared(T* c, bool increment_dead) { + if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(c)) { + c->dead_count++; } c->pending -= 1; - *dead_result = c->dead_count; - *pending_result = c->pending; + return AdjustResult(c->dead_count, c->pending); } // We keep track of the pending count and dead input count for each diff --git a/tensorflow/core/common_runtime/pending_counts_test.cc b/tensorflow/core/common_runtime/pending_counts_test.cc index 8ab6383f20a..5d5e7367c86 100644 --- a/tensorflow/core/common_runtime/pending_counts_test.cc +++ b/tensorflow/core/common_runtime/pending_counts_test.cc @@ -144,25 +144,24 @@ TEST(PendingCounts, AdjustForActivation) { PendingCounts::Handle h = handles[id]; // Test for both packed and large. int count = (id == 0) ? 5 : 15; - int pending, dead; PendingCounts c(layout); c.set_initial_count(h, count); EXPECT_EQ(c.pending(h), count); // Don't increment the dead count this time - c.adjust_for_activation(h, false, &pending, &dead); + PendingCounts::AdjustResult result = c.adjust_for_activation(h, false); EXPECT_EQ(c.pending(h), count - 1); - EXPECT_EQ(c.pending(h), pending); + EXPECT_TRUE(result.any_pending); EXPECT_EQ(c.dead_count(h), 0); - EXPECT_EQ(c.dead_count(h), dead); + EXPECT_FALSE(result.any_dead); // Increment the dead count this time - c.adjust_for_activation(h, true, &pending, &dead); + result = c.adjust_for_activation(h, true); EXPECT_EQ(c.pending(h), count - 2); - EXPECT_EQ(c.pending(h), pending); - EXPECT_EQ(c.dead_count(h), dead); + EXPECT_TRUE(result.any_pending); EXPECT_EQ(c.dead_count(h), 1); + EXPECT_TRUE(result.any_dead); } } From cea988551e46d538b644992eeaccac9341620783 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Sun, 22 Mar 2020 14:52:03 -0700 Subject: [PATCH 391/492] Support host training loops in Model.fit Adds experimental_steps_per_execution to Model.compile to control this behavior. PiperOrigin-RevId: 302325887 Change-Id: I224a6aad9da524ba1a73bbccbd51b3420514455d --- tensorflow/python/keras/callbacks.py | 16 +- .../distribute/distribute_strategy_test.py | 79 ++++++++ tensorflow/python/keras/engine/base_layer.py | 13 +- .../python/keras/engine/base_layer_test.py | 94 ++------- .../python/keras/engine/data_adapter.py | 69 ++++++- tensorflow/python/keras/engine/training.py | 178 ++++++++++-------- .../golden/v1/tensorflow.keras.-model.pbtxt | 2 +- .../v1/tensorflow.keras.-sequential.pbtxt | 2 +- ...low.keras.experimental.-linear-model.pbtxt | 2 +- ....keras.experimental.-wide-deep-model.pbtxt | 2 +- .../v1/tensorflow.keras.models.-model.pbtxt | 2 +- .../tensorflow.keras.models.-sequential.pbtxt | 2 +- .../golden/v2/tensorflow.keras.-model.pbtxt | 2 +- .../v2/tensorflow.keras.-sequential.pbtxt | 2 +- ...low.keras.experimental.-linear-model.pbtxt | 2 +- ....keras.experimental.-wide-deep-model.pbtxt | 2 +- .../v2/tensorflow.keras.models.-model.pbtxt | 2 +- .../tensorflow.keras.models.-sequential.pbtxt | 2 +- 18 files changed, 288 insertions(+), 185 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index c68f58c0747..e0b6ba52239 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -920,14 +920,15 @@ class ProgbarLogger(Callback): print('Epoch %d/%d' % (epoch + 1, self.epochs)) def on_train_batch_end(self, batch, logs=None): - self._batch_update_progbar(logs) + self._batch_update_progbar(batch, logs) def on_test_batch_end(self, batch, logs=None): if not self._called_in_fit: - self._batch_update_progbar(logs) + self._batch_update_progbar(batch, logs) def on_predict_batch_end(self, batch, logs=None): - self._batch_update_progbar(None) # Don't pass prediction results. + # Don't pass prediction results. + self._batch_update_progbar(batch, None) def on_epoch_end(self, epoch, logs=None): self._finalize_progbar(logs) @@ -943,7 +944,7 @@ class ProgbarLogger(Callback): self.seen = 0 self.progbar = None - def _batch_update_progbar(self, logs=None): + def _batch_update_progbar(self, batch, logs=None): """Updates the progbar.""" if self.stateful_metrics is None: if self.model: @@ -962,8 +963,11 @@ class ProgbarLogger(Callback): batch_size = logs.pop('size', 0) num_steps = logs.pop('num_steps', 1) # DistStrat can run >1 steps. logs.pop('batch', None) - add_seen = num_steps if self.use_steps else num_steps * batch_size - self.seen += add_seen + if self.use_steps: + self.seen = batch + 1 # One-indexed. + else: + add_seen = num_steps * batch_size + self.seen += add_seen self.progbar.update(self.seen, list(logs.items()), finalize=False) def _finalize_progbar(self, logs): diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index f5c3dc9bcfe..920b180df0b 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -320,6 +320,28 @@ def strategy_and_optimizer_combinations(): return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph +class BatchCountingCB(keras.callbacks.Callback): + + def __init__(self): + super(BatchCountingCB, self).__init__() + self.train_begin_batches = [] + self.train_end_batches = [] + self.test_begin_batches = [] + self.test_end_batches = [] + + def on_train_batch_begin(self, batch, logs=None): + self.train_begin_batches.append(batch) + + def on_train_batch_end(self, batch, logs=None): + self.train_end_batches.append(batch) + + def on_test_batch_begin(self, batch, logs=None): + self.test_begin_batches.append(batch) + + def on_test_batch_end(self, batch, logs=None): + self.test_end_batches.append(batch) + + class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): @@ -1706,6 +1728,63 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, 'distributed dataset, you must specify'): model.fit(ds, epochs=2) + @combinations.generate( + combinations.combine(distribution=all_strategies, mode=['eager'])) + def test_host_training_loop(self, distribution): + with distribution.scope(): + inputs = keras.Input((10, 10, 3)) + x = keras.layers.Conv2D(3, kernel_size=3)(inputs) + x = keras.layers.Flatten()(x) + outputs = keras.layers.Dense(1)(x) + model = keras.Model(inputs, outputs) + + model.compile('sgd', 'mse', experimental_steps_per_execution=10) + + bc = BatchCountingCB() + x, y = np.ones((100, 10, 10, 3)), np.ones((100, 1)) + model.fit(x, y, batch_size=2, epochs=1, callbacks=[bc]) + self.assertEqual(bc.train_begin_batches, [0, 10, 20, 30, 40]) + self.assertEqual(bc.train_end_batches, [9, 19, 29, 39, 49]) + + @combinations.generate( + combinations.combine(distribution=all_strategies, mode=['eager'])) + def test_host_training_loop_last_partial_execution(self, distribution): + with distribution.scope(): + inputs = keras.Input(10) + outputs = keras.layers.Dense(1)(inputs) + model = keras.Model(inputs, outputs) + + model.compile('sgd', 'mse', experimental_steps_per_execution=20) + + bc = BatchCountingCB() + x, y = np.ones((100, 10)), np.ones((100, 1)) + model.fit(x, y, batch_size=2, epochs=1, callbacks=[bc]) + self.assertEqual(bc.train_begin_batches, [0, 20, 40]) + self.assertEqual(bc.train_end_batches, [19, 39, 49]) + + @combinations.generate( + combinations.combine(distribution=all_strategies, mode=['eager'])) + def test_host_training_loop_dataset_unknown_size(self, distribution): + with distribution.scope(): + inputs = keras.Input(10) + outputs = keras.layers.Dense(1)(inputs) + model = keras.Model(inputs, outputs) + + model.compile('sgd', 'mse', experimental_steps_per_execution=20) + + x, y = np.ones((100, 10)), np.ones((100, 1)) + ds = dataset_ops.DatasetV2.from_tensor_slices((x, y)).batch(2) + ds = ds.filter(lambda *args, **kwargs: True) # Makes the size UNKNOWN. + bc = BatchCountingCB() + + with self.assertRaisesRegexp(ValueError, 'steps_per_execution'): + model.fit(ds, epochs=2, callbacks=[bc]) + + ds = ds.repeat(2) + model.fit(ds, steps_per_epoch=50, epochs=2, callbacks=[bc]) + self.assertEqual(bc.train_begin_batches, [0, 20, 40, 0, 20, 40]) + self.assertEqual(bc.train_end_batches, [19, 39, 49, 19, 39, 49]) + @combinations.generate( combinations.times( all_strategy_combinations_minus_default())) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 5a37826d761..38cddcfde6f 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1129,13 +1129,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): continue for u in layer._updates: if callable(u): - try: - u = u() - except errors.InaccessibleTensorError: - base_layer_utils.check_graph_consistency( - method='add_update', force_raise=True) - raise # check_graph_consistency may not always raise. - base_layer_utils.check_graph_consistency(u, method='add_update') + u = u() collected_updates.append(u) return collected_updates @@ -1268,7 +1262,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if (tf_utils.is_symbolic_tensor(loss) and not base_layer_utils.is_in_tf_function()): symbolic_losses.append(_tag_unconditional(loss)) - base_layer_utils.check_graph_consistency(loss, method='add_loss') elif tensor_util.is_tensor(loss): eager_losses.append(_tag_unconditional(loss)) @@ -1363,8 +1356,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # If a metric was added in a Layer's `call` or `build`. if in_call_context or not getattr(self, '_is_graph_network', False): # TF Function path should take the eager path. - if is_symbolic and not base_layer_utils.is_in_tf_function(): - base_layer_utils.check_graph_consistency(value, method='add_metric') self._add_metric(value, aggregation, name) else: if from_metric_obj: @@ -2172,8 +2163,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): array_ops.shape(output)[0], activity_loss.dtype) # Make activity regularization strength batch-agnostic. mean_activity_loss = activity_loss / batch_size - base_layer_utils.check_graph_consistency( - mean_activity_loss, method='activity_regularizer') self.add_loss(mean_activity_loss, inputs=inputs) def _set_mask_metadata(self, inputs, outputs, previous_mask): diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 1999f313d6b..c905c6118c3 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -29,6 +29,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec @@ -1106,46 +1107,6 @@ class AutographControlFlowTest(keras_parameterized.TestCase): model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) self.assertEqual(backend.get_value(layer.counter), 1.) - def test_conditional_updates_in_call(self): - - class MyLayer(base_layer.Layer): - - def __init__(self): - super(MyLayer, - self).__init__(dynamic=testing_utils.should_run_eagerly()) - - def build(self, input_shape): - self.counter = self.add_weight( - shape=(), trainable=False, initializer='zeros') - - def call(self, inputs, training=None): - if training: - z = math_ops.reduce_sum(inputs) - self.add_update(lambda: self.counter.assign_add(z)) - return inputs - - def compute_output_shape(self, input_shape): - return input_shape - - if testing_utils.should_run_eagerly(): - inputs = input_layer.Input((3,)) - layer = MyLayer() - outputs = layer(inputs) - model = training_lib.Model(inputs, outputs) - model.compile( - 'sgd', - 'mse', - run_eagerly=testing_utils.should_run_eagerly()) - model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) - self.assertEqual(backend.get_value(layer.counter), 6.) - else: - # TODO(fchollet): support the same workflow in graph mode. - with self.assertRaisesRegexp(RuntimeError, - '`add_update` in a control flow branch'): - layer = MyLayer() - layer(input_layer.Input((3,))) - _ = layer.updates - def test_conditional_losses_in_call(self): class MyLayer(base_layer.Layer): @@ -1162,21 +1123,13 @@ class AutographControlFlowTest(keras_parameterized.TestCase): def compute_output_shape(self, input_shape): return input_shape - if testing_utils.should_run_eagerly(): - inputs = input_layer.Input((3,)) - layer = MyLayer() - outputs = layer(inputs) - model = training_lib.Model(inputs, outputs) - model.compile( - 'sgd', - 'mse', - run_eagerly=testing_utils.should_run_eagerly()) - loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) - self.assertEqual(loss, 2 * 3) - else: - with self.assertRaisesRegexp(RuntimeError, - '`add_loss` in a control flow branch'): - layer = MyLayer()(input_layer.Input((3,))) + inputs = input_layer.Input((3,)) + layer = MyLayer() + outputs = layer(inputs) + model = training_lib.Model(inputs, outputs) + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(loss, 2 * 3) def test_conditional_callable_losses(self): model = sequential.Sequential([ @@ -1217,22 +1170,13 @@ class AutographControlFlowTest(keras_parameterized.TestCase): def compute_output_shape(self, input_shape): return input_shape - if testing_utils.should_run_eagerly(): - inputs = input_layer.Input((3,)) - layer = MyLayer() - outputs = layer(inputs) - model = training_lib.Model(inputs, outputs) - model.compile( - 'sgd', - 'mse', - run_eagerly=testing_utils.should_run_eagerly()) - history = model.fit(np.ones((2, 3)), np.ones((2, 3))) - self.assertEqual(history.history['sum'][-1], 2 * 3) - else: - # TODO(fchollet): support the same workflow in graph mode. - with self.assertRaisesRegexp(RuntimeError, - '`add_metric` in a control flow branch'): - layer = MyLayer()(input_layer.Input((3,))) + inputs = input_layer.Input((3,)) + layer = MyLayer() + outputs = layer(inputs) + model = training_lib.Model(inputs, outputs) + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + history = model.fit(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(history.history['sum'][-1], 2 * 3) def test_conditional_activity_regularizer_in_call(self): @@ -1261,8 +1205,8 @@ class AutographControlFlowTest(keras_parameterized.TestCase): if testing_utils.should_run_eagerly(): model.fit(x, y, epochs=2, batch_size=5) else: - with self.assertRaisesRegexp( - RuntimeError, '`activity_regularizer` in a control flow branch'): + with self.assertRaisesRegexp(errors_impl.InaccessibleTensorError, + 'ActivityRegularizer'): model.fit(x, y, epochs=2, batch_size=5) def test_conditional_activity_regularizer_with_wrappers_in_call(self): @@ -1293,8 +1237,8 @@ class AutographControlFlowTest(keras_parameterized.TestCase): if testing_utils.should_run_eagerly(): model.fit(x, y, epochs=2, batch_size=5) else: - with self.assertRaisesRegexp( - RuntimeError, '`activity_regularizer` in a control flow branch'): + with self.assertRaisesRegexp(errors_impl.InaccessibleTensorError, + 'ActivityRegularizer'): model.fit(x, y, epochs=2, batch_size=5) diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 43eefd75c36..b0741acfe30 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -1102,11 +1102,22 @@ class DataHandler(object): max_queue_size=10, workers=1, use_multiprocessing=False, - model=None): + model=None, + steps_per_execution=1): self._initial_epoch = initial_epoch self._epochs = epochs self._insufficient_data = False + self._model = model + self._steps_per_execution = steps_per_execution + + # This `Variable` is assigned to by `DataHandler` to allow partial + # executions. Save its original value here to reset after a partial + # execution. + if isinstance(steps_per_execution, int): + self._steps_per_execution_value = steps_per_execution + else: + self._steps_per_execution_value = steps_per_execution.numpy().item() adapter_cls = select_data_adapter(x, y) self._adapter = adapter_cls( @@ -1133,6 +1144,12 @@ class DataHandler(object): dataset = strategy.experimental_distribute_dataset(dataset) self._dataset = dataset + self._current_step = 0 + self._step_increment = self._steps_per_execution_value - 1 + self._insufficient_data = False + + self._validate_data_handler() + def enumerate_epochs(self): """Yields `(epoch, tf.data.Iterator)`.""" data_iterator = iter(self._dataset) @@ -1156,11 +1173,14 @@ class DataHandler(object): yield context.async_wait() except (StopIteration, errors.OutOfRangeError): - if (self._adapter.get_size() is None and self._inferred_steps is None and - self._current_step > 0): + context.async_clear_error() + if self._inferred_steps is None: # The input passed by the user ran out of batches. # Now we know the cardinality of the input(dataset or generator). - self._inferred_steps = self._current_step + if self._model is not None: + self._inferred_steps = self._model._train_counter.numpy().item() # pylint: disable=protected-access + else: + self._inferred_steps = self._current_step else: self._insufficient_data = True total_epochs = self._epochs - self._initial_epoch @@ -1180,8 +1200,30 @@ class DataHandler(object): self._current_step < self._inferred_steps): if self._insufficient_data: # Set by `catch_stop_iteration`. break - yield self._current_step - self._current_step += 1 + + can_run_full_execution = ( + self._steps_per_execution_value == 1 or + self._inferred_steps is None or + self._inferred_steps - self._current_step >= + self._steps_per_execution_value) + + if can_run_full_execution: + self._step_increment = self._steps_per_execution_value - 1 + yield self._current_step + self._current_step += self._steps_per_execution_value + else: + # Last partial execution. + steps_remaining = self._inferred_steps - self._current_step + self._steps_per_execution.assign(steps_remaining) + self._step_increment = steps_remaining - 1 + yield self._current_step + self._current_step += steps_remaining + self._steps_per_execution.assign(self._steps_per_execution_value) + + @property + def step_increment(self): + """The number to increment the step for `on_batch_end` methods.""" + return self._step_increment @property def inferred_steps(self): @@ -1198,6 +1240,13 @@ class DataHandler(object): """ return self._inferred_steps + @property + def should_sync(self): + # Catch OutOfRangeError for Datasets of unknown size. + # This blocks until the batch has finished executing. + # TODO(b/150292341): Allow multiple async steps here. + return self._inferred_steps is None + def _infer_steps(self, steps, dataset): """Infers steps_per_epoch needed to loop through a dataset.""" if steps is not None: @@ -1231,6 +1280,14 @@ class DataHandler(object): def _samples(self): return self._adapter.get_samples() + def _validate_data_handler(self): + # TODO(b/152094471): Support this with DistIter.get_next_as_optional. + if self._steps_per_execution_value > 1 and self._inferred_steps is None: + raise ValueError( + "Could not infer the size of the data. With " + "`steps_per_execution > 1`, you must specify the number of steps " + "to run.") + def _make_class_weight_map_fn(class_weight): """Applies class weighting to a `Dataset`. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 21361f680da..168647d5342 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -44,6 +44,7 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables @@ -166,7 +167,8 @@ class Model(network.Network, version_utils.ModelVersionSelector): additional details. """ _TF_MODULE_IGNORED_PROPERTIES = frozenset( - itertools.chain(('_train_counter', '_test_counter', '_predict_counter'), + itertools.chain(('_train_counter', '_test_counter', '_predict_counter', + '_steps_per_execution'), network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access def __init__(self, *args, **kwargs): @@ -274,65 +276,71 @@ class Model(network.Network, version_utils.ModelVersionSelector): loss_weights=None, sample_weight_mode=None, weighted_metrics=None, + run_eagerly=None, **kwargs): """Configures the model for training. Arguments: - optimizer: String (name of optimizer) or optimizer instance. - See `tf.keras.optimizers`. + optimizer: String (name of optimizer) or optimizer instance. See + `tf.keras.optimizers`. loss: String (name of objective function), objective function or - `tf.keras.losses.Loss` instance. See `tf.keras.losses`. - An objective function is any callable with the signature - `loss = fn(y_true, y_pred)`, where - y_true = ground truth values with shape = `[batch_size, d0, .. dN]`, - except sparse loss functions such as sparse categorical crossentropy - where shape = `[batch_size, d0, .. dN-1]`. - y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. - It returns a weighted loss float tensor. - If a custom `Loss` instance is used and reduction is set to NONE, - return value has the shape [batch_size, d0, .. dN-1] ie. per-sample - or per-timestep loss values; otherwise, it is a scalar. - If the model has multiple outputs, you can use a different loss on - each output by passing a dictionary or a list of losses. The loss - value that will be minimized by the model will then be the sum of - all individual losses. + `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective + function is any callable with the signature `loss = fn(y_true, + y_pred)`, where y_true = ground truth values with shape = + `[batch_size, d0, .. dN]`, except sparse loss functions such as sparse + categorical crossentropy where shape = `[batch_size, d0, .. dN-1]`. + y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. It + returns a weighted loss float tensor. If a custom `Loss` instance is + used and reduction is set to NONE, return value has the shape + [batch_size, d0, .. dN-1] ie. per-sample or per-timestep loss values; + otherwise, it is a scalar. If the model has multiple outputs, you can + use a different loss on each output by passing a dictionary or a list + of losses. The loss value that will be minimized by the model will + then be the sum of all individual losses. metrics: List of metrics to be evaluated by the model during training - and testing. - Each of this can be a string (name of a built-in function), function - or a `tf.keras.metrics.Metric` instance. See `tf.keras.metrics`. - Typically you will use `metrics=['accuracy']`. A function is any - callable with the signature `result = fn(y_true, y_pred)`. - To specify different metrics for different outputs of a - multi-output model, you could also pass a dictionary, such as + and testing. Each of this can be a string (name of a built-in + function), function or a `tf.keras.metrics.Metric` instance. See + `tf.keras.metrics`. Typically you will use `metrics=['accuracy']`. A + function is any callable with the signature `result = fn(y_true, + y_pred)`. To specify different metrics for different outputs of a + multi-output model, you could also pass a dictionary, such as `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. - You can also pass a list (len = len(outputs)) of lists of metrics - such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or - `metrics=['accuracy', ['accuracy', 'mse']]`. - When you pass the strings 'accuracy' or 'acc', we convert this to - one of `tf.keras.metrics.BinaryAccuracy`, - `tf.keras.metrics.CategoricalAccuracy`, - `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss - function used and the model output shape. We do a similar conversion - for the strings 'crossentropy' and 'ce' as well. - loss_weights: Optional list or dictionary specifying scalar - coefficients (Python floats) to weight the loss contributions - of different model outputs. - The loss value that will be minimized by the model - will then be the *weighted sum* of all individual losses, - weighted by the `loss_weights` coefficients. - If a list, it is expected to have a 1:1 mapping - to the model's outputs. If a dict, it is expected to map - output names (strings) to scalar coefficients. - sample_weight_mode: If you need to do timestep-wise - sample weighting (2D weights), set this to `"temporal"`. - `None` defaults to sample-wise weights (1D). - If the model has multiple outputs, you can use a different - `sample_weight_mode` on each output by passing a - dictionary or a list of modes. - weighted_metrics: List of metrics to be evaluated and weighted - by sample_weight or class_weight during training and testing. - **kwargs: Any additional arguments. For eager execution, pass - `run_eagerly=True`. + You can also pass a list (len = len(outputs)) of lists of metrics + such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or + `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the + strings 'accuracy' or 'acc', we convert this to one of + `tf.keras.metrics.BinaryAccuracy`, + `tf.keras.metrics.CategoricalAccuracy`, + `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss + function used and the model output shape. We do a similar + conversion for the strings 'crossentropy' and 'ce' as well. + loss_weights: Optional list or dictionary specifying scalar coefficients + (Python floats) to weight the loss contributions of different model + outputs. The loss value that will be minimized by the model will then + be the *weighted sum* of all individual losses, weighted by the + `loss_weights` coefficients. + If a list, it is expected to have a 1:1 mapping to the model's + outputs. If a dict, it is expected to map output names (strings) + to scalar coefficients. + sample_weight_mode: If you need to do timestep-wise sample weighting (2D + weights), set this to `"temporal"`. `None` defaults to sample-wise + weights (1D). If the model has multiple outputs, you can use a + different `sample_weight_mode` on each output by passing a dictionary + or a list of modes. + weighted_metrics: List of metrics to be evaluated and weighted by + sample_weight or class_weight during training and testing. + run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s + logic will not be wrapped in a `tf.function`. Recommended to leave + this as `None` unless your `Model` cannot be run inside a + `tf.function`. + **kwargs: Any additional arguments. Supported arguments: + `experimental_steps_per_execution`: Int. The number of batches to + run during each `tf.function` call. Running multiple batches + inside a single `tf.function` call can greatly improve performance + on TPUs or small models with a large Python overhead. Note that if + this value is set to `N`, `Callback.on_batch` methods will only be + called every `N` batches. This currently defaults to `1`. At most, + one full epoch can be run each execution. Raises: ValueError: In case of invalid arguments for @@ -349,6 +357,10 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.compiled_metrics = compile_utils.MetricsContainer( metrics, weighted_metrics, output_names=self.output_names) + experimental_steps_per_execution = kwargs.pop( + 'experimental_steps_per_execution', 1) + self._configure_steps_per_execution(experimental_steps_per_execution) + # Initializes attrs that are reset each time `compile` is called. self._reset_compile_cache() self._is_compiled = True @@ -376,6 +388,13 @@ class Model(network.Network, version_utils.ModelVersionSelector): # Used to cache `trainable` attr of `Layer`s for `fit`. self._compiled_trainable_state = self._get_trainable_state() + @trackable.no_automatic_dependency_tracking + def _configure_steps_per_execution(self, steps_per_execution): + self._steps_per_execution = variables.Variable( + steps_per_execution, + dtype='int64', + aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA) + @property def metrics(self): """Returns the model's metrics added using `compile`, `add_metric` APIs.""" @@ -517,21 +536,39 @@ class Model(network.Network, version_utils.ModelVersionSelector): if self.train_function is not None: return self.train_function - def train_function(iterator): - """Runs one call to `self.train_function`.""" + def step_function(model, iterator): + """Runs a single training step.""" def run_step(data): - outputs = self.train_step(data) - self._train_counter.assign_add(1) + outputs = model.train_step(data) + # Ensure counter is updated only if `train_step` succeeds. + control_deps = [nest.flatten(outputs)[0]] + with ops.control_dependencies(control_deps): + model._train_counter.assign_add(1) # pylint: disable=protected-access return outputs data = next(iterator) - outputs = self.distribute_strategy.run(run_step, args=(data,)) + outputs = model.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') - write_scalar_summaries(outputs, step=self._train_counter) + write_scalar_summaries(outputs, step=model._train_counter) # pylint: disable=protected-access return outputs + if self._steps_per_execution.numpy().item() == 1: + + def train_function(iterator): + """Runs a training execution with one step.""" + return step_function(self, iterator) + + else: + + def train_function(iterator): + """Runs a training execution with multiple steps.""" + outputs = step_function(self, iterator) + for _ in math_ops.range(self._steps_per_execution - 1): + outputs = step_function(self, iterator) + return outputs + if not self.run_eagerly: train_function = def_function.function( train_function, experimental_relax_shapes=True) @@ -773,7 +810,8 @@ class Model(network.Network, version_utils.ModelVersionSelector): max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, - model=self) + model=self, + steps_per_execution=self._steps_per_execution) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -808,13 +846,11 @@ class Model(network.Network, version_utils.ModelVersionSelector): batch_size=batch_size): callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) - # Catch OutOfRangeError for Datasets of unknown size. - # This blocks until the batch has finished executing. - # TODO(b/150292341): Allow multiple async steps here. - if not data_handler.inferred_steps: + if data_handler.should_sync: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. - callbacks.on_train_batch_end(step, logs) + end_step = step + data_handler.step_increment + callbacks.on_train_batch_end(end_step, logs) epoch_logs = copy.copy(logs) # Run validation. @@ -1045,10 +1081,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): with trace.Trace('TraceContext', graph_type='test', step_num=step): callbacks.on_test_batch_begin(step) tmp_logs = test_function(iterator) - # Catch OutOfRangeError for Datasets of unknown size. - # This blocks until the batch has finished executing. - # TODO(b/150292341): Allow multiple async steps here. - if not data_handler.inferred_steps: + if not data_handler.should_sync: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. callbacks.on_test_batch_end(step, logs) @@ -1239,10 +1272,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): for step in data_handler.steps(): callbacks.on_predict_batch_begin(step) tmp_batch_outputs = predict_function(iterator) - # Catch OutOfRangeError for Datasets of unknown size. - # This blocks until the batch has finished executing. - # TODO(b/150292341): Allow multiple async steps here. - if not data_handler.inferred_steps: + if data_handler.should_sync: context.async_wait() batch_outputs = tmp_batch_outputs # No error, now safe to assign. if outputs is None: @@ -1544,7 +1574,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): if kwargs.pop('target_tensors', None) is not None: raise ValueError( 'target_tensors argument is not supported when executing eagerly.') - invalid_kwargs = set(kwargs) - {'run_eagerly'} + invalid_kwargs = set(kwargs) - {'experimental_steps_per_execution'} if invalid_kwargs: raise TypeError('Invalid keyword argument(s) in `compile`: %s' % (invalid_kwargs,)) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt index 1f81e67c055..414f682473c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt @@ -171,7 +171,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt index d07869c782e..fb929010980 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt @@ -176,7 +176,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt index be83cf67dc5..2d64a7bb9e0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt @@ -172,7 +172,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt index 1657ea681e7..4f17a33773c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -172,7 +172,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt index ebf8a930005..b77893eaeda 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt @@ -171,7 +171,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt index 980301e8f14..939657fd748 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt @@ -176,7 +176,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt index 1f81e67c055..414f682473c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt @@ -171,7 +171,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt index d07869c782e..fb929010980 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt @@ -176,7 +176,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt index be83cf67dc5..2d64a7bb9e0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt @@ -172,7 +172,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt index 1657ea681e7..4f17a33773c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -172,7 +172,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt index ebf8a930005..b77893eaeda 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt @@ -171,7 +171,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt index 980301e8f14..939657fd748 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt @@ -176,7 +176,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" From ce82bfc4545149605950b442a6ff53ffc3e992d2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 22 Mar 2020 17:06:01 -0500 Subject: [PATCH 392/492] rollback unrelated change for TFE_OpReset --- tensorflow/c/eager/c_api_experimental.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 9d491c72f38..afa36fe1210 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -31,7 +31,6 @@ using tensorflow::string; void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { if (op_to_reset) { - op_to_reset->operation->Clear(); status->status = op_to_reset->operation->Reset(op_or_function_name, raw_device_name); } else { From 15f9f08467758d7bc39b50c2782e28311ca9d5d3 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Sun, 22 Mar 2020 17:01:40 -0700 Subject: [PATCH 393/492] Replace ASSERT_TRUE with TF_ASSERT_OK in compile_mlir_util_test for better error messages Example error below: Expected equality of these values: ::tensorflow::Status::OK() Which is: OK (s) Which is: Invalid argument: could not parse MLIR module: error: custom op 'return' is unknown PiperOrigin-RevId: 302335189 Change-Id: I4a94a799d0a7e79af6df0b4aa73545590c284f30 --- .../tensorflow/utils/compile_mlir_util_test.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 0caf1752cfb..7db3d34a4ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -66,13 +66,13 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { Status s = CompileSerializedMlirToXlaHlo( kBinaryAddModule, arg_shapes, /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); const xla::HloModuleConfig module_config( compilation_result.computation->GetProgramShape().ValueOrDie()); auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); + TF_ASSERT_OK(status_or_hlo_module.status()); string expected_hlo_module_string = R"(HloModule main.6 ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { @@ -124,13 +124,13 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { Status s = CompileSerializedMlirToXlaHlo( kBinaryAddModule, arg_shapes, /*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); const xla::HloModuleConfig module_config( compilation_result.computation->GetProgramShape().ValueOrDie()); auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); + TF_ASSERT_OK(status_or_hlo_module.status()); string expected_hlo_module_string = R"(HloModule main.5 ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { @@ -195,13 +195,13 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { Status s = CompileSerializedMlirToXlaHlo( mlir_module, arg_shapes, /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); const xla::HloModuleConfig module_config( compilation_result.computation->GetProgramShape().ValueOrDie()); auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); + TF_ASSERT_OK(status_or_hlo_module.status()); string expected_hlo_module_string = R"(HloModule main.6 ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { @@ -240,7 +240,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { compilation_result.computation->GetProgramShape().ValueOrDie()); auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); + TF_ASSERT_OK(status_or_hlo_module.status()); string expected_signature = R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))"; @@ -265,13 +265,13 @@ TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { Status s = CompileSerializedMlirToXlaHlo( kBroadcastGradientArgsModule, arg_shapes, /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); const xla::HloModuleConfig module_config( compilation_result.computation->GetProgramShape().ValueOrDie()); auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); + TF_ASSERT_OK(status_or_hlo_module.status()); string expected_hlo_module_string = R"(HloModule main.4 ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { From fbaaf16478f18b7f13d4f00f0233eaa454307544 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Sun, 22 Mar 2020 18:16:30 -0700 Subject: [PATCH 394/492] Include used header in verifier.cc PiperOrigin-RevId: 302340637 Change-Id: I4e58632aa07a62b485c03990cd57018376ae28f5 --- tensorflow/lite/tools/verifier.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc index cc2ac8e5ed0..b0025015743 100644 --- a/tensorflow/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/tools/verifier.h" #include +#include #include #include "absl/container/flat_hash_set.h" From fe998dbcf4d6c4f91e9abef33dc1ad8406b56171 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Sun, 22 Mar 2020 19:01:09 -0700 Subject: [PATCH 395/492] Explicitly register dialects required in compilation pipelines Currently, these are relying on AllPassesAndDialects target to link and register all dialects and passes. Linking only the required dialects reduces the size of linker input and binary. This change is not re-using a static MLIR context among multiple invocations to avoid any change in behavior there and can be revisited separately. PiperOrigin-RevId: 302344149 Change-Id: I1537ae6d1fd332fb8c9d47e5d35559c3e64ba45c --- .../compiler/mlir/lite/quantization/xla/BUILD | 2 +- .../mlir/lite/quantization/xla/quantize.cc | 12 ++++++++++++ tensorflow/compiler/mlir/tensorflow/BUILD | 2 +- .../tensorflow/utils/compile_mlir_util.cc | 19 ++++++++++++++++++- tensorflow/compiler/tf2xla/BUILD | 2 ++ tensorflow/compiler/tf2xla/mlir_tf2xla.cc | 18 +++++++++++++++++- 6 files changed, 51 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD index 36c897d5fec..2ce36709e9c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD @@ -61,7 +61,6 @@ cc_library( "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo", "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/tf2xla:mlir_tf2xla", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -71,6 +70,7 @@ cc_library( "//tensorflow/core/platform:status", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc index 4640284fa5c..a5ac34e1cc0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project @@ -22,18 +23,29 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" namespace mlir { namespace xla_hlo { +static void RegisterDialects() { + static bool init_once = []() { + mlir::registerDialect(); + mlir::registerDialect(); + return true; + }(); + (void)init_once; +} + // Quantizes the model in the computation. tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config, xla::XlaComputation* computation) { TF_ASSIGN_OR_RETURN(std::unique_ptr snapshot, computation->Snapshot()); + RegisterDialects(); MLIRContext context; OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context)); auto status = xla::ConvertHloToMlirHlo( diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 782102510fa..6cd058a15d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1062,6 +1062,7 @@ cc_library( ":convert_type", ":dump_mlir_util", ":error_util", + ":tensorflow", ":tensorflow_dialect_registration", ":tensorflow_passes", ":tf_dialect_passes", @@ -1076,7 +1077,6 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 10aad0a03ff..b13eec71de9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/OpDefinition.h" // TF:llvm-project @@ -26,6 +27,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" @@ -33,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" @@ -208,7 +212,19 @@ Status RefineShapes(llvm::ArrayRef arg_shapes, return Status::OK(); } -} // namespace +static void RegisterDialects() { + static bool init_once = []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + return true; + }(); + (void)init_once; +} + +} // namespace +// namespace Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, @@ -265,6 +281,7 @@ Status CompileSerializedMlirToXlaHlo( bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { + RegisterDialects(); mlir::MLIRContext mlir_context; mlir::OwningModuleRef mlir_module; diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index a6f88df7e40..10a67e835b1 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -180,6 +180,8 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 9303e2e9330..01b1fed9cac 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -20,6 +20,10 @@ limitations under the License. #include #include +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" @@ -28,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -84,6 +89,17 @@ Status ConvertOutputInfo(const tf2xla::Config& config, return ParseOutputArrayInfo(array_names, &specs->outputs); } +static void RegisterDialects() { + static bool init_once = []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + return true; + }(); + (void)init_once; +} + } // namespace Status ConvertGraphDefToXlaViaMlir( @@ -132,8 +148,8 @@ Status ConvertGraphDefToXlaViaMlir( } } + RegisterDialects(); mlir::MLIRContext context; - TF_ASSIGN_OR_RETURN( mlir::OwningModuleRef module, ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context)); From ed1484362608ca7f53ce9b36392cbc8411ced785 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Mar 2020 19:46:11 -0700 Subject: [PATCH 396/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302347974 Change-Id: I50057fa74dab3e2ae04f829c6fc96da66769ce0a --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 75d86f71b78..68bb1dc49f5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From f810279886b9b2864bf8c0a49a793b7cd211ee08 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Sun, 22 Mar 2020 20:14:30 -0700 Subject: [PATCH 397/492] Change promote-resources-to-args pass to emit all the unexpected users for VarHandleOp. Modify an existing test to test the output. PiperOrigin-RevId: 302350400 Change-Id: I0cb30a477c13629b20172b77e1e6340392f8fd5b --- .../tests/promote_resources_to_args.mlir | 3 ++- .../transforms/promote_resources_to_args.cc | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index dcf21275857..db15aeb9107 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -297,8 +297,9 @@ func @main(%arg0: tensor>>) -> tensor { // Tests VarHandleOp has users that are not removed. func @main() -> tensor { - // expected-error@+1 {{expects no uses}} + // expected-error@+1 {{expects no uses but used by operations: tf.UnknownOp, tf.VarIsInitializedOp}} %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> %1 = "tf.VarIsInitializedOp"(%0) : (tensor>>) -> tensor + %2 = "tf.UnknownOp"(%0) : (tensor>>) -> tensor return %1 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index e69ac6ab146..a0a521745cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -36,6 +36,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project @@ -210,8 +211,21 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { for (Operation& op : llvm::make_early_inc_range(function.front())) { auto var_handle_op = llvm::dyn_cast(op); if (!var_handle_op) continue; - if (!var_handle_op.use_empty()) - return var_handle_op.emitOpError() << "expects no uses"; + if (!var_handle_op.use_empty()) { + // SmallSet will use a vector when there is only one element and use + // std::set when there are more than one elements. This ensures that + // the operations in the error message are ordered. + llvm::SmallSet unique_operations; + llvm::for_each( + var_handle_op.getOperation()->getUsers(), [&](Operation* user) { + unique_operations.insert(user->getName().getStringRef().str()); + }); + + return var_handle_op.emitOpError( + "expects no uses but used by operations: ") + << llvm::join(unique_operations.begin(), unique_operations.end(), + ", "); + } op.erase(); } From 85b36ef350bf3d4318d1667f8f50692e5b9ee342 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Mar 2020 21:45:43 -0700 Subject: [PATCH 398/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302358447 Change-Id: Ib722ffc23f0f43a358d161ca3a4f3245689e7d4c --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 68bb1dc49f5..75d86f71b78 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 0c487d64172c64d60a93bc98cf5ea07f1a8e95ba Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Sun, 22 Mar 2020 21:59:17 -0700 Subject: [PATCH 399/492] Fixed the delegate memory ownership issue w/ the benchmark model tool. PiperOrigin-RevId: 302360006 Change-Id: I5cb9bd50463448d96076f27a864b939ca30e72f8 --- tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc | 1 + tensorflow/lite/tools/benchmark/benchmark_tflite_model.h | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index cd00a196337..d6965619689 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -646,6 +646,7 @@ TfLiteStatus BenchmarkTfLiteModel::Init() { << " delegate, and the model graph will be " << delegate_status << " executed w/ the delegate."; } + owned_delegates_.emplace_back(std::move(delegate)); } auto interpreter_inputs = interpreter_->inputs(); diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index 16d5c08ac44..8e9bad2269a 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -119,6 +119,8 @@ class BenchmarkTfLiteModel : public BenchmarkModel { std::unique_ptr profiling_listener_ = nullptr; std::unique_ptr ruy_profiling_listener_ = nullptr; std::mt19937 random_engine_; + + std::vector owned_delegates_; }; } // namespace benchmark From fbc353501cc48af1db3ada456598cb2254821a0c Mon Sep 17 00:00:00 2001 From: Ashutosh Hathidara Date: Mon, 23 Mar 2020 10:39:57 +0530 Subject: [PATCH 400/492] Added description of effect of regularization on validation and test --- tensorflow/python/keras/engine/training.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 168647d5342..e74a6caf5ac 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -655,7 +655,10 @@ class Model(network.Network, version_utils.ModelVersionSelector): `keras.utils.Sequence` instance. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. - The model will not be trained on this data. + The model will not be trained on this data. Thus, note the fact + that the validation loss of data provided using `validation_split` + or `validation_data` is not affected by regularization layers like + noise and dropuout. `validation_data` will override `validation_split`. `validation_data` could be: - tuple `(x_val, y_val)` of Numpy arrays or tensors @@ -1179,7 +1182,8 @@ class Model(network.Network, version_utils.ModelVersionSelector): directly using `__call__` is recommended for faster execution, e.g., `model(x)`, or `model(x, training=False)` if you have layers such as `tf.keras.layers.BatchNormalization` that behaves differently during - inference. + inference. Also, note the fact that test loss is not affected by + regularization layers like noise and dropout. Arguments: x: Input samples. It could be: From 7116a21f176031da4b818234c70b4616ff484cf7 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Sun, 22 Mar 2020 22:45:40 -0700 Subject: [PATCH 401/492] Make flatbuffer_translate_lib dynamic linked To do this, some static registered translated functions are moved to a seperated c++ file and target. Only the binaries requires these translates functions needs to link them statically. This cl also removes the tensorflow/core:lib dependence from the quantize_model target. PiperOrigin-RevId: 302364991 Change-Id: I89c7898fd320d84d340810c690098cc69a21c471 --- tensorflow/compiler/mlir/lite/BUILD | 46 +- .../compiler/mlir/lite/flatbuffer_export.cc | 1455 ---------------- .../compiler/mlir/lite/flatbuffer_import.cc | 84 +- .../mlir/lite/flatbuffer_translate.cc | 1495 ++++++++++++++++- ...buffer_export.h => flatbuffer_translate.h} | 6 +- ...t_flags.h => flatbuffer_translate_flags.h} | 6 +- .../compiler/mlir/lite/mlir_tflite_runner.cc | 4 +- .../lite/quantization/lite/quantize_model.cc | 2 +- .../mlir/lite/sparsity/sparsify_model.cc | 2 +- .../compiler/mlir/lite/tf_tfl_translate.cc | 4 +- .../mlir/lite/tf_to_tfl_flatbuffer.cc | 2 +- tensorflow/compiler/mlir/tensorflow/BUILD | 3 +- .../mlir/tensorflow/utils/error_util.cc | 2 +- .../mlir/tensorflow/utils/error_util.h | 2 +- 14 files changed, 1527 insertions(+), 1586 deletions(-) delete mode 100644 tensorflow/compiler/mlir/lite/flatbuffer_export.cc rename tensorflow/compiler/mlir/lite/{flatbuffer_export.h => flatbuffer_translate.h} (90%) rename tensorflow/compiler/mlir/lite/{flatbuffer_export_flags.h => flatbuffer_translate_flags.h} (84%) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 446ba89a3f1..03cf9265f3b 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -224,6 +224,7 @@ cc_library( deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:support", @@ -553,14 +554,14 @@ cc_library( cc_library( name = "flatbuffer_translate_lib", srcs = [ - "flatbuffer_export.cc", "flatbuffer_import.cc", + "flatbuffer_translate.cc", "utils/convert_type.cc", ], hdrs = [ - "flatbuffer_export.h", - "flatbuffer_export_flags.h", "flatbuffer_import.h", + "flatbuffer_translate.h", + "flatbuffer_translate_flags.h", "utils/convert_type.h", ], deps = [ @@ -578,10 +579,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:status", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", @@ -602,32 +601,15 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", ], -) - -cc_library( - name = "flatbuffer_translate_registeration", - srcs = [ - "flatbuffer_translate.cc", - ], - deps = [ - ":flatbuffer_translate_lib", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopOpsTransforms", - "@llvm-project//mlir:MlirTranslateMain", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Translation", - ], alwayslink = 1, ) tf_cc_binary( name = "flatbuffer_translate", deps = [ - ":flatbuffer_translate_registeration", + ":flatbuffer_translate_lib", + "@llvm-project//mlir:LoopOpsTransforms", + "@llvm-project//mlir:MlirTranslateMain", ], ) @@ -665,13 +647,10 @@ filegroup( tf_cc_binary( name = "tf_tfl_translate", - srcs = [ - ":tf_tfl_translate_main", - ], + srcs = [":tf_tfl_translate_main"], deps = [ ":common", ":flatbuffer_translate_lib", - ":flatbuffer_translate_registeration", ":tensorflow_lite", ":tf_tfl_passes", ":tf_tfl_translate_cl_options", @@ -693,18 +672,15 @@ tf_cc_binary( tf_cc_binary( name = "mlir-tflite-runner", - srcs = [ - "mlir_tflite_runner.cc", - ], + srcs = ["mlir_tflite_runner.cc"], deps = [ ":flatbuffer_translate_lib", - ":flatbuffer_translate_registeration", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc deleted file mode 100644 index 72e9b8c742a..00000000000 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ /dev/null @@ -1,1455 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/ToolOutputFile.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Translation.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/string_util.h" -#include "tensorflow/lite/tools/versioning/op_version.h" -#include "tensorflow/lite/tools/versioning/runtime_version.h" -#include "tensorflow/lite/version.h" - -using llvm::dyn_cast; -using llvm::formatv; -using llvm::isa; -using llvm::Optional; -using llvm::StringRef; -using llvm::Twine; -using mlir::Dialect; -using mlir::ElementsAttr; -using mlir::FuncOp; -using mlir::MLIRContext; -using mlir::ModuleOp; -using mlir::NoneType; -using mlir::Operation; -using mlir::Region; -using mlir::StringAttr; -using mlir::TensorType; -using mlir::Type; -using mlir::UnknownLoc; -using mlir::Value; -using tensorflow::OpOrArgLocNameMapper; -using tensorflow::OpOrArgNameMapper; -using tensorflow::Status; -using tflite::flex::IsWhitelistedFlexOp; -using xla::StatusOr; - -template -using BufferOffset = flatbuffers::Offset; - -template -using VectorBufferOffset = flatbuffers::Offset>; - -using CustomOptionsOffset = VectorBufferOffset; - -namespace error = tensorflow::error; -namespace tfl = mlir::TFL; - -ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; - -// Use initial buffer size in flatbuffer builder to be same as the initial size -// used by the TOCO export. (It does not explain rationale for this choice.) -constexpr size_t kInitialBufferSize = 10240; - -// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. -// Since tflite doesn't support unsigned for other types, returns error if -// `isSigned` is set to false for other types. -static StatusOr GetTFLiteType(Type type, - bool is_signed = true) { - if (!is_signed && type.isSignlessInteger(8)) { - return tflite::TensorType_UINT8; - } - if (!is_signed) { - return Status(error::INVALID_ARGUMENT, - "'isSigned' can only be set for 8-bits integer type"); - } - switch (type.getKind()) { - case mlir::StandardTypes::F32: - return tflite::TensorType_FLOAT32; - case mlir::StandardTypes::F16: - return tflite::TensorType_FLOAT16; - case mlir::TF::TensorFlowTypes::STRING: - return tflite::TensorType_STRING; - case mlir::TF::TensorFlowTypes::QUINT8: - return tflite::TensorType_UINT8; - case mlir::StandardTypes::Complex: { - auto ftype = type.cast().getElementType(); - if (ftype && ftype.isF32()) { - return tflite::TensorType_COMPLEX64; - } - return Status(error::INVALID_ARGUMENT, "Unsupported type"); - } - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return tflite::TensorType_BOOL; - case 8: - return itype.isUnsigned() ? tflite::TensorType_UINT8 - : tflite::TensorType_INT8; - case 16: - return tflite::TensorType_INT16; - case 32: - return tflite::TensorType_INT32; - case 64: - return tflite::TensorType_INT64; - } - } - case mlir::quant::QuantizationTypes::UniformQuantized: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::TF::TensorFlowTypes::RESOURCE: { - // Treat tf.resource values as integer values in flatbuffer. - // TODO(b/146131919): Maybe need to have a detailed design for supporting - // other resource types beyonds hash table resources and resource - // variables. - return tflite::TensorType_INT32; - } - default: - // TFLite export fills FLOAT32 for unknown data types. Returning an error - // for now for safety and this could be revisited when required. - return Status(error::INVALID_ARGUMENT, "Unsupported type"); - } -} - -static bool IsConst(Operation* op) { - return isa(op) || isa(op) || - isa(op) || isa(op); -} - -template -static bool HasValidTFLiteType(Value value, T& error_handler) { - // None type is allowed to represent unspecified operands. - if (value.getType().isa()) return true; - - auto type = value.getType().dyn_cast(); - if (!type) { - if (auto op = value.getDefiningOp()) { - error_handler.emitError() - << '\'' << op << "' should produce value of tensor type instead of " - << value.getType(); - return false; - } - error_handler.emitError("expected tensor type, got ") << value.getType(); - return false; - } - - Type element_type = type.getElementType(); - auto status = GetTFLiteType(element_type); - if (!status.ok()) { - return error_handler.emitError( - formatv("Failed to convert element type '{0}': {1}", - element_type, status.status().error_message())), - false; - } - return true; -} - -// Returns true if the module holds all the invariants expected by the -// Translator class. -// TODO(hinsu): Now that translation is done by making a single pass over the -// MLIR module, consider inlining these validation checks at the place where -// these invariants are assumed instead of checking upfront. -static bool IsValidTFLiteMlirModule(ModuleOp module) { - MLIRContext* context = module.getContext(); - - // Verify that module has a function named main. - FuncOp main_fn = module.lookupSymbol("main"); - if (!main_fn) { - return emitError(UnknownLoc::get(context), - "should have a function named 'main'"), - false; - } - - for (auto fn : module.getOps()) { - if (fn.getBlocks().size() != 1) { - return fn.emitError("should have exactly one basic block"), false; - } - auto& bb = fn.getBlocks().front(); - - for (auto arg : bb.getArguments()) { - if (!HasValidTFLiteType(arg, fn)) - return fn.emitError("invalid TFLite type: ") << arg.getType(), false; - } - - // Verify that all operations except the terminator have exactly one - // result of type supported by TFLite. - for (auto& inst : bb) { - if (inst.isKnownTerminator()) break; - - for (auto result : inst.getResults()) { - if (!HasValidTFLiteType(result, inst)) - return fn.emitError("invalid TFLite type: ") << result.getType(), - false; - } - } - } - - return true; -} - -static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( - ::mlir::Operation* inst) { - // We pass empty string for the original node_def name since Flex runtime - // does not care about this being set correctly on node_def. There is no - // "easy" (see b/120948529) way yet to get this from MLIR inst. - auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( - inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); - if (!status_or_node_def.ok()) { - inst->emitOpError( - Twine("failed to obtain TensorFlow nodedef with status: " + - status_or_node_def.status().ToString())); - return {}; - } - return std::move(status_or_node_def.ValueOrDie()); -} - -// Converts a mlir padding StringRef to TfLitePadding. -// Returns llvm::None if conversion fails. -static Optional GetTflitePadding(Operation* inst, - llvm::StringRef padding) { - const tflite::Padding padding_attr = - std::move(llvm::StringSwitch(padding) - .Case("SAME", tflite::Padding_SAME) - .Case("VALID", tflite::Padding_VALID)); - if (padding_attr == tflite::Padding_SAME) { - return kTfLitePaddingSame; - } - if (padding_attr == tflite::Padding_VALID) { - return kTfLitePaddingValid; - } - - return inst->emitOpError() << "Invalid padding attribute: " << padding, - llvm::None; -} - -// Extracts TfLitePoolParams from a TFL custom op. -// Template parameter, TFLOp, should be a TFL custom op containing attributes -// generated from TfLitePoolParams. -// Returns llvm::None if conversion fails. -template -static Optional GetTflitePoolParams(Operation* inst, - TFLOp op) { - TfLitePoolParams pool_params; - pool_params.stride_height = op.stride_h().getSExtValue(); - pool_params.stride_width = op.stride_w().getSExtValue(); - pool_params.filter_height = op.filter_h().getSExtValue(); - pool_params.filter_width = op.filter_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - pool_params.padding = *padding; - pool_params.activation = kTfLiteActNone; - pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; - return pool_params; - } - - return llvm::None; -} - -namespace { - -// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. -class Translator { - public: - // Translates the given MLIR module into TFLite FlatBuffer format and returns - // the serialized output. Returns llvm::None on unsupported, invalid inputs or - // internal error. - static Optional Translate( - ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); - - private: - enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; - explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, - bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) - : module_(module), - name_mapper_(*op_or_arg_name_mapper), - builder_(kInitialBufferSize) { - // The first buffer must be empty according to the schema definition. - empty_buffer_ = tflite::CreateBuffer(builder_); - buffers_.push_back(empty_buffer_); - if (emit_builtin_tflite_ops) { - enabled_op_types_.emplace(OpType::kTfliteBuiltin); - } - if (emit_select_tf_ops) { - enabled_op_types_.emplace(OpType::kSelectTf); - } - if (emit_custom_ops) { - enabled_op_types_.emplace(OpType::kCustomOp); - } - tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); - tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); - } - - Optional TranslateInternal(); - - // Returns TFLite buffer populated with constant value if the operation is - // TFLite constant operation. Otherwise, returns an empty buffer. Emits error - // and returns llvm::None on failure. - Optional> BuildBuffer(Operation* inst); - - // Build TFLite tensor from the given type. This function is for tfl.lstm - // intermediates, which should have UniformQuantizedType. - Optional> BuildTensorFromType( - mlir::Type type, const std::string& name); - - // Builds TFLite tensor from the given value. `buffer_idx` is index of the - // corresponding buffer. Emits error and returns llvm::None on failure. - Optional> BuildTensor(Value value, - const std::string& name, - unsigned buffer_idx); - - // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove - // these 2 functions here. - BufferOffset BuildIfOperator( - mlir::TF::IfOp op, const std::vector& operands, - const std::vector& results); - BufferOffset BuildWhileOperator( - mlir::TF::WhileOp op, const std::vector& operands, - const std::vector& results); - - // Build while operator where cond & body are regions. - Optional> BuildWhileOperator( - mlir::TFL::WhileOp op, const std::vector& operands, - const std::vector& results); - - // Builds custom operators. - // Templated on a) data type of custom_option to be stored into flatbuffer, - // and b) TFL custom op type. - template - BufferOffset BuildCustomOperator( - const CustomOptionType& custom_option, const std::string& opcode_name, - TFLOp op, const std::vector& operands, - const std::vector& results); - - BufferOffset BuildNumericVerifyOperator( - mlir::TFL::NumericVerifyOp op, const std::vector& operands, - const std::vector& results); - Optional> - BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxUnpooling2DOperator( - Operation* inst, mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results); - - Optional CreateFlexOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - Optional CreateCustomOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - std::unique_ptr CreateFlexBuilderWithNodeAttrs( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - // Returns opcode index for op identified by the op_name, if already - // available. Otherwise, creates a new OperatorCode using the given `builtin` - // operator and associates it with `op_name`. - uint32_t GetOpcodeIndex(const std::string& op_name, - tflite::BuiltinOperator builtin); - - // Builds operator for the given operation with specified operand and result - // tensor indices. Emits an error and returns llvm::None on failure. - Optional> BuildOperator( - Operation* inst, const std::vector& operands, - const std::vector& results, - const std::vector& intermediates); - - // Build a subgraph with a given name out of the region either corresponding - // to a function's body or while op. - Optional> BuildSubGraph( - const std::string& name, Region* region); - - // Builds Metadata with the given `name` and buffer `content`. - BufferOffset BuildMetadata(StringRef name, - StringRef content); - - // Encodes the `tfl.metadata` dictionary attribute of the module to the - // metadata section in the final model. - Optional>> - CreateMetadataVector(); - - // Uses the tf.entry_function attribute (if set) to initialize the op to name - // mapping. - void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); - - // Determines if the specified operation op's operand at operand_index - // is marked as a stateful operand. - bool IsStatefulOperand(mlir::Operation* op, int operand_index); - - // Returns a unique name for `val`. - std::string UniqueName(mlir::Value val); - - ModuleOp module_; - - tensorflow::OpOrArgNameMapper& name_mapper_; - - flatbuffers::FlatBufferBuilder builder_; - BufferOffset empty_buffer_; - - std::vector> buffers_; - - // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. - absl::flat_hash_map opcode_index_map_; - std::vector> opcodes_; - - // Maps function name to index of the corresponding subgraph in the FlatBuffer - // model. - absl::flat_hash_map subgraph_index_map_; - absl::flat_hash_set enabled_op_types_; - - // Points to TensorFlow and TFLite dialects, respectively. nullptr if the - // dialect is not registered. - const Dialect* tf_dialect_; - const Dialect* tfl_dialect_; - - // The failed ops during legalization. - std::set failed_flex_ops_; - std::set failed_custom_ops_; -}; - -std::string Translator::UniqueName(mlir::Value val) { - return std::string(name_mapper_.GetUniqueName(val)); -} - -Optional> Translator::BuildBuffer( - Operation* inst) { - ElementsAttr attr; - if (auto cst = dyn_cast(inst)) { - // ConstantOp have ElementAttr at this point due to validation of the TFLite - // module. - attr = cst.getValue().cast(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else { - return empty_buffer_; - } - - tensorflow::Tensor tensor; - auto status = tensorflow::ConvertToTensor(attr, &tensor); - if (!status.ok()) { - inst->emitError( - Twine("failed to convert value attribute to tensor with error: " + - status.ToString())); - return llvm::None; - } - - // TensorFlow and TensorFlow Lite use different string encoding formats. - // Convert to TensorFlow Lite format is it's a constant string tensor. - if (tensor.dtype() == tensorflow::DT_STRING) { - ::tflite::DynamicBuffer dynamic_buffer; - auto flat = tensor.flat<::tensorflow::tstring>(); - for (int i = 0; i < flat.size(); ++i) { - const auto& str = flat(i); - dynamic_buffer.AddString(str.c_str(), str.length()); - } - char* tensor_buffer; - int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); - auto buffer_data = - builder_.CreateVector(reinterpret_cast(tensor_buffer), bytes); - free(tensor_buffer); - return tflite::CreateBuffer(builder_, buffer_data); - } - - absl::string_view tensor_data = tensor.tensor_data(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(tensor_data.data()), tensor_data.size()); - return tflite::CreateBuffer(builder_, buffer_data); -} - -Optional> Translator::BuildTensorFromType( - mlir::Type type, const std::string& name) { - auto tensor_type = type.cast(); - - if (!tensor_type.hasStaticShape()) { - return llvm::None; - } - llvm::ArrayRef shape_ref = tensor_type.getShape(); - std::vector shape(shape_ref.begin(), shape_ref.end()); - - auto element_type = tensor_type.getElementType(); - tflite::TensorType tflite_element_type = - GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); - BufferOffset q_params; - auto qtype = element_type.dyn_cast(); - if (!qtype) { - return llvm::None; - } - q_params = tflite::CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, - builder_.CreateVector({static_cast(qtype.getScale())}), - builder_.CreateVector({qtype.getZeroPoint()})); - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - /*buffer=*/0, builder_.CreateString(name), q_params, - /*is_variable=*/false); -} - -Optional> Translator::BuildTensor( - Value value, const std::string& name, unsigned buffer_idx) { - auto type = value.getType().cast(); - - // TFLite requires tensor shape only for the inputs and constants. - // However, we output all known shapes for better round-tripping - auto check_shape = - [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { - auto is_out_of_range = [](int64_t dim) { - return dim > std::numeric_limits::max(); - }; - - if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) - return mlir::emitError( - value.getLoc(), - "result shape dimensions out of 32 bit int type range"); - - return mlir::success(); - }; - - std::vector shape; - std::vector shape_signature; - if (type.hasStaticShape()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (auto* inst = value.getDefiningOp()) { - if (IsConst(inst)) { - // Const op can have a result of dynamic shaped type (e.g. due to constant - // folding), but we can still derive the shape of a constant tensor for - // its attribute type. - mlir::Attribute tensor_attr = inst->getAttr("value"); - llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } - } else if (type.hasRank()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape.reserve(shape_ref.size()); - for (auto& dim : shape_ref) { - shape.push_back(dim == -1 ? 1 : dim); - } - shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); - } - - if (auto* inst = value.getDefiningOp()) { - if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); - } else if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); - } - } - - Type element_type = type.getElementType(); - tflite::TensorType tflite_element_type = - GetTFLiteType(type.getElementType()).ValueOrDie(); - - BufferOffset q_params; - if (auto qtype = element_type.dyn_cast()) { - q_params = tflite::CreateQuantizationParameters( - // TODO(fengliuai): min and max values are not stored in the - // quantized type, so both are set to 0. The model couldn't be imported - // to TensorFlow because of this. - builder_, /*min=*/0, /*max=*/0, - builder_.CreateVector({static_cast(qtype.getScale())}), - builder_.CreateVector({qtype.getZeroPoint()})); - } else if (auto qtype = - element_type - .dyn_cast()) { - std::vector scales(qtype.getScales().begin(), - qtype.getScales().end()); - q_params = tflite::CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), - builder_.CreateVector(qtype.getZeroPoints()), - tflite::QuantizationDetails_NONE, /*details=*/0, - qtype.getQuantizedDimension()); - } else { - q_params = tflite::CreateQuantizationParameters(builder_); - } - // Check if the value's uses includes an op and usage at an operand index - // marked as a stateful. If so, set the tensor's is_variable as true - // This is v1 ref variable semantics in the TFLite runtime. - bool is_variable = false; - for (auto& use : value.getUses()) { - is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); - if (is_variable) { - break; - } - } - - if (shape_signature.empty()) { - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable); - } else { - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable, /*sparsity=*/0, - /*shape_signature=*/builder_.CreateVector(shape_signature)); - } -} - -BufferOffset Translator::BuildIfOperator( - mlir::TF::IfOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); - int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); - int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); - auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, - else_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_IfOptions, - builtin_options); -} - -BufferOffset Translator::BuildWhileOperator( - mlir::TF::WhileOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); - int body_subgraph_index = subgraph_index_map_.at(op.body().str()); - auto builtin_options = tflite::CreateWhileOptions( - builder_, cond_subgraph_index, body_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_WhileOptions, - builtin_options); -} - -Optional> Translator::BuildWhileOperator( - mlir::TFL::WhileOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - auto get_call_index = [&](mlir::Block& b) -> Optional { - if (b.getOperations().size() != 2) return llvm::None; - if (auto call_op = dyn_cast(b.front())) - return subgraph_index_map_.at(call_op.callee().str()); - return llvm::None; - }; - auto body_subgraph_index = get_call_index(op.body().front()); - auto cond_subgraph_index = get_call_index(op.cond().front()); - if (!body_subgraph_index || !cond_subgraph_index) - return op.emitOpError("only single call cond/body while export supported"), - llvm::None; - auto builtin_options = - tflite::CreateWhileOptions(builder_, *cond_subgraph_index, - *body_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_WhileOptions, - builtin_options); -} - -template -BufferOffset Translator::BuildCustomOperator( - const CustomOptionType& custom_option, const std::string& opcode_name, - TFLOp op, const std::vector& operands, - const std::vector& results) { - std::vector custom_option_vector(sizeof(CustomOptionType)); - memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType)); - auto opcode_index = - GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM); - return tflite::CreateOperator( - builder_, opcode_index, builder_.CreateVector(operands), - builder_.CreateVector(results), tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - builder_.CreateVector(custom_option_vector), - tflite::CustomOptionsFormat_FLEXBUFFERS); -} - -BufferOffset Translator::BuildNumericVerifyOperator( - mlir::TFL::NumericVerifyOp op, const std::vector& operands, - const std::vector& results) { - float tolerance = op.tolerance().convertToFloat(); - return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); -} - -Optional> -Translator::BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, const std::vector& results) { - TfLiteTransposeConvParams conv_params; - conv_params.stride_height = op.stride_h().getSExtValue(); - conv_params.stride_width = op.stride_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - conv_params.padding = *padding; - return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxUnpooling2DOperator(Operation* inst, - mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, - results); - } - - return llvm::None; -} - -Optional Translator::CreateFlexOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } - - auto flex_builder = absl::make_unique(); - flex_builder->Vector([&]() { - flex_builder->String(node_def.op()); - flex_builder->String(node_def_str); - }); - flex_builder->Finish(); - return builder_.CreateVector(flex_builder->GetBuffer()); -} - -Optional Translator::CreateCustomOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } - auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); - return builder_.CreateVector(flex_builder->GetBuffer()); -} - -std::unique_ptr -Translator::CreateFlexBuilderWithNodeAttrs( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - auto flex_builder = absl::make_unique(); - size_t map_start = flex_builder->StartMap(); - for (const auto& pair : node_def.attr()) { - const char* key = pair.first.c_str(); - const auto& attr = pair.second; - switch (attr.value_case()) { - case ::tensorflow::AttrValue::kS: - flex_builder->String(key, attr.s()); - break; - case ::tensorflow::AttrValue::kType: { - auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); - if (status_or_tfl_type.ok()) { - flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); - } else { - emitWarning(loc, "ignoring unsupported tensorflow type: ") - << std::to_string(attr.type()); - } - break; - } - case ::tensorflow::AttrValue::kI: - flex_builder->Int(key, attr.i()); - break; - case ::tensorflow::AttrValue::kF: - flex_builder->Float(key, attr.f()); - break; - case ::tensorflow::AttrValue::kB: - flex_builder->Bool(key, attr.b()); - break; - case tensorflow::AttrValue::kList: - if (attr.list().s_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const std::string& v : attr.list().s()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else if (attr.list().i_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const int64_t v : attr.list().i()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else if (attr.list().f_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const float v : attr.list().f()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else { - emitWarning(loc, - "ignoring unsupported type in list attribute with key: ") - << key; - } - break; - default: - emitWarning(loc, "ignoring unsupported attribute type with key: ") - << key; - break; - } - } - flex_builder->EndMap(map_start); - flex_builder->Finish(); - return flex_builder; -} - -uint32_t Translator::GetOpcodeIndex(const std::string& op_name, - tflite::BuiltinOperator builtin) { - auto it = opcode_index_map_.insert({op_name, 0}); - - // If the insert succeeded, the opcode has not been created already. Create a - // new operator code and update its index value in the map. - if (it.second) { - it.first->second = opcodes_.size(); - auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM - ? builder_.CreateString(op_name) - : BufferOffset(); - // Use version 0 for builtin op. This is a way to serialize version field to - // flatbuffer (since 0 is non default) and it will be corrected later. - int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; - opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, - custom_code, op_version)); - } - return it.first->second; -} - -Optional> Translator::BuildOperator( - Operation* inst, const std::vector& operands, - const std::vector& results, - const std::vector& intermediates) { - const auto* dialect = inst->getDialect(); - if (!dialect) { - inst->emitOpError("dialect is not registered"); - return llvm::None; - } - - // If TFLite built in op, create operator as a builtin op. - if (dialect == tfl_dialect_) { - // Only if built-in TFLite op emission is enabled, would legalization have - // converted any TF->TFL. - if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { - return inst->emitOpError( - "is a TFLite builtin op but builtin emission is not enabled"), - llvm::None; - } - - auto builtin_code = GetBuiltinOpCode(inst); - if (!builtin_code) { - if (auto verify_op = dyn_cast(inst)) { - return BuildNumericVerifyOperator(verify_op, operands, results); - } - if (auto conv_transpose_bias_op = - dyn_cast(inst)) { - return BuildConvolution2DTransposeBiasOperator( - inst, conv_transpose_bias_op, operands, results); - } - if (auto max_pooling_with_arg_max_op = - dyn_cast(inst)) { - return BuildMaxPoolingWithArgMax2DOperator( - inst, max_pooling_with_arg_max_op, operands, results); - } - if (auto max_unpooling_op = dyn_cast(inst)) { - return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, - results); - } - if (auto whileOp = dyn_cast(inst)) { - if (inst->getNumOperands() != inst->getNumResults()) { - inst->emitOpError( - "number of operands and results don't match, only canonical " - "TFL While supported"); - return llvm::None; - } - return BuildWhileOperator(whileOp, operands, results); - } - - inst->emitOpError("is not a supported TFLite op"); - return llvm::None; - } - - std::string op_name = inst->getName().getStringRef().str(); - uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); - auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, - results, intermediates, &builder_); - if (!offset) { - inst->emitOpError("is not a supported TFLite op"); - } - return offset; - } - - if (dialect == tf_dialect_) { - std::string op_name; - if (auto ifOp = dyn_cast(inst)) { - return BuildIfOperator(ifOp, operands, results); - } else if (auto whileOp = dyn_cast(inst)) { - return BuildWhileOperator(whileOp, operands, results); - } - - CustomOptionsOffset custom_options; - - // Ops in TF dialect can either be custom ops or flex ops. - // The reason we go directly from TensorFlow dialect MLIR to tensorflow - // node instead of going to TF table gen'd ops via generated code is that - // we do not want to restrict custom and flex op conversion support to - // only those TF ops that are currently registered in MLIR. The current - // model is of an open op system. - // - // The following algorithm is followed: - // if flex is enabled and the op is whitelisted as flex - // we emit op as flex. - // if custom is enabled - // we emit the op as custom. - auto node_def = GetTensorFlowNodeDef(inst); - if (!node_def) { - return llvm::None; - } - - // Flex op case - // Eventually, the whitelist will go away and we will rely on some TF op - // trait (e.g. No side effect) to determine if it is a supported "Flex" - // op or not. - if (enabled_op_types_.contains(OpType::kSelectTf) && - IsWhitelistedFlexOp(node_def->op())) { - // Construct ops as flex op encoding TensorFlow node definition - // as custom options. - // Flex ops are named with the kFlexOpNamePrefix prefix to the actual - // TF op name. - op_name = std::string(kFlexOpNamePrefix) + node_def->op(); - if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else if (enabled_op_types_.contains(OpType::kCustomOp)) { - // Generic case of custom ops - write using flex buffers since that - // is the only custom options supported by TFLite today. - op_name = node_def->op(); - if (auto options = - CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else { - // Create description of operation that could not be converted. - const int kLargeElementsAttr = 16; - std::string op_str; - llvm::raw_string_ostream os(op_str); - inst->getName().print(os); - // Print out attributes except for large elementsattributes (which should - // rarely be the cause why the legalization didn't happen). - if (!inst->getAttrList().getAttrs().empty()) { - os << " {"; - bool first = true; - for (auto& named_attr : inst->getAttrList().getDictionary()) { - os << (!first ? ", " : ""); - first = false; - named_attr.first.print(os); - os << " = "; - if (auto element_attr = named_attr.second.dyn_cast()) { - if (element_attr.getNumElements() <= kLargeElementsAttr) { - element_attr.print(os); - } else { - os << ""; - } - } else { - named_attr.second.print(os); - } - } - os << "}"; - } - - // Insert failed op to `flex_ops` or `custom_ops`. - if (IsWhitelistedFlexOp(node_def->op())) { - failed_flex_ops_.insert(os.str()); - } else { - failed_custom_ops_.insert(os.str()); - } - return inst->emitOpError("is neither a custom op nor a flex op"), - llvm::None; - } - - uint32_t opcode_index = - GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - /*custom_options=*/custom_options, - tflite::CustomOptionsFormat_FLEXBUFFERS, - /*mutating_variable_inputs=*/0); - } - - return inst->emitOpError( - "is not any of a builtin TFLite op, a flex TensorFlow op or a " - "custom TensorFlow op"), - llvm::None; -} - -void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { - auto dict_attr = fn.getAttrOfType("tf.entry_function"); - if (!dict_attr) return; - - llvm::SmallVector input_names; - llvm::SmallVector output_names; - if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { - str.getValue().split(input_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - if (input_names.size() != fn.getNumArguments()) { - fn.emitWarning() << "invalid entry function specification"; - return; - } - for (auto it : llvm::enumerate(fn.getArguments())) { - name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); - } - *has_input_attr = true; - } - - if (auto str = - dict_attr.get("outputs").dyn_cast_or_null()) { - str.getValue().split(output_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - auto term = fn.getBlocks().back().getTerminator(); - if (output_names.size() != term->getNumOperands()) { - fn.emitWarning() << "output names (" << output_names.size() - << ") != terminator operands (" << term->getNumOperands() - << ")"; - return; - } - for (const auto& it : llvm::enumerate(term->getOperands())) { - name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); - } - } -} - -bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { - std::vector operand_indices; - if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; - return absl::c_find(operand_indices, operand_index) != operand_indices.end(); -} - -Optional> Translator::BuildSubGraph( - const std::string& name, Region* region) { - bool has_input_attr = false; - if (auto fn = dyn_cast(region->getParentOp())) { - InitializeNamesFromAttribute(fn, &has_input_attr); - } - std::vector> tensors; - llvm::DenseMap tensor_index_map; - - // Builds tensor and buffer for argument or operation result. Returns false - // on failure. - auto build_tensor_and_buffer = [&](Value value, const std::string& name) { - // NoneType represents optional and may be skipped here. - if (value.getType().isa()) { - return true; - } - - tensor_index_map.insert({value, tensors.size()}); - auto tensor_or = BuildTensor(value, name, buffers_.size()); - if (!tensor_or) return false; - tensors.push_back(*tensor_or); - - // TODO(ashwinm): Check if for stateful tensors, if it is also needed to - // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. - // This does not seem to affect runtime behavior for RNN/LSTM, but would be - // good for reducing memory footprint. - if (auto* inst = value.getDefiningOp()) { - auto buffer_or = BuildBuffer(inst); - if (!buffer_or) return false; - buffers_.push_back(*buffer_or); - } else { - buffers_.push_back(empty_buffer_); - } - return true; - }; - - std::vector> operators; - auto& bb = region->front(); - - // Main function's arguments are first passed to `input` op so they don't - // have associated tensor and buffer. Build FlatBuffer tensor and buffer for - // other functions. - for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { - mlir::BlockArgument arg = bb.getArgument(i); - std::string name; - if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg)); - if (name.empty()) name = absl::StrCat("arg", i); - if (!build_tensor_and_buffer(arg, name)) return llvm::None; - } - - bool failed_once = false; - for (auto& inst : bb) { - if (inst.isKnownTerminator()) break; - std::vector intermediates; - // Build intermediate tensors for tfl.lstm and insert these tensors into - // flatbuffer. - if (llvm::isa(inst)) { - std::vector intermediate_names = { - "input_to_input_intermediate", "input_to_forget_intermediate", - "input_to_cell_intermediate", "input_to_output_intermediate", - "effective_hidden_scale_intermediate"}; - for (const std::string& intermediate : intermediate_names) { - auto intermediate_attr = inst.getAttr(intermediate); - if (auto attr = intermediate_attr.dyn_cast_or_null()) { - Type qtype = attr.getValue(); - auto tensor_or = BuildTensorFromType( - qtype, name_mapper_.GetUniqueName(intermediate).str()); - if (!tensor_or.hasValue()) { - continue; - } else { - intermediates.push_back(tensors.size()); - tensors.push_back(tensor_or.getValue()); - } - } - } - } - - for (auto val : inst.getResults()) { - std::string name = UniqueName(val); - if (!build_tensor_and_buffer(val, name)) return llvm::None; - } - - // Skip constant ops as they don't represent a TFLite operator. - if (IsConst(&inst)) continue; - - // Fetch operand and result tensor indices. - std::vector operands; - operands.reserve(inst.getNumOperands()); - for (auto operand : inst.getOperands()) { - if (operand.getType().isa()) - operands.push_back(kTfLiteOptionalTensor); - else - operands.push_back(tensor_index_map.lookup(operand)); - } - std::vector results; - results.reserve(inst.getNumOperands()); - for (auto result : inst.getResults()) { - results.push_back(tensor_index_map.lookup(result)); - } - - if (auto tfl_operator = - BuildOperator(&inst, operands, results, intermediates)) - operators.push_back(*tfl_operator); - else - failed_once = true; - } - - if (failed_once) return llvm::None; - - // Get input and output tensor indices for the subgraph. - std::vector inputs, outputs; - for (auto arg : bb.getArguments()) { - inputs.push_back(tensor_index_map[arg]); - } - for (auto result : bb.getTerminator()->getOperands()) { - outputs.push_back(tensor_index_map[result]); - } - - return tflite::CreateSubGraph( - builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), - builder_.CreateVector(outputs), builder_.CreateVector(operators), - /*name=*/builder_.CreateString(name)); -} - -BufferOffset Translator::BuildMetadata(StringRef name, - StringRef content) { - auto buffer_index = buffers_.size(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(content.data()), content.size()); - buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); - return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); -} - -Optional>> -Translator::CreateMetadataVector() { - auto dict_attr = module_.getAttrOfType("tfl.metadata"); - std::vector> metadata; - if (dict_attr) { - for (const auto& named_attr : dict_attr) { - StringRef name = named_attr.first; - mlir::Attribute attr = named_attr.second; - if (auto content = attr.dyn_cast()) { - metadata.push_back(BuildMetadata(name, content.getValue())); - } else { - module_.emitError( - "all values in tfl.metadata's dictionary key-value pairs should be " - "string attributes"); - return llvm::None; - } - } - } - // Runtime version string is generated after we update the op - // versions. Here we put a 16-byte dummy string as a placeholder. We choose - // 16-byte because it's the alignment of buffers in flatbuffer, so it won't - // cause any waste of space if the actual string is shorter than 16 bytes. - metadata.push_back( - BuildMetadata("min_runtime_version", std::string(16, '\0'))); - return builder_.CreateVector(metadata); -} - -bool UpdateEntryFunction(ModuleOp module) { - if (module.lookupSymbol("main") != nullptr) { - // We already have an entry function. - return true; - } - - int entry_func_count = 0; - FuncOp entry_func = nullptr; - for (auto fn : module.getOps()) { - auto attrs = fn.getAttrOfType("tf.entry_function"); - if (attrs && !attrs.empty()) { - entry_func_count++; - entry_func = fn; - } - } - - // We should have one & only have one entry function. - if (entry_func_count != 1) return false; - - // Update the entry func to main. - entry_func.setName("main"); - return true; -} - -Optional Translator::Translate( - ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { - if (!UpdateEntryFunction(module)) return llvm::None; - if (!IsValidTFLiteMlirModule(module)) return llvm::None; - Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - return translator.TranslateInternal(); -} - -Optional Translator::TranslateInternal() { - // A list of named regions in the module with main function being the first in - // the list. The main function is required as the first subgraph in the model - // is entry point for the model. - std::vector> named_regions; - named_regions.reserve(std::distance(module_.begin(), module_.end())); - - int subgraph_idx = 0; - FuncOp main_fn = module_.lookupSymbol("main"); - subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back("main", &main_fn.getBody()); - // Walk over the module collection ops with functions and while ops. - module_.walk([&](FuncOp fn) { - if (fn != main_fn) { - subgraph_index_map_[fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back(fn.getName().str(), &fn.getBody()); - } - }); - - // Build subgraph for each of the named regions. - std::vector> subgraphs; - subgraphs.reserve(named_regions.size()); - int first_failed_func = -1; - for (auto it : llvm::enumerate(named_regions)) { - auto subgraph_or = BuildSubGraph(it.value().first, it.value().second); - if (!subgraph_or) { - if (first_failed_func == -1) - // Record the index of the first region that cannot be converted. - // Keep looping through all subgraphs in the module to make sure that - // we collect the list of missing ops from the entire module. - first_failed_func = it.index(); - } else { - subgraphs.push_back(*subgraph_or); - } - } - - if (first_failed_func != -1) { - std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); - std::string failed_custom_ops_list = - absl::StrJoin(failed_custom_ops_, "\n\t"); - std::string err; - if (!failed_flex_ops_list.empty()) - err += - "Ops that can be supported by the flex runtime (enabled via setting " - "the -emit-select-tf-ops flag):\n\t" + - failed_flex_ops_list; - if (!failed_custom_ops_list.empty()) - err += - "Ops that need custom implementation (enabled via setting the " - "-emit-custom-ops flag):\n\t" + - failed_custom_ops_list; - - auto& failed_region = named_regions[first_failed_func]; - return failed_region.second->getParentOp()->emitError() - << "failed while converting: '" << failed_region.first - << "': " << err, - llvm::None; - } - - std::string model_description; - if (auto attr = module_.getAttrOfType("tfl.description")) { - model_description = attr.getValue().str(); - } else { - model_description = "MLIR Converted."; - } - - // Build the model and finish the model building process. - auto description = builder_.CreateString(model_description.data()); - VectorBufferOffset metadata_buffer = 0; // Deprecated - auto metadata = CreateMetadataVector(); - if (!metadata) return llvm::None; - - auto model = tflite::CreateModel( - builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), - builder_.CreateVector(subgraphs), description, - builder_.CreateVector(buffers_), metadata_buffer, *metadata); - tflite::FinishModelBuffer(builder_, model); - tflite::UpdateOpVersion(builder_.GetBufferPointer()); - tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); - - // Return serialized string for the built FlatBuffer. - return std::string(reinterpret_cast(builder_.GetBufferPointer()), - builder_.GetSize()); -} - -} // namespace - -// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer -// format. Returns false on success. -// -// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting -// the following: -// -// * Quantization -// * Ops with variable tensors -// -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) { - auto maybe_translated = - Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - if (!maybe_translated) return true; - *serialized_flatbuffer = std::move(*maybe_translated); - return false; -} - -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops) { - OpOrArgLocNameMapper op_or_arg_name_mapper; - return MlirToFlatBufferTranslateFunction( - module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); -} diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 1eec402d35a..4b888764053 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -63,16 +63,20 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Translation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -96,6 +100,45 @@ using xla::StatusOr; namespace errors = tensorflow::errors; namespace tfl = mlir::TFL; +using llvm::cl::opt; + +// Commandline flag to enable the control of flatbuffer import. +bool use_external_constant; + +// Commandline flag to enable graph pruning. +bool experimental_prune_unreachable_nodes_unconditionally; + +// NOLINTNEXTLINE +static opt use_external_constant_flag( + "use-external-constant", + llvm::cl::desc("Use external constant during flatbuffer import"), + llvm::cl::location(use_external_constant), llvm::cl::init(false)); + +// TODO(b/147111261): After the importer supports generic custom ops, we should +// change the flag to a more lightwise flag, e.g. +// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune +// the operations. +// NOLINTNEXTLINE +static opt experimental_prune_unreachable_nodes_unconditionally_flg( + "experimental-prune-unreachable-nodes-unconditionally", + llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), + llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static opt input_arrays_flag( + "input-arrays", + llvm::cl::desc( + "List of input tensors, if different from the default inputs"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +static opt output_arrays_flag( + "output-arrays", + llvm::cl::desc( + "List of output tensors, if different from the default outputs"), + llvm::cl::init("")); + namespace { bool IsScalar(const TensorT& tensor) { // TODO(b/138222071) We can't distinguish scalars and unranked tensors @@ -1020,3 +1063,42 @@ OwningModuleRef tflite::FlatBufferToMlir( return OwningModuleRef(module); } + +static OwningModuleRef FlatBufferFileToMlirTrans( + llvm::SourceMgr* source_mgr, MLIRContext* context, + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { + const llvm::MemoryBuffer* input = + source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); + std::string error; + auto loc = + mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); + + // Parses input/output names from command line options. + std::vector inputs; + std::vector outputs; + // Use output parser since we only have tensor names. + if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) { + return emitError(loc, "parsing input array info failed ") + << input_arrays_flag, + nullptr; + } + if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) { + return emitError(loc, "parsing output array info failed ") + << output_arrays_flag, + nullptr; + } + + return tflite::FlatBufferToMlir( + absl::string_view(input->getBufferStart(), input->getBufferSize()), + context, loc, use_external_constant, inputs, outputs, + experimental_prune_unreachable_nodes_unconditionally); +} + +static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( + "tflite-flatbuffer-to-mlir", + [](llvm::SourceMgr& source_mgr, MLIRContext* context) { + return FlatBufferFileToMlirTrans( + &source_mgr, context, use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); + }); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index ee7ac81dce9..e8337d4a79f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -13,6 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" @@ -31,48 +56,67 @@ limitations under the License. #include "mlir/IR/Value.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "mlir/Translation.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" +#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/tools/versioning/op_version.h" +#include "tensorflow/lite/tools/versioning/runtime_version.h" +#include "tensorflow/lite/version.h" -using llvm::cl::opt; +using llvm::dyn_cast; +using llvm::formatv; +using llvm::isa; +using llvm::Optional; +using llvm::StringRef; +using llvm::Twine; +using mlir::Dialect; +using mlir::ElementsAttr; +using mlir::FuncOp; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::NoneType; +using mlir::Operation; +using mlir::Region; +using mlir::StringAttr; +using mlir::TensorType; +using mlir::TranslateFromMLIRRegistration; +using mlir::Type; +using mlir::UnknownLoc; +using mlir::Value; +using tensorflow::OpOrArgLocNameMapper; +using tensorflow::OpOrArgNameMapper; +using tensorflow::Status; +using tflite::flex::IsWhitelistedFlexOp; +using xla::StatusOr; -// Commandline flag to enable the control of flatbuffer import. -bool use_external_constant; +template +using BufferOffset = flatbuffers::Offset; -// Commandline flag to enable graph pruning. -bool experimental_prune_unreachable_nodes_unconditionally; +template +using VectorBufferOffset = flatbuffers::Offset>; -// NOLINTNEXTLINE -static opt use_external_constant_flag( - "use-external-constant", - llvm::cl::desc("Use external constant during flatbuffer import"), - llvm::cl::location(use_external_constant), llvm::cl::init(false)); +using CustomOptionsOffset = VectorBufferOffset; -// TODO(b/147111261): After the importer supports generic custom ops, we should -// change the flag to a more lightwise flag, e.g. -// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune -// the operations. -// NOLINTNEXTLINE -static opt experimental_prune_unreachable_nodes_unconditionally_flg( - "experimental-prune-unreachable-nodes-unconditionally", - llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), - llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), - llvm::cl::init(false)); +namespace error = tensorflow::error; +namespace tfl = mlir::TFL; -// NOLINTNEXTLINE -static opt input_arrays_flag( - "input-arrays", - llvm::cl::desc( - "List of input tensors, if different from the default inputs"), - llvm::cl::init("")); - -// NOLINTNEXTLINE -static opt output_arrays_flag( - "output-arrays", - llvm::cl::desc( - "List of output tensors, if different from the default outputs"), - llvm::cl::init("")); using llvm::cl::opt; // These command line flags enable control of the translation implementation. @@ -113,48 +157,1353 @@ static opt strip_debug_info_flag( "strip-debug-info", llvm::cl::desc("Strip debug info during export"), llvm::cl::location(strip_debug_info), llvm::cl::init(false)); -namespace mlir { -namespace { -static OwningModuleRef FlatBufferFileToMlirTrans( - llvm::SourceMgr* source_mgr, MLIRContext* context, - bool use_external_constant, - bool experimental_prune_unreachable_nodes_unconditionally) { - const llvm::MemoryBuffer* input = - source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); - std::string error; - auto loc = - mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); +ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; - // Parses input/output names from command line options. - std::vector inputs; - std::vector outputs; - // Use output parser since we only have tensor names. - if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) { - return emitError(loc, "parsing input array info failed ") - << input_arrays_flag, - nullptr; +// Use initial buffer size in flatbuffer builder to be same as the initial size +// used by the TOCO export. (It does not explain rationale for this choice.) +constexpr size_t kInitialBufferSize = 10240; + +// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. +// Since tflite doesn't support unsigned for other types, returns error if +// `isSigned` is set to false for other types. +static StatusOr GetTFLiteType(Type type, + bool is_signed = true) { + if (!is_signed && type.isSignlessInteger(8)) { + return tflite::TensorType_UINT8; } - if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) { - return emitError(loc, "parsing output array info failed ") - << output_arrays_flag, - nullptr; + if (!is_signed) { + return Status(error::INVALID_ARGUMENT, + "'isSigned' can only be set for 8-bits integer type"); + } + switch (type.getKind()) { + case mlir::StandardTypes::F32: + return tflite::TensorType_FLOAT32; + case mlir::StandardTypes::F16: + return tflite::TensorType_FLOAT16; + case mlir::TF::TensorFlowTypes::STRING: + return tflite::TensorType_STRING; + case mlir::TF::TensorFlowTypes::QUINT8: + return tflite::TensorType_UINT8; + case mlir::StandardTypes::Complex: { + auto ftype = type.cast().getElementType(); + if (ftype && ftype.isF32()) { + return tflite::TensorType_COMPLEX64; + } + return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } + case mlir::StandardTypes::Integer: { + const auto& itype = type.cast(); + switch (itype.getWidth()) { + case 1: + return tflite::TensorType_BOOL; + case 8: + return itype.isUnsigned() ? tflite::TensorType_UINT8 + : tflite::TensorType_INT8; + case 16: + return tflite::TensorType_INT16; + case 32: + return tflite::TensorType_INT32; + case 64: + return tflite::TensorType_INT64; + } + } + case mlir::quant::QuantizationTypes::UniformQuantized: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } + case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } + case mlir::TF::TensorFlowTypes::RESOURCE: { + // Treat tf.resource values as integer values in flatbuffer. + // TODO(b/146131919): Maybe need to have a detailed design for supporting + // other resource types beyonds hash table resources and resource + // variables. + return tflite::TensorType_INT32; + } + default: + // TFLite export fills FLOAT32 for unknown data types. Returning an error + // for now for safety and this could be revisited when required. + return Status(error::INVALID_ARGUMENT, "Unsupported type"); } - return tflite::FlatBufferToMlir( - absl::string_view(input->getBufferStart(), input->getBufferSize()), - context, loc, use_external_constant, inputs, outputs, - experimental_prune_unreachable_nodes_unconditionally); } -static LogicalResult MlirToFlatBufferFileTranslateFunction( +static bool IsConst(Operation* op) { + return isa(op) || isa(op) || + isa(op) || isa(op); +} + +template +static bool HasValidTFLiteType(Value value, T& error_handler) { + // None type is allowed to represent unspecified operands. + if (value.getType().isa()) return true; + + auto type = value.getType().dyn_cast(); + if (!type) { + if (auto op = value.getDefiningOp()) { + error_handler.emitError() + << '\'' << op << "' should produce value of tensor type instead of " + << value.getType(); + return false; + } + error_handler.emitError("expected tensor type, got ") << value.getType(); + return false; + } + + Type element_type = type.getElementType(); + auto status = GetTFLiteType(element_type); + if (!status.ok()) { + return error_handler.emitError( + formatv("Failed to convert element type '{0}': {1}", + element_type, status.status().error_message())), + false; + } + return true; +} + +// Returns true if the module holds all the invariants expected by the +// Translator class. +// TODO(hinsu): Now that translation is done by making a single pass over the +// MLIR module, consider inlining these validation checks at the place where +// these invariants are assumed instead of checking upfront. +static bool IsValidTFLiteMlirModule(ModuleOp module) { + MLIRContext* context = module.getContext(); + + // Verify that module has a function named main. + FuncOp main_fn = module.lookupSymbol("main"); + if (!main_fn) { + return emitError(UnknownLoc::get(context), + "should have a function named 'main'"), + false; + } + + for (auto fn : module.getOps()) { + if (fn.getBlocks().size() != 1) { + return fn.emitError("should have exactly one basic block"), false; + } + auto& bb = fn.getBlocks().front(); + + for (auto arg : bb.getArguments()) { + if (!HasValidTFLiteType(arg, fn)) + return fn.emitError("invalid TFLite type: ") << arg.getType(), false; + } + + // Verify that all operations except the terminator have exactly one + // result of type supported by TFLite. + for (auto& inst : bb) { + if (inst.isKnownTerminator()) break; + + for (auto result : inst.getResults()) { + if (!HasValidTFLiteType(result, inst)) + return fn.emitError("invalid TFLite type: ") << result.getType(), + false; + } + } + } + + return true; +} + +static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( + ::mlir::Operation* inst) { + // We pass empty string for the original node_def name since Flex runtime + // does not care about this being set correctly on node_def. There is no + // "easy" (see b/120948529) way yet to get this from MLIR inst. + auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( + inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); + if (!status_or_node_def.ok()) { + inst->emitOpError( + Twine("failed to obtain TensorFlow nodedef with status: " + + status_or_node_def.status().ToString())); + return {}; + } + return std::move(status_or_node_def.ValueOrDie()); +} + +// Converts a mlir padding StringRef to TfLitePadding. +// Returns llvm::None if conversion fails. +static Optional GetTflitePadding(Operation* inst, + llvm::StringRef padding) { + const tflite::Padding padding_attr = + std::move(llvm::StringSwitch(padding) + .Case("SAME", tflite::Padding_SAME) + .Case("VALID", tflite::Padding_VALID)); + if (padding_attr == tflite::Padding_SAME) { + return kTfLitePaddingSame; + } + if (padding_attr == tflite::Padding_VALID) { + return kTfLitePaddingValid; + } + + return inst->emitOpError() << "Invalid padding attribute: " << padding, + llvm::None; +} + +// Extracts TfLitePoolParams from a TFL custom op. +// Template parameter, TFLOp, should be a TFL custom op containing attributes +// generated from TfLitePoolParams. +// Returns llvm::None if conversion fails. +template +static Optional GetTflitePoolParams(Operation* inst, + TFLOp op) { + TfLitePoolParams pool_params; + pool_params.stride_height = op.stride_h().getSExtValue(); + pool_params.stride_width = op.stride_w().getSExtValue(); + pool_params.filter_height = op.filter_h().getSExtValue(); + pool_params.filter_width = op.filter_w().getSExtValue(); + const auto padding = GetTflitePadding(inst, op.padding()); + if (padding) { + pool_params.padding = *padding; + pool_params.activation = kTfLiteActNone; + pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; + return pool_params; + } + + return llvm::None; +} + +namespace { + +// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. +class Translator { + public: + // Translates the given MLIR module into TFLite FlatBuffer format and returns + // the serialized output. Returns llvm::None on unsupported, invalid inputs or + // internal error. + static Optional Translate( + ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); + + private: + enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; + explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, + bool emit_select_tf_ops, bool emit_custom_ops, + OpOrArgNameMapper* op_or_arg_name_mapper) + : module_(module), + name_mapper_(*op_or_arg_name_mapper), + builder_(kInitialBufferSize) { + // The first buffer must be empty according to the schema definition. + empty_buffer_ = tflite::CreateBuffer(builder_); + buffers_.push_back(empty_buffer_); + if (emit_builtin_tflite_ops) { + enabled_op_types_.emplace(OpType::kTfliteBuiltin); + } + if (emit_select_tf_ops) { + enabled_op_types_.emplace(OpType::kSelectTf); + } + if (emit_custom_ops) { + enabled_op_types_.emplace(OpType::kCustomOp); + } + tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); + tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); + } + + Optional TranslateInternal(); + + // Returns TFLite buffer populated with constant value if the operation is + // TFLite constant operation. Otherwise, returns an empty buffer. Emits error + // and returns llvm::None on failure. + Optional> BuildBuffer(Operation* inst); + + // Build TFLite tensor from the given type. This function is for tfl.lstm + // intermediates, which should have UniformQuantizedType. + Optional> BuildTensorFromType( + mlir::Type type, const std::string& name); + + // Builds TFLite tensor from the given value. `buffer_idx` is index of the + // corresponding buffer. Emits error and returns llvm::None on failure. + Optional> BuildTensor(Value value, + const std::string& name, + unsigned buffer_idx); + + // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove + // these 2 functions here. + BufferOffset BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results); + BufferOffset BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results); + + // Build while operator where cond & body are regions. + Optional> BuildWhileOperator( + mlir::TFL::WhileOp op, const std::vector& operands, + const std::vector& results); + + // Builds custom operators. + // Templated on a) data type of custom_option to be stored into flatbuffer, + // and b) TFL custom op type. + template + BufferOffset BuildCustomOperator( + const CustomOptionType& custom_option, const std::string& opcode_name, + TFLOp op, const std::vector& operands, + const std::vector& results); + + BufferOffset BuildNumericVerifyOperator( + mlir::TFL::NumericVerifyOp op, const std::vector& operands, + const std::vector& results); + Optional> + BuildConvolution2DTransposeBiasOperator( + Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, + const std::vector& operands, + const std::vector& results); + Optional> BuildMaxPoolingWithArgMax2DOperator( + Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, + const std::vector& operands, + const std::vector& results); + Optional> BuildMaxUnpooling2DOperator( + Operation* inst, mlir::TFL::MaxUnpooling2DOp op, + const std::vector& operands, + const std::vector& results); + + Optional CreateFlexOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + Optional CreateCustomOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + std::unique_ptr CreateFlexBuilderWithNodeAttrs( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + // Returns opcode index for op identified by the op_name, if already + // available. Otherwise, creates a new OperatorCode using the given `builtin` + // operator and associates it with `op_name`. + uint32_t GetOpcodeIndex(const std::string& op_name, + tflite::BuiltinOperator builtin); + + // Builds operator for the given operation with specified operand and result + // tensor indices. Emits an error and returns llvm::None on failure. + Optional> BuildOperator( + Operation* inst, const std::vector& operands, + const std::vector& results, + const std::vector& intermediates); + + // Build a subgraph with a given name out of the region either corresponding + // to a function's body or while op. + Optional> BuildSubGraph( + const std::string& name, Region* region); + + // Builds Metadata with the given `name` and buffer `content`. + BufferOffset BuildMetadata(StringRef name, + StringRef content); + + // Encodes the `tfl.metadata` dictionary attribute of the module to the + // metadata section in the final model. + Optional>> + CreateMetadataVector(); + + // Uses the tf.entry_function attribute (if set) to initialize the op to name + // mapping. + void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); + + // Determines if the specified operation op's operand at operand_index + // is marked as a stateful operand. + bool IsStatefulOperand(mlir::Operation* op, int operand_index); + + // Returns a unique name for `val`. + std::string UniqueName(mlir::Value val); + + ModuleOp module_; + + tensorflow::OpOrArgNameMapper& name_mapper_; + + flatbuffers::FlatBufferBuilder builder_; + BufferOffset empty_buffer_; + + std::vector> buffers_; + + // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. + absl::flat_hash_map opcode_index_map_; + std::vector> opcodes_; + + // Maps function name to index of the corresponding subgraph in the FlatBuffer + // model. + absl::flat_hash_map subgraph_index_map_; + absl::flat_hash_set enabled_op_types_; + + // Points to TensorFlow and TFLite dialects, respectively. nullptr if the + // dialect is not registered. + const Dialect* tf_dialect_; + const Dialect* tfl_dialect_; + + // The failed ops during legalization. + std::set failed_flex_ops_; + std::set failed_custom_ops_; +}; + +std::string Translator::UniqueName(mlir::Value val) { + return std::string(name_mapper_.GetUniqueName(val)); +} + +Optional> Translator::BuildBuffer( + Operation* inst) { + ElementsAttr attr; + if (auto cst = dyn_cast(inst)) { + // ConstantOp have ElementAttr at this point due to validation of the TFLite + // module. + attr = cst.getValue().cast(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else { + return empty_buffer_; + } + + tensorflow::Tensor tensor; + auto status = tensorflow::ConvertToTensor(attr, &tensor); + if (!status.ok()) { + inst->emitError( + Twine("failed to convert value attribute to tensor with error: " + + status.ToString())); + return llvm::None; + } + + // TensorFlow and TensorFlow Lite use different string encoding formats. + // Convert to TensorFlow Lite format is it's a constant string tensor. + if (tensor.dtype() == tensorflow::DT_STRING) { + ::tflite::DynamicBuffer dynamic_buffer; + auto flat = tensor.flat<::tensorflow::tstring>(); + for (int i = 0; i < flat.size(); ++i) { + const auto& str = flat(i); + dynamic_buffer.AddString(str.c_str(), str.length()); + } + char* tensor_buffer; + int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); + auto buffer_data = + builder_.CreateVector(reinterpret_cast(tensor_buffer), bytes); + free(tensor_buffer); + return tflite::CreateBuffer(builder_, buffer_data); + } + + absl::string_view tensor_data = tensor.tensor_data(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(tensor_data.data()), tensor_data.size()); + return tflite::CreateBuffer(builder_, buffer_data); +} + +Optional> Translator::BuildTensorFromType( + mlir::Type type, const std::string& name) { + auto tensor_type = type.cast(); + + if (!tensor_type.hasStaticShape()) { + return llvm::None; + } + llvm::ArrayRef shape_ref = tensor_type.getShape(); + std::vector shape(shape_ref.begin(), shape_ref.end()); + + auto element_type = tensor_type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); + BufferOffset q_params; + auto qtype = element_type.dyn_cast(); + if (!qtype) { + return llvm::None; + } + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + /*buffer=*/0, builder_.CreateString(name), q_params, + /*is_variable=*/false); +} + +Optional> Translator::BuildTensor( + Value value, const std::string& name, unsigned buffer_idx) { + auto type = value.getType().cast(); + + // TFLite requires tensor shape only for the inputs and constants. + // However, we output all known shapes for better round-tripping + auto check_shape = + [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { + auto is_out_of_range = [](int64_t dim) { + return dim > std::numeric_limits::max(); + }; + + if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) + return mlir::emitError( + value.getLoc(), + "result shape dimensions out of 32 bit int type range"); + + return mlir::success(); + }; + + std::vector shape; + std::vector shape_signature; + if (type.hasStaticShape()) { + llvm::ArrayRef shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape = std::vector(shape_ref.begin(), shape_ref.end()); + } else if (auto* inst = value.getDefiningOp()) { + if (IsConst(inst)) { + // Const op can have a result of dynamic shaped type (e.g. due to constant + // folding), but we can still derive the shape of a constant tensor for + // its attribute type. + mlir::Attribute tensor_attr = inst->getAttr("value"); + llvm::ArrayRef shape_ref = + tensor_attr.getType().cast().getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape = std::vector(shape_ref.begin(), shape_ref.end()); + } + } else if (type.hasRank()) { + llvm::ArrayRef shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape.reserve(shape_ref.size()); + for (auto& dim : shape_ref) { + shape.push_back(dim == -1 ? 1 : dim); + } + shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); + } + + if (auto* inst = value.getDefiningOp()) { + if (auto cst = dyn_cast(inst)) { + // CreateSparsityParameters(cst.s_param()); + } else if (auto cst = dyn_cast(inst)) { + // CreateSparsityParameters(cst.s_param()); + } + } + + Type element_type = type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(type.getElementType()).ValueOrDie(); + + BufferOffset q_params; + if (auto qtype = element_type.dyn_cast()) { + q_params = tflite::CreateQuantizationParameters( + // TODO(fengliuai): min and max values are not stored in the + // quantized type, so both are set to 0. The model couldn't be imported + // to TensorFlow because of this. + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + } else if (auto qtype = + element_type + .dyn_cast()) { + std::vector scales(qtype.getScales().begin(), + qtype.getScales().end()); + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), + builder_.CreateVector(qtype.getZeroPoints()), + tflite::QuantizationDetails_NONE, /*details=*/0, + qtype.getQuantizedDimension()); + } else { + q_params = tflite::CreateQuantizationParameters(builder_); + } + // Check if the value's uses includes an op and usage at an operand index + // marked as a stateful. If so, set the tensor's is_variable as true + // This is v1 ref variable semantics in the TFLite runtime. + bool is_variable = false; + for (auto& use : value.getUses()) { + is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); + if (is_variable) { + break; + } + } + + if (shape_signature.empty()) { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable); + } else { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable, /*sparsity=*/0, + /*shape_signature=*/builder_.CreateVector(shape_signature)); + } +} + +BufferOffset Translator::BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); + int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); + int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); + auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, + else_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_IfOptions, + builtin_options); +} + +BufferOffset Translator::BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); + int body_subgraph_index = subgraph_index_map_.at(op.body().str()); + auto builtin_options = tflite::CreateWhileOptions( + builder_, cond_subgraph_index, body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); +} + +Optional> Translator::BuildWhileOperator( + mlir::TFL::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + auto get_call_index = [&](mlir::Block& b) -> Optional { + if (b.getOperations().size() != 2) return llvm::None; + if (auto call_op = dyn_cast(b.front())) + return subgraph_index_map_.at(call_op.callee().str()); + return llvm::None; + }; + auto body_subgraph_index = get_call_index(op.body().front()); + auto cond_subgraph_index = get_call_index(op.cond().front()); + if (!body_subgraph_index || !cond_subgraph_index) + return op.emitOpError("only single call cond/body while export supported"), + llvm::None; + auto builtin_options = + tflite::CreateWhileOptions(builder_, *cond_subgraph_index, + *body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); +} + +template +BufferOffset Translator::BuildCustomOperator( + const CustomOptionType& custom_option, const std::string& opcode_name, + TFLOp op, const std::vector& operands, + const std::vector& results) { + std::vector custom_option_vector(sizeof(CustomOptionType)); + memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType)); + auto opcode_index = + GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM); + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + builder_.CreateVector(custom_option_vector), + tflite::CustomOptionsFormat_FLEXBUFFERS); +} + +BufferOffset Translator::BuildNumericVerifyOperator( + mlir::TFL::NumericVerifyOp op, const std::vector& operands, + const std::vector& results) { + float tolerance = op.tolerance().convertToFloat(); + return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); +} + +Optional> +Translator::BuildConvolution2DTransposeBiasOperator( + Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, + const std::vector& operands, const std::vector& results) { + TfLiteTransposeConvParams conv_params; + conv_params.stride_height = op.stride_h().getSExtValue(); + conv_params.stride_width = op.stride_w().getSExtValue(); + const auto padding = GetTflitePadding(inst, op.padding()); + if (padding) { + conv_params.padding = *padding; + return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, + operands, results); + } + + return llvm::None; +} + +Optional> +Translator::BuildMaxPoolingWithArgMax2DOperator( + Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, + const std::vector& operands, const std::vector& results) { + const auto pool_params = GetTflitePoolParams(inst, op); + if (pool_params) { + return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, + operands, results); + } + + return llvm::None; +} + +Optional> +Translator::BuildMaxUnpooling2DOperator(Operation* inst, + mlir::TFL::MaxUnpooling2DOp op, + const std::vector& operands, + const std::vector& results) { + const auto pool_params = GetTflitePoolParams(inst, op); + if (pool_params) { + return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, + results); + } + + return llvm::None; +} + +Optional Translator::CreateFlexOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + std::string node_def_str; + if (!node_def.SerializeToString(&node_def_str)) { + return emitError(loc, "failed to serialize tensorflow node_def"), + llvm::None; + } + + auto flex_builder = absl::make_unique(); + flex_builder->Vector([&]() { + flex_builder->String(node_def.op()); + flex_builder->String(node_def_str); + }); + flex_builder->Finish(); + return builder_.CreateVector(flex_builder->GetBuffer()); +} + +Optional Translator::CreateCustomOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + std::string node_def_str; + if (!node_def.SerializeToString(&node_def_str)) { + return emitError(loc, "failed to serialize tensorflow node_def"), + llvm::None; + } + auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); + return builder_.CreateVector(flex_builder->GetBuffer()); +} + +std::unique_ptr +Translator::CreateFlexBuilderWithNodeAttrs( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + auto flex_builder = absl::make_unique(); + size_t map_start = flex_builder->StartMap(); + for (const auto& pair : node_def.attr()) { + const char* key = pair.first.c_str(); + const auto& attr = pair.second; + switch (attr.value_case()) { + case ::tensorflow::AttrValue::kS: + flex_builder->String(key, attr.s()); + break; + case ::tensorflow::AttrValue::kType: { + auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); + if (status_or_tfl_type.ok()) { + flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); + } else { + emitWarning(loc, "ignoring unsupported tensorflow type: ") + << std::to_string(attr.type()); + } + break; + } + case ::tensorflow::AttrValue::kI: + flex_builder->Int(key, attr.i()); + break; + case ::tensorflow::AttrValue::kF: + flex_builder->Float(key, attr.f()); + break; + case ::tensorflow::AttrValue::kB: + flex_builder->Bool(key, attr.b()); + break; + case tensorflow::AttrValue::kList: + if (attr.list().s_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const std::string& v : attr.list().s()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else if (attr.list().i_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const int64_t v : attr.list().i()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else if (attr.list().f_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const float v : attr.list().f()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else { + emitWarning(loc, + "ignoring unsupported type in list attribute with key: ") + << key; + } + break; + default: + emitWarning(loc, "ignoring unsupported attribute type with key: ") + << key; + break; + } + } + flex_builder->EndMap(map_start); + flex_builder->Finish(); + return flex_builder; +} + +uint32_t Translator::GetOpcodeIndex(const std::string& op_name, + tflite::BuiltinOperator builtin) { + auto it = opcode_index_map_.insert({op_name, 0}); + + // If the insert succeeded, the opcode has not been created already. Create a + // new operator code and update its index value in the map. + if (it.second) { + it.first->second = opcodes_.size(); + auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM + ? builder_.CreateString(op_name) + : BufferOffset(); + // Use version 0 for builtin op. This is a way to serialize version field to + // flatbuffer (since 0 is non default) and it will be corrected later. + int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; + opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, + custom_code, op_version)); + } + return it.first->second; +} + +Optional> Translator::BuildOperator( + Operation* inst, const std::vector& operands, + const std::vector& results, + const std::vector& intermediates) { + const auto* dialect = inst->getDialect(); + if (!dialect) { + inst->emitOpError("dialect is not registered"); + return llvm::None; + } + + // If TFLite built in op, create operator as a builtin op. + if (dialect == tfl_dialect_) { + // Only if built-in TFLite op emission is enabled, would legalization have + // converted any TF->TFL. + if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { + return inst->emitOpError( + "is a TFLite builtin op but builtin emission is not enabled"), + llvm::None; + } + + auto builtin_code = GetBuiltinOpCode(inst); + if (!builtin_code) { + if (auto verify_op = dyn_cast(inst)) { + return BuildNumericVerifyOperator(verify_op, operands, results); + } + if (auto conv_transpose_bias_op = + dyn_cast(inst)) { + return BuildConvolution2DTransposeBiasOperator( + inst, conv_transpose_bias_op, operands, results); + } + if (auto max_pooling_with_arg_max_op = + dyn_cast(inst)) { + return BuildMaxPoolingWithArgMax2DOperator( + inst, max_pooling_with_arg_max_op, operands, results); + } + if (auto max_unpooling_op = dyn_cast(inst)) { + return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, + results); + } + if (auto whileOp = dyn_cast(inst)) { + if (inst->getNumOperands() != inst->getNumResults()) { + inst->emitOpError( + "number of operands and results don't match, only canonical " + "TFL While supported"); + return llvm::None; + } + return BuildWhileOperator(whileOp, operands, results); + } + + inst->emitOpError("is not a supported TFLite op"); + return llvm::None; + } + + std::string op_name = inst->getName().getStringRef().str(); + uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); + auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, + results, intermediates, &builder_); + if (!offset) { + inst->emitOpError("is not a supported TFLite op"); + } + return offset; + } + + if (dialect == tf_dialect_) { + std::string op_name; + if (auto ifOp = dyn_cast(inst)) { + return BuildIfOperator(ifOp, operands, results); + } else if (auto whileOp = dyn_cast(inst)) { + return BuildWhileOperator(whileOp, operands, results); + } + + CustomOptionsOffset custom_options; + + // Ops in TF dialect can either be custom ops or flex ops. + // The reason we go directly from TensorFlow dialect MLIR to tensorflow + // node instead of going to TF table gen'd ops via generated code is that + // we do not want to restrict custom and flex op conversion support to + // only those TF ops that are currently registered in MLIR. The current + // model is of an open op system. + // + // The following algorithm is followed: + // if flex is enabled and the op is whitelisted as flex + // we emit op as flex. + // if custom is enabled + // we emit the op as custom. + auto node_def = GetTensorFlowNodeDef(inst); + if (!node_def) { + return llvm::None; + } + + // Flex op case + // Eventually, the whitelist will go away and we will rely on some TF op + // trait (e.g. No side effect) to determine if it is a supported "Flex" + // op or not. + if (enabled_op_types_.contains(OpType::kSelectTf) && + IsWhitelistedFlexOp(node_def->op())) { + // Construct ops as flex op encoding TensorFlow node definition + // as custom options. + // Flex ops are named with the kFlexOpNamePrefix prefix to the actual + // TF op name. + op_name = std::string(kFlexOpNamePrefix) + node_def->op(); + if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { + return llvm::None; + } + } else if (enabled_op_types_.contains(OpType::kCustomOp)) { + // Generic case of custom ops - write using flex buffers since that + // is the only custom options supported by TFLite today. + op_name = node_def->op(); + if (auto options = + CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { + return llvm::None; + } + } else { + // Create description of operation that could not be converted. + const int kLargeElementsAttr = 16; + std::string op_str; + llvm::raw_string_ostream os(op_str); + inst->getName().print(os); + // Print out attributes except for large elementsattributes (which should + // rarely be the cause why the legalization didn't happen). + if (!inst->getAttrList().getAttrs().empty()) { + os << " {"; + bool first = true; + for (auto& named_attr : inst->getAttrList().getDictionary()) { + os << (!first ? ", " : ""); + first = false; + named_attr.first.print(os); + os << " = "; + if (auto element_attr = named_attr.second.dyn_cast()) { + if (element_attr.getNumElements() <= kLargeElementsAttr) { + element_attr.print(os); + } else { + os << ""; + } + } else { + named_attr.second.print(os); + } + } + os << "}"; + } + + // Insert failed op to `flex_ops` or `custom_ops`. + if (IsWhitelistedFlexOp(node_def->op())) { + failed_flex_ops_.insert(os.str()); + } else { + failed_custom_ops_.insert(os.str()); + } + return inst->emitOpError("is neither a custom op nor a flex op"), + llvm::None; + } + + uint32_t opcode_index = + GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + /*custom_options=*/custom_options, + tflite::CustomOptionsFormat_FLEXBUFFERS, + /*mutating_variable_inputs=*/0); + } + + return inst->emitOpError( + "is not any of a builtin TFLite op, a flex TensorFlow op or a " + "custom TensorFlow op"), + llvm::None; +} + +void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { + auto dict_attr = fn.getAttrOfType("tf.entry_function"); + if (!dict_attr) return; + + llvm::SmallVector input_names; + llvm::SmallVector output_names; + if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { + str.getValue().split(input_names, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + if (input_names.size() != fn.getNumArguments()) { + fn.emitWarning() << "invalid entry function specification"; + return; + } + for (auto it : llvm::enumerate(fn.getArguments())) { + name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); + } + *has_input_attr = true; + } + + if (auto str = + dict_attr.get("outputs").dyn_cast_or_null()) { + str.getValue().split(output_names, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + auto term = fn.getBlocks().back().getTerminator(); + if (output_names.size() != term->getNumOperands()) { + fn.emitWarning() << "output names (" << output_names.size() + << ") != terminator operands (" << term->getNumOperands() + << ")"; + return; + } + for (const auto& it : llvm::enumerate(term->getOperands())) { + name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); + } + } +} + +bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { + std::vector operand_indices; + if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; + return absl::c_find(operand_indices, operand_index) != operand_indices.end(); +} + +Optional> Translator::BuildSubGraph( + const std::string& name, Region* region) { + bool has_input_attr = false; + if (auto fn = dyn_cast(region->getParentOp())) { + InitializeNamesFromAttribute(fn, &has_input_attr); + } + std::vector> tensors; + llvm::DenseMap tensor_index_map; + + // Builds tensor and buffer for argument or operation result. Returns false + // on failure. + auto build_tensor_and_buffer = [&](Value value, const std::string& name) { + // NoneType represents optional and may be skipped here. + if (value.getType().isa()) { + return true; + } + + tensor_index_map.insert({value, tensors.size()}); + auto tensor_or = BuildTensor(value, name, buffers_.size()); + if (!tensor_or) return false; + tensors.push_back(*tensor_or); + + // TODO(ashwinm): Check if for stateful tensors, if it is also needed to + // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. + // This does not seem to affect runtime behavior for RNN/LSTM, but would be + // good for reducing memory footprint. + if (auto* inst = value.getDefiningOp()) { + auto buffer_or = BuildBuffer(inst); + if (!buffer_or) return false; + buffers_.push_back(*buffer_or); + } else { + buffers_.push_back(empty_buffer_); + } + return true; + }; + + std::vector> operators; + auto& bb = region->front(); + + // Main function's arguments are first passed to `input` op so they don't + // have associated tensor and buffer. Build FlatBuffer tensor and buffer for + // other functions. + for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { + mlir::BlockArgument arg = bb.getArgument(i); + std::string name; + if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg)); + if (name.empty()) name = absl::StrCat("arg", i); + if (!build_tensor_and_buffer(arg, name)) return llvm::None; + } + + bool failed_once = false; + for (auto& inst : bb) { + if (inst.isKnownTerminator()) break; + std::vector intermediates; + // Build intermediate tensors for tfl.lstm and insert these tensors into + // flatbuffer. + if (llvm::isa(inst)) { + std::vector intermediate_names = { + "input_to_input_intermediate", "input_to_forget_intermediate", + "input_to_cell_intermediate", "input_to_output_intermediate", + "effective_hidden_scale_intermediate"}; + for (const std::string& intermediate : intermediate_names) { + auto intermediate_attr = inst.getAttr(intermediate); + if (auto attr = intermediate_attr.dyn_cast_or_null()) { + Type qtype = attr.getValue(); + auto tensor_or = BuildTensorFromType( + qtype, name_mapper_.GetUniqueName(intermediate).str()); + if (!tensor_or.hasValue()) { + continue; + } else { + intermediates.push_back(tensors.size()); + tensors.push_back(tensor_or.getValue()); + } + } + } + } + + for (auto val : inst.getResults()) { + std::string name = UniqueName(val); + if (!build_tensor_and_buffer(val, name)) return llvm::None; + } + + // Skip constant ops as they don't represent a TFLite operator. + if (IsConst(&inst)) continue; + + // Fetch operand and result tensor indices. + std::vector operands; + operands.reserve(inst.getNumOperands()); + for (auto operand : inst.getOperands()) { + if (operand.getType().isa()) + operands.push_back(kTfLiteOptionalTensor); + else + operands.push_back(tensor_index_map.lookup(operand)); + } + std::vector results; + results.reserve(inst.getNumOperands()); + for (auto result : inst.getResults()) { + results.push_back(tensor_index_map.lookup(result)); + } + + if (auto tfl_operator = + BuildOperator(&inst, operands, results, intermediates)) + operators.push_back(*tfl_operator); + else + failed_once = true; + } + + if (failed_once) return llvm::None; + + // Get input and output tensor indices for the subgraph. + std::vector inputs, outputs; + for (auto arg : bb.getArguments()) { + inputs.push_back(tensor_index_map[arg]); + } + for (auto result : bb.getTerminator()->getOperands()) { + outputs.push_back(tensor_index_map[result]); + } + + return tflite::CreateSubGraph( + builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), + builder_.CreateVector(outputs), builder_.CreateVector(operators), + /*name=*/builder_.CreateString(name)); +} + +BufferOffset Translator::BuildMetadata(StringRef name, + StringRef content) { + auto buffer_index = buffers_.size(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(content.data()), content.size()); + buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); + return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); +} + +Optional>> +Translator::CreateMetadataVector() { + auto dict_attr = module_.getAttrOfType("tfl.metadata"); + std::vector> metadata; + if (dict_attr) { + for (const auto& named_attr : dict_attr) { + StringRef name = named_attr.first; + mlir::Attribute attr = named_attr.second; + if (auto content = attr.dyn_cast()) { + metadata.push_back(BuildMetadata(name, content.getValue())); + } else { + module_.emitError( + "all values in tfl.metadata's dictionary key-value pairs should be " + "string attributes"); + return llvm::None; + } + } + } + // Runtime version string is generated after we update the op + // versions. Here we put a 16-byte dummy string as a placeholder. We choose + // 16-byte because it's the alignment of buffers in flatbuffer, so it won't + // cause any waste of space if the actual string is shorter than 16 bytes. + metadata.push_back( + BuildMetadata("min_runtime_version", std::string(16, '\0'))); + return builder_.CreateVector(metadata); +} + +bool UpdateEntryFunction(ModuleOp module) { + if (module.lookupSymbol("main") != nullptr) { + // We already have an entry function. + return true; + } + + int entry_func_count = 0; + FuncOp entry_func = nullptr; + for (auto fn : module.getOps()) { + auto attrs = fn.getAttrOfType("tf.entry_function"); + if (attrs && !attrs.empty()) { + entry_func_count++; + entry_func = fn; + } + } + + // We should have one & only have one entry function. + if (entry_func_count != 1) return false; + + // Update the entry func to main. + entry_func.setName("main"); + return true; +} + +Optional Translator::Translate( + ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { + if (!UpdateEntryFunction(module)) return llvm::None; + if (!IsValidTFLiteMlirModule(module)) return llvm::None; + Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_or_arg_name_mapper); + return translator.TranslateInternal(); +} + +Optional Translator::TranslateInternal() { + // A list of named regions in the module with main function being the first in + // the list. The main function is required as the first subgraph in the model + // is entry point for the model. + std::vector> named_regions; + named_regions.reserve(std::distance(module_.begin(), module_.end())); + + int subgraph_idx = 0; + FuncOp main_fn = module_.lookupSymbol("main"); + subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; + named_regions.emplace_back("main", &main_fn.getBody()); + // Walk over the module collection ops with functions and while ops. + module_.walk([&](FuncOp fn) { + if (fn != main_fn) { + subgraph_index_map_[fn.getName().str()] = subgraph_idx++; + named_regions.emplace_back(fn.getName().str(), &fn.getBody()); + } + }); + + // Build subgraph for each of the named regions. + std::vector> subgraphs; + subgraphs.reserve(named_regions.size()); + int first_failed_func = -1; + for (auto it : llvm::enumerate(named_regions)) { + auto subgraph_or = BuildSubGraph(it.value().first, it.value().second); + if (!subgraph_or) { + if (first_failed_func == -1) + // Record the index of the first region that cannot be converted. + // Keep looping through all subgraphs in the module to make sure that + // we collect the list of missing ops from the entire module. + first_failed_func = it.index(); + } else { + subgraphs.push_back(*subgraph_or); + } + } + + if (first_failed_func != -1) { + std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); + std::string failed_custom_ops_list = + absl::StrJoin(failed_custom_ops_, "\n\t"); + std::string err; + if (!failed_flex_ops_list.empty()) + err += + "Ops that can be supported by the flex runtime (enabled via setting " + "the -emit-select-tf-ops flag):\n\t" + + failed_flex_ops_list; + if (!failed_custom_ops_list.empty()) + err += + "Ops that need custom implementation (enabled via setting the " + "-emit-custom-ops flag):\n\t" + + failed_custom_ops_list; + + auto& failed_region = named_regions[first_failed_func]; + return failed_region.second->getParentOp()->emitError() + << "failed while converting: '" << failed_region.first + << "': " << err, + llvm::None; + } + + std::string model_description; + if (auto attr = module_.getAttrOfType("tfl.description")) { + model_description = attr.getValue().str(); + } else { + model_description = "MLIR Converted."; + } + + // Build the model and finish the model building process. + auto description = builder_.CreateString(model_description.data()); + VectorBufferOffset metadata_buffer = 0; // Deprecated + auto metadata = CreateMetadataVector(); + if (!metadata) return llvm::None; + + auto model = tflite::CreateModel( + builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), + builder_.CreateVector(subgraphs), description, + builder_.CreateVector(buffers_), metadata_buffer, *metadata); + tflite::FinishModelBuffer(builder_, model); + tflite::UpdateOpVersion(builder_.GetBufferPointer()); + tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); + + // Return serialized string for the built FlatBuffer. + return std::string(reinterpret_cast(builder_.GetBufferPointer()), + builder_.GetSize()); +} + +} // namespace + +// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer +// format. Returns false on success. +// +// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting +// the following: +// +// * Quantization +// * Ops with variable tensors +// +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + OpOrArgNameMapper* op_or_arg_name_mapper) { + auto maybe_translated = + Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_or_arg_name_mapper); + if (!maybe_translated) return true; + *serialized_flatbuffer = std::move(*maybe_translated); + return false; +} + +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops) { + OpOrArgLocNameMapper op_or_arg_name_mapper; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); +} + +static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction( ModuleOp module, llvm::raw_ostream& output) { std::string serialized_flatbuffer; - std::unique_ptr op_or_arg_name_mapper; + std::unique_ptr op_or_arg_name_mapper; if (strip_debug_info) { op_or_arg_name_mapper = std::make_unique(); } else { - op_or_arg_name_mapper = - std::make_unique(); + op_or_arg_name_mapper = std::make_unique(); } if (tflite::MlirToFlatBufferTranslateFunction( module, &serialized_flatbuffer, emit_builtin_tflite_ops, @@ -162,18 +1511,8 @@ static LogicalResult MlirToFlatBufferFileTranslateFunction( return mlir::failure(); output << serialized_flatbuffer; - return success(); + return mlir::success(); } -} // namespace - -static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( - "tflite-flatbuffer-to-mlir", - [](llvm::SourceMgr& source_mgr, MLIRContext* context) { - return FlatBufferFileToMlirTrans( - &source_mgr, context, use_external_constant, - experimental_prune_unreachable_nodes_unconditionally); - }); static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate( "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction); -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h similarity index 90% rename from tensorflow/compiler/mlir/lite/flatbuffer_export.h rename to tensorflow/compiler/mlir/lite/flatbuffer_translate.h index f89893d5c87..03f92ddbf03 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ #include @@ -40,4 +40,4 @@ bool MlirToFlatBufferTranslateFunction( tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); } // namespace tflite -#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h b/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h similarity index 84% rename from tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h rename to tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h index 4e891a5b266..6c8f80d4e05 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ #include @@ -28,4 +28,4 @@ extern bool lower_tensor_list_ops; // The flag to control whether debug info gets stripped on export. extern bool strip_debug_info; -#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index d17215566a1..6f8292308a4 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -34,8 +34,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/Parser.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/delegates/flex/delegate.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 7557ff5223c..2f677397109 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index f04dc9c2961..c05337918f2 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index bb82988def1..74e48cd6d91 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -28,8 +28,8 @@ limitations under the License. #include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 1ba0c025613..b05dcaadab2 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "mlir/Transforms/Passes.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 6cd058a15d2..8ac33c906bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -811,8 +811,7 @@ cc_library( srcs = ["utils/error_util.cc"], hdrs = ["utils/error_util.h"], deps = [ - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:status", + "//tensorflow/core:lib", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", ], diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc index 5514a788996..60646ae764e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/lib/core/errors.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index 1bc0a23e359..7eb30ee2c46 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -21,7 +21,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/lib/core/status.h" // Error utilities for MLIR when interacting with code using Status returns. namespace mlir { From 9278bbfc249233190c3901e9e6922d4077372899 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Sun, 22 Mar 2020 22:55:34 -0700 Subject: [PATCH 402/492] Support to pass in a list of supported node indices directly to GraphPartitionHelper. PiperOrigin-RevId: 302365857 Change-Id: I3debf09a5cc033b07e56b08e8345d548fb556cb7 --- tensorflow/lite/delegates/utils.cc | 4 ++-- tensorflow/lite/delegates/utils.h | 12 ++++++++--- tensorflow/lite/delegates/utils_test.cc | 27 +++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/delegates/utils.cc b/tensorflow/lite/delegates/utils.cc index 75839d53560..fba8bec39a5 100644 --- a/tensorflow/lite/delegates/utils.cc +++ b/tensorflow/lite/delegates/utils.cc @@ -18,9 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" -#include "tensorflow/lite/util.h" namespace tflite { namespace delegates { @@ -98,6 +96,8 @@ GraphPartitionHelper::GetFirstNLargestPartitions( TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes( std::set* unsupported_nodes_info) { + if (!is_node_supported_fn_) return kTfLiteOk; + TfLiteIntArray* execution_plan = nullptr; auto status = context_->GetExecutionPlan(context_, &execution_plan); if (status != kTfLiteOk) { diff --git a/tensorflow/lite/delegates/utils.h b/tensorflow/lite/delegates/utils.h index f894cae30fd..d6d22c4efa2 100644 --- a/tensorflow/lite/delegates/utils.h +++ b/tensorflow/lite/delegates/utils.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace delegates { @@ -43,12 +44,17 @@ using IsNodeSupportedFn = // Note the class *needs* to be used in TfLiteDelegate::Prepare. class GraphPartitionHelper { public: - // TODO(b/151152967): Support use-cases where a list of supported nodes are - // directly passed-in. GraphPartitionHelper(TfLiteContext* context, IsNodeSupportedFn is_node_supported_fn) : context_(context), is_node_supported_fn_(is_node_supported_fn) {} + GraphPartitionHelper(TfLiteContext* context, + const std::vector& supported_node_indices) + : context_(context), + num_total_nodes_(supported_node_indices.size()), + supported_nodes_( + ConvertVectorToTfLiteIntArray(supported_node_indices)) {} + virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); } // Partition the graph into node subsets such that each subset could be @@ -98,7 +104,7 @@ class GraphPartitionHelper { int num_total_nodes_ = 0; // Tells if a node is supported as it could be delegated. - const IsNodeSupportedFn is_node_supported_fn_; + const IsNodeSupportedFn is_node_supported_fn_ = nullptr; // Contains an array of supported node indices. TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory diff --git a/tensorflow/lite/delegates/utils_test.cc b/tensorflow/lite/delegates/utils_test.cc index a67778fee1f..5d308a0b546 100644 --- a/tensorflow/lite/delegates/utils_test.cc +++ b/tensorflow/lite/delegates/utils_test.cc @@ -223,6 +223,33 @@ TEST(GraphPartitionHelper, CheckPrepareErrors) { EXPECT_EQ(kTfLiteError, helper.Partition(nullptr)); } +TEST(GraphPartitionHelper, CheckPartitionsWithSupportedNodeList) { + // The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}. + // So, we simply create a list of supported nodes as {0,1,2,...,8,9} + MockTfLiteContext mocked_context; + std::vector supported_nodes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + GraphPartitionHelper helper(&mocked_context, supported_nodes); + EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr)); + EXPECT_EQ(10, helper.num_total_nodes()); + EXPECT_EQ(4, helper.num_partitions()); + + auto partitions = helper.GetFirstNLargestPartitions(1, 0); + EXPECT_EQ(1, partitions.size()); + auto nodes = GetNodesToReplaceFromPartitions(partitions); + EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8})); + + // Get the largest partition but requiring at least 5 nodes, so empty result. + partitions = helper.GetFirstNLargestPartitions(1, 5); + EXPECT_TRUE(partitions.empty()); + + partitions = helper.GetFirstNLargestPartitions(10, 3); + EXPECT_EQ(2, partitions.size()); + EXPECT_EQ(4, partitions[0]->nodes_to_replace->size); + EXPECT_EQ(3, partitions[1]->nodes_to_replace->size); + nodes = GetNodesToReplaceFromPartitions(partitions); + EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9})); +} + } // namespace } // namespace delegates } // namespace tflite From 00ba3b81eb46fb8cb34e889659553d0a1753cadd Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Mon, 23 Mar 2020 01:08:31 -0700 Subject: [PATCH 403/492] Clear owned delegates before they are being populated, which avoids leaving unused delegates in multi-option runs. PiperOrigin-RevId: 302379673 Change-Id: I95ec448c5793f05778e9f6fee55b935dd44a97ad --- tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index d6965619689..825879693f3 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -615,6 +615,7 @@ TfLiteStatus BenchmarkTfLiteModel::Init() { interpreter_->UseNNAPI(params_.Get("use_legacy_nnapi")); interpreter_->SetAllowFp16PrecisionForFp32(params_.Get("allow_fp16")); + owned_delegates_.clear(); for (const auto& delegate_provider : GetRegisteredDelegateProviders()) { auto delegate = delegate_provider->CreateTfLiteDelegate(params_); // It's possible that a delegate of certain type won't be created as From 2e1b15de42c25fcd0969f6b7f973fcf121ddd193 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 02:02:23 -0700 Subject: [PATCH 404/492] compat: Update forward compatibility horizon to 2020-03-23 PiperOrigin-RevId: 302385481 Change-Id: I2c6c43db3f14531063104c94af1236cee9b1ca9c --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 81a7d03f110..da05db6f7f4 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 22) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 23) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From b4d23b0da5ae7e8c1354274f51de3570a2ac221c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 03:41:35 -0700 Subject: [PATCH 405/492] Update clang to latest version in the RBE docker image. PiperOrigin-RevId: 302399776 Change-Id: I44bfc2b6a30205f4474a8c005f550df48021d43f --- .../Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 index c7e84936bf5..d2713e8805b 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010 @@ -47,7 +47,7 @@ RUN apt-get update && apt-get install -y \ rm -rf /var/lib/apt/lists/* # Copy and run the install scripts. -ENV CLANG_VERSION="ra21beccea2020f950845cbb68db663d0737e174c" +ENV CLANG_VERSION="r42cab985fd95ba4f3f290e7bb26b93805edb447d" COPY install/*.sh /install/ ARG DEBIAN_FRONTEND=noninteractive RUN /install/install_bootstrap_deb_packages.sh From f05de3fdee4f292841ccade5689ffef4bb211e80 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 03:45:24 -0700 Subject: [PATCH 406/492] Use latest RBE docker image and latest clang release. PiperOrigin-RevId: 302400255 Change-Id: I68afba780aeab479980b71cfa900afaa06ac98e6 --- third_party/toolchains/preconfig/generate/containers.bzl | 2 +- third_party/toolchains/remote_config/configs.bzl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/toolchains/preconfig/generate/containers.bzl b/third_party/toolchains/preconfig/generate/containers.bzl index be19af8ceeb..bf2a655acc4 100644 --- a/third_party/toolchains/preconfig/generate/containers.bzl +++ b/third_party/toolchains/preconfig/generate/containers.bzl @@ -8,7 +8,7 @@ container_digests = { "cuda10.0-cudnn7-centos6": "sha256:a1909ba09c703340ee0074ce63dd94fe8fea48035a25264677907a609e2375e0", "cuda10.1-cudnn7-centos6": "sha256:454b899657e87893ee5e68dc0f87df59b6a0a7418ae09cafcc3dd65ac71feca9", "cuda10.0-cudnn7-ubuntu16.04-manylinux2010": "sha256:5812d9d0ef0a3276fc5faaf4cd01f3d6e03d635893a6e2d2e04f6f01d626c432", - "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:177e1e55894b3c6edcfd7aa5d6db53716924b02553922bbf907e16b3d319e18c", + "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:cc7f760195d7bbe283b45ae740409751d0b74d8ffbdc2f7a3cb62c71a71fbe25", "rocm-ubuntu16.04": "sha256:e645447dd6127325f3e97b8bf23424f637a8579d963b34fcc6772cf7cfaa0ebe", "windows-1803": "sha256:f109576c7c0c8a1783ff22b666e8923b52dbbe7933f69a1c7a7275202c304a12", } diff --git a/third_party/toolchains/remote_config/configs.bzl b/third_party/toolchains/remote_config/configs.bzl index 4ebf5c1c068..4c94abf45b3 100644 --- a/third_party/toolchains/remote_config/configs.bzl +++ b/third_party/toolchains/remote_config/configs.bzl @@ -36,7 +36,7 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0", - compiler = "/clang_ra21beccea2020f950845cbb68db663d0737e174c/bin/clang", + compiler = "/clang_r42cab985fd95ba4f3f290e7bb26b93805edb447d/bin/clang", cuda_version = "10.1", cudnn_version = "7", os = "ubuntu16.04-manylinux2010", From 4abab73a1074301ca601e127c4abec7b10ae125c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 03:46:19 -0700 Subject: [PATCH 407/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302400337 Change-Id: I0119674270464404c1abbe2409047d318aa72e6b --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 75d86f71b78..68bb1dc49f5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From cbc5ab117445ac0328311df8abea61be6256b92f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 05:46:30 -0700 Subject: [PATCH 408/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302415042 Change-Id: Ic75e656ca1aa380265e68e9e552d78ab50983f21 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 68bb1dc49f5..75d86f71b78 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From a6f0afd2ebbd8c6abb40f308cdf8ac4394b5326d Mon Sep 17 00:00:00 2001 From: Lu Wang Date: Mon, 23 Mar 2020 07:03:09 -0700 Subject: [PATCH 409/492] Fix typos in metadata_schema.fbs PiperOrigin-RevId: 302425414 Change-Id: I3b44db6fefc5a5f8bb4baaef5ae046139566d70f --- .../lite/experimental/support/metadata/metadata_schema.fbs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs b/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs index f3f3bbcc6ff..7806899c906 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs +++ b/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs @@ -294,8 +294,8 @@ table Content { // Parameters that are used when normalizing the tensor. table NormalizationOptions{ - // mean and std are normalization parameters. Tensor values are normailzed - // per-channelly by, + // mean and std are normalization parameters. Tensor values are normalized + // on a per-channel basis, by the formula // (x - mean) / std. // For example, a float MobileNet model will have // mean = 127.5f and std = 127.5f. @@ -404,7 +404,7 @@ table TensorMetadata { // A description of the tensor. description:string; - // A list of names of the dimensions in this tentor. The length of + // A list of names of the dimensions in this tensor. The length of // dimension_names need to match the number of dimensions in this tensor. // // : From 99c1898d7b4193d4ef4e735cb2c4e250a19dfd6c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 07:32:05 -0700 Subject: [PATCH 410/492] Allow adding compiler specific flags to the GPU crosstool. For now, only support "clang", "msvc" and "unknown" (implying a variety of gcc compatible compilers). We hard-code the MSVC toolchain to "msvc". The darwin toolchain is kept at "unknown", as it currently isn't tested with any of our clang-specific flags. The linux toolchain will use "clang" if compiling with cuda-clang, as that's currently the only situation in which we care about using clang-specific flags. In the future, we do optimally want to use compiler detection and setting the compiler attribute accordingly. PiperOrigin-RevId: 302429367 Change-Id: I3f47582e76d86d119d3fca4a6855cfc377667922 --- third_party/gpus/crosstool/BUILD.tpl | 2 + .../crosstool/cc_toolchain_config.bzl.tpl | 120 ++++++++++++------ third_party/gpus/cuda_configure.bzl | 2 + 3 files changed, 82 insertions(+), 42 deletions(-) diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl index 5a78654a90f..bc92f91a777 100644 --- a/third_party/gpus/crosstool/BUILD.tpl +++ b/third_party/gpus/crosstool/BUILD.tpl @@ -68,6 +68,7 @@ cc_toolchain_config( linker_bin_path = "%{linker_bin_path}", builtin_sysroot = "%{builtin_sysroot}", cuda_path = "%{cuda_toolkit_path}", + compiler = "%{compiler}", ) cc_toolchain( @@ -124,6 +125,7 @@ cc_toolchain_config( msvc_lib_path = "%{msvc_lib_path}", msvc_link_path = "%{msvc_link_path}", msvc_ml_path = "%{msvc_ml_path}", + compiler = "msvc", ) filegroup( diff --git a/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl b/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl index 3d4d41aa2b1..7b249c0c606 100644 --- a/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl +++ b/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl @@ -626,48 +626,82 @@ def _impl(ctx): ], ) - default_compile_flags_feature = feature( - name = "default_compile_flags", - enabled = True, - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.linkstamp_compile, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ACTION_NAMES.lto_backend, - ACTION_NAMES.clif_match, - ], - flag_groups = [ - flag_group( - flags = [ - "/DCOMPILER_MSVC", - "/DNOMINMAX", - "/D_WIN32_WINNT=0x0600", - "/D_CRT_SECURE_NO_DEPRECATE", - "/D_CRT_SECURE_NO_WARNINGS", - "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS", - "/bigobj", - "/Zm500", - "/J", - "/Gy", - "/GF", - "/EHsc", - "/wd4351", - "/wd4291", - "/wd4250", - "/wd4996", - ], - ), - ], - ), - ], - ) + if ctx.attr.compiler == "clang": + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = [ + "-fexperimental-new-pass-manager", + ], + ), + ], + ), + ], + ) + + elif ctx.attr.compiler == "msvc": + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = [ + "/DCOMPILER_MSVC", + "/DNOMINMAX", + "/D_WIN32_WINNT=0x0600", + "/D_CRT_SECURE_NO_DEPRECATE", + "/D_CRT_SECURE_NO_WARNINGS", + "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS", + "/bigobj", + "/Zm500", + "/J", + "/Gy", + "/GF", + "/EHsc", + "/wd4351", + "/wd4291", + "/wd4250", + "/wd4996", + ], + ), + ], + ), + ], + ) + + else: + default_compile_flags_feature = feature( + name = "default_compile_flags") static_link_msvcrt_debug_feature = feature( name = "static_link_msvcrt_debug", @@ -1320,6 +1354,7 @@ def _impl(ctx): if (ctx.attr.cpu == "local"): features = [ + default_compile_flags_feature, cpp11_feature, stdlib_feature, determinism_feature, @@ -1510,6 +1545,7 @@ cc_toolchain_config = rule( "msvc_lib_path": attr.string(default = "msvc_not_used"), "msvc_link_path": attr.string(default = "msvc_not_used"), "msvc_ml_path": attr.string(default = "msvc_not_used"), + "compiler": attr.string(values = ["clang", "msvc", "unknown"], default="unknown"), }, provides = [CcToolchainConfigInfo], executable = True, diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index a4eccc4d235..8fa64f264dc 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1024,8 +1024,10 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines = {} cuda_defines["%{builtin_sysroot}"] = tf_sysroot cuda_defines["%{cuda_toolkit_path}"] = "" + cuda_defines["%{compiler}"] = "unknown" if is_cuda_clang: cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"] + cuda_defines["%{compiler}"] = "clang" host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX) if not host_compiler_prefix: From d3f59200aac4cc29b5181fa79ba68e865ca80968 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 07:46:05 -0700 Subject: [PATCH 411/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302431216 Change-Id: I61f2e6bbccad25756c75d030d2cc62f2a3358378 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 75d86f71b78..68bb1dc49f5 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 204b93a1d5ab465ec6807bb245960b9044b0cdf4 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 23 Mar 2020 09:09:36 -0700 Subject: [PATCH 412/492] Break tf_export and its few dependencies into smaller standalone targets on which the test of TensorFlow can depend. PiperOrigin-RevId: 302446648 Change-Id: I72b00774bae8705de2125a80dd4ddc4b0c9eb0e6 --- tensorflow/python/BUILD | 113 +++++++++++++++++++++++++++------------- 1 file changed, 78 insertions(+), 35 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4b5910d05a4..46696993f99 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3,6 +3,7 @@ # Public targets: # ":platform" - Low-level and platform-specific Python code. +load("//tensorflow:tensorflow.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") load("//tensorflow:tensorflow.bzl", "pybind_extension") @@ -19,6 +20,8 @@ load( "if_ngraph", ) +# TODO(mdan): Break into per-directory files. + visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//third_party/cloud_tpu/convergence_tools:__subpackages__", @@ -1065,17 +1068,6 @@ tf_py_test( ], ) -tf_py_test( - name = "tf_export_test", - srcs = ["util/tf_export_test.py"], - python_version = "PY3", - deps = [ - ":client_testlib", - ":platform", - ":util", - ], -) - tf_py_test( name = "deprecation_test", srcs = ["util/deprecation_test.py"], @@ -1823,19 +1815,6 @@ tf_py_test( ], ) -# This target is maintained separately from :util to provide separate visibility -# for legacy users who were granted visibility when the functions were private -# members of ops.Graph. -py_library( - name = "tf_stack", - srcs = ["util/tf_stack.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":_tf_stack", - ], -) - py_library( name = "tensor_shape", srcs = ["framework/tensor_shape.py"], @@ -5398,6 +5377,63 @@ py_library( ], ) +# Leaf library: may not depend on anything else inside TensorFlow. +py_strict_library( + name = "tf_export", + srcs = ["util/tf_export.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":tf_decorator", + ], +) + +tf_py_test( + name = "tf_export_test", + srcs = ["util/tf_export_test.py"], + python_version = "PY3", + deps = [ + ":client_testlib", + ":platform", + ":util", + ], +) + +# Leaf library: may not depend on anything else inside TensorFlow. +# TODO(mdan): Move this utility outside of TF. +py_strict_library( + name = "tf_decorator", + srcs = [ + "util/tf_contextlib.py", + "util/tf_decorator.py", + "util/tf_inspect.py", + ], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow:__subpackages__", + # TODO(mdan): Remove these dependencies. + "//third_party/py/tf_slim:__subpackages__", + "//learning/deepmind/research/language/translation/lm:__subpackages__", + ], + deps = [ + ":tf_stack", + "@six_archive//:six", + ], +) + +# Leaf library: may not depend on anything else inside TensorFlow. +py_strict_library( + name = "tf_stack", + srcs = ["util/tf_stack.py"], + srcs_version = "PY2AND3", + # TODO(mdan): Remove public visibility. + visibility = ["//visibility:public"], + deps = [ + ":_tf_stack", + "@six_archive//:six", + ], +) + pybind_extension( name = "_tf_stack", srcs = ["util/tf_stack.cc"], @@ -5409,13 +5445,28 @@ pybind_extension( ], ) +tf_py_test( + name = "tf_stack_test", + srcs = ["util/tf_stack_test.py"], + python_version = "PY3", + deps = [ + ":client_testlib", + ":tf_export", + ":tf_stack", + ], +) + py_library( name = "util", srcs = glob( ["util/**/*.py"], exclude = [ "util/example_parser*", + "util/tf_contextlib.py", "util/tf_should_use.py", + "util/tf_export.py", + "util/tf_stack.py", + "util/tf_decorator.py", "util/**/*_test.py", ], ), @@ -5426,7 +5477,9 @@ py_library( "//third_party/py/tf_agents:__subpackages__", ], deps = [ - ":_tf_stack", + ":tf_decorator", + ":tf_export", + ":tf_stack", "@org_python_pypi_backports_weakref", "@com_google_protobuf//:protobuf_python", "//third_party/py/numpy", @@ -5436,16 +5489,6 @@ py_library( ] + if_mlir(["//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration"]), ) -tf_py_test( - name = "tf_stack_test", - srcs = ["util/tf_stack_test.py"], - python_version = "PY3", - deps = [ - ":client_testlib", - ":util", - ], -) - tf_py_test( name = "object_identity_test", size = "small", From 19f2bd062289fe43f3bfc6fbabb740ca1459f8c4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 09:41:05 -0700 Subject: [PATCH 413/492] Make TPUStrategy API for enabling maximal sharding use sharding op instead of attributes. _XlaSharding attribute can be lost if ops are removed during graph transformation. To better preserve this information, default to using XlaShardingOp for Maximal sharding. PiperOrigin-RevId: 302452752 Change-Id: Ie7e61e99c6f23babf7505b88034d921fc28012cc --- tensorflow/python/distribute/tpu_strategy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index a45ac5785ee..05c9f75f09e 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -553,7 +553,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): "logical device id {} but there are only total of {} " "logical devices in replica.".format( logical_device_id, num_logical_devices_per_replica)) - return xla_sharding.assign_device(tensor, logical_device_id) + return xla_sharding.assign_device( + tensor, logical_device_id, use_sharding_op=True) def _experimental_split_to_logical_devices(self, tensor, partition_dimensions): From b8054c93d03fa35e77e8da174857cc544311b2bf Mon Sep 17 00:00:00 2001 From: Jian Li Date: Mon, 23 Mar 2020 09:42:04 -0700 Subject: [PATCH 414/492] Create a helper function to change a float models's interface to uint8. This is for users to use on inputs, rather than relying on infererence_input and inference_output type in the 2.0 converter. PiperOrigin-RevId: 302452948 Change-Id: I4b5a71f48046c3392b09675ddef3e30d845ce4ca --- tensorflow/lite/tools/optimize/BUILD | 1 + .../tools/optimize/modify_model_interface.cc | 104 +++++++++++++++ .../tools/optimize/modify_model_interface.h | 18 +++ .../optimize/modify_model_interface_test.cc | 125 +++++++++++++++++- 4 files changed, 244 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index c74c3f495d3..c3318f1ab26 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -17,6 +17,7 @@ cc_library( srcs = ["modify_model_interface.cc"], hdrs = ["modify_model_interface.h"], deps = [ + ":model_utils", "//tensorflow/lite:framework", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:compatibility", diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc index 7d51bc03434..bc1e9cbe5a3 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/tools/optimize/model_utils.h" namespace tflite { namespace optimize { @@ -360,5 +361,108 @@ TfLiteStatus ModifyModelInterface(const string& input_file, return kTfLiteOk; } +namespace { +void AddUint8Dequant( + const std::unordered_map>& quant_params, + ModelT* model) { + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); + subgraph_idx++) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); + // Add dequant to input tensors. + for (size_t input_idx = 0; input_idx < subgraph->inputs.size(); + input_idx++) { + const int32_t tensor_idx = subgraph->inputs[input_idx]; + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + if (tensor->type != TensorType_FLOAT32) { + continue; + } + if (quant_params.find(tensor->name) != quant_params.end()) { + // Add uint8 tensor + const string added_tensor_name = tensor->name + "_uint8"; + std::unique_ptr leading_op_input; + const std::pair& provided_quant_params = + quant_params.at(string(tensor->name)); + utils::MakeTensorWithQuantParam( + added_tensor_name, tensor->shape, TensorType_UINT8, + provided_quant_params.first, provided_quant_params.second, + &leading_op_input); + const int32_t leading_op_input_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(leading_op_input)); + + // Create the leading op, which is deqantize Op. + std::unique_ptr leading_op; + utils::MakeDequantizeOperator(model, &leading_op, leading_op_input_idx, + tensor_idx); + + // Insert the new op at the start of the model. + subgraph->operators.insert(subgraph->operators.begin(), + std::move(leading_op)); + } + } + } +} + +void AddUint8Quant( + const std::unordered_map>& quant_params, + ModelT* model) { + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); + subgraph_idx++) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); + // Add quant to output tensors. + for (size_t output_idx = 0; output_idx < subgraph->outputs.size(); + output_idx++) { + const int32_t tensor_idx = subgraph->outputs[output_idx]; + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + if (tensor->type != TensorType_FLOAT32) { + continue; + } + if (quant_params.find(tensor->name) != quant_params.end()) { + // Add uint8 tensor + const string added_tensor_name = tensor->name + "_uint8"; + std::unique_ptr tailing_op_output; + const std::pair& provided_quant_params = + quant_params.at(string(tensor->name)); + utils::MakeTensorWithQuantParam( + added_tensor_name, tensor->shape, TensorType_UINT8, + provided_quant_params.first, provided_quant_params.second, + &tailing_op_output); + const int32_t tailing_op_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(tailing_op_output)); + + // Create the tailing op, which is Qantize Op. + std::unique_ptr tailing_op; + utils::MakeQuantizeOperator(model, &tailing_op, tensor_idx, + tailing_op_output_idx); + + // Insert the new op at the end of the model. + subgraph->operators.push_back(std::move(tailing_op)); + } + } + } +} +} // namespace + +TfLiteStatus Uint8QuantizeModelInputsOutputs( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + const std::unordered_map>& + input_quant_params, + const std::unordered_map>& + output_quant_params) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + // Add Dequant for inputs. + AddUint8Dequant(input_quant_params, model.get()); + + // Add Quant for outputs. + AddUint8Quant(output_quant_params, model.get()); + + // Output model. + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + + return kTfLiteOk; +} + } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.h b/tensorflow/lite/tools/optimize/modify_model_interface.h index cfe4f41ff90..170e0e73a67 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.h +++ b/tensorflow/lite/tools/optimize/modify_model_interface.h @@ -39,6 +39,24 @@ TfLiteStatus ModifyModelInterface(const string& input_file, const TensorType& input_type, const TensorType& output_type); +// Adds uint8 quantize ops for specified inputs and uint8 dequantize ops for +// specified outputs for a float model. The scale and zero point of uint8 +// tensors are provided through quant_params. +// - input_quant_params has a map between tensor name and the +// pair for inputs. +// - output_quant_params has a map between tensor name and the +// pair for inputs. +// For the inputs/output tensors for the model, if its quantization parameters +// are not provided, that tensor is not affected. +// +// Note: This is a private API, subject to change. +TfLiteStatus Uint8QuantizeModelInputsOutputs( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + const std::unordered_map>& + input_quant_params, + const std::unordered_map>& + output_quant_params); + } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc index 01d0775c953..7e2744fb1e1 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/tools/optimize/modify_model_interface.h" +#include + +#include #include #include "absl/memory/memory.h" #include "tensorflow/lite/model.h" @@ -23,6 +26,8 @@ namespace tflite { namespace optimize { namespace { +using ::testing::ElementsAreArray; + // Create a model with 1 quant, 1 FC, 1 dequant std::unique_ptr CreateModelSingleInputOutput() { auto model = absl::make_unique(); @@ -238,7 +243,53 @@ std::unique_ptr CreateModelMultipleInputOutput() { return model; } -TEST(ModelInference, Uint8SingleInputOutput) { +// Create a model with 1 FC. +std::unique_ptr CreateFloatModel() { + auto model = absl::make_unique(); + auto subgraph = absl::make_unique(); + auto buffer = absl::make_unique(); + auto fc_op_code = absl::make_unique(); + auto fc_op = absl::make_unique(); + + model->subgraphs.push_back(std::move(subgraph)); + + // Op code + fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED; + fc_op_code->version = 2; + + // Op. + fc_op->opcode_index = 0; + fc_op->inputs = {0}; + fc_op->outputs = {1}; + + model->subgraphs[0]->operators.push_back(std::move(fc_op)); + model->operator_codes.push_back(std::move(fc_op_code)); + + // Model input/otuput. + model->subgraphs[0]->inputs = {0}; + model->subgraphs[0]->outputs = {1}; + + // Tensors + auto tensor_0 = absl::make_unique(); + tensor_0->name = "tensor_0"; + tensor_0->shape = {}; + tensor_0->type = TensorType_FLOAT32; + + auto tensor_1 = absl::make_unique(); + tensor_1->name = "tensor_1"; + tensor_1->shape = {}; + tensor_1->type = TensorType_FLOAT32; + + model->subgraphs[0]->tensors.push_back(std::move(tensor_0)); + model->subgraphs[0]->tensors.push_back(std::move(tensor_1)); + + // Buffer + model->buffers.push_back(std::move(buffer)); + + return model; +} + +TEST(ModelInterface, Uint8SingleInputOutput) { auto model = CreateModelSingleInputOutput(); // Ops. @@ -277,7 +328,7 @@ TEST(ModelInference, Uint8SingleInputOutput) { EXPECT_EQ(model->subgraphs[0]->operators[2]->opcode_index, 0); } -TEST(ModelInference, Int8SingleInputOutput) { +TEST(ModelInterface, Int8SingleInputOutput) { auto model = CreateModelSingleInputOutput(); // Change model type. @@ -299,7 +350,7 @@ TEST(ModelInference, Int8SingleInputOutput) { EXPECT_EQ(model->subgraphs[0]->outputs[0], 2); } -TEST(ModelInference, Uint8MutipleInputOutput) { +TEST(ModelInterface, Uint8MutipleInputOutput) { auto model = CreateModelMultipleInputOutput(); // Ops. @@ -362,7 +413,7 @@ TEST(ModelInference, Uint8MutipleInputOutput) { EXPECT_EQ(model->subgraphs[0]->operators[4]->opcode_index, 0); } -TEST(ModelInference, Int8MutipleInputOutput) { +TEST(ModelInterface, Int8MutipleInputOutput) { auto model = CreateModelMultipleInputOutput(); // Change model type. @@ -413,6 +464,72 @@ TEST(ModelInference, Int8MutipleInputOutput) { EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1); } +TEST(ModelInterface, Float) { + // Create the model. + std::unique_ptr input_model_t = CreateFloatModel(); + flatbuffers::FlatBufferBuilder builder_temp; + flatbuffers::Offset output_model_location = + Model::Pack(builder_temp, input_model_t.get()); + FinishModelBuffer(builder_temp, output_model_location); + const uint8_t* buffer_temp = builder_temp.GetBufferPointer(); + const Model* input_model = GetModel(buffer_temp); + + // Change model type. + flatbuffers::FlatBufferBuilder builder; + EXPECT_EQ(Uint8QuantizeModelInputsOutputs(&builder, input_model, + {{"tensor_0", {0.4, 2}}}, + {{"tensor_1", {0.5, -5}}}), + kTfLiteOk); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + std::unique_ptr model; + model.reset(output_model->UnPack()); + + // Verify results. + EXPECT_EQ(model->operator_codes.size(), 3); + EXPECT_EQ(model->subgraphs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->operators.size(), 3); + EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4); + EXPECT_EQ(model->buffers.size(), 1); + + // Ops. + EXPECT_EQ(model->operator_codes[0]->builtin_code, + BuiltinOperator_FULLY_CONNECTED); + EXPECT_EQ(model->operator_codes[1]->builtin_code, BuiltinOperator_DEQUANTIZE); + EXPECT_EQ(model->operator_codes[2]->builtin_code, BuiltinOperator_QUANTIZE); + + EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1); + EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 0); + EXPECT_EQ(model->subgraphs[0]->operators[2]->opcode_index, 2); + + EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({2})); + EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs, + ElementsAreArray({0})); + EXPECT_THAT(model->subgraphs[0]->operators[1]->inputs, ElementsAreArray({0})); + EXPECT_THAT(model->subgraphs[0]->operators[1]->outputs, + ElementsAreArray({1})); + EXPECT_THAT(model->subgraphs[0]->operators[2]->inputs, ElementsAreArray({1})); + EXPECT_THAT(model->subgraphs[0]->operators[2]->outputs, + ElementsAreArray({3})); + + // Tensors. + EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "tensor_0"); + EXPECT_EQ(model->subgraphs[0]->tensors[0]->type, TensorType_FLOAT32); + EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "tensor_1"); + EXPECT_EQ(model->subgraphs[0]->tensors[1]->type, TensorType_FLOAT32); + + EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "tensor_0_uint8"); + EXPECT_EQ(model->subgraphs[0]->tensors[2]->type, TensorType_UINT8); + EXPECT_FLOAT_EQ(model->subgraphs[0]->tensors[2]->quantization->scale[0], 0.4); + EXPECT_EQ(model->subgraphs[0]->tensors[2]->quantization->zero_point[0], 2); + + EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "tensor_1_uint8"); + EXPECT_EQ(model->subgraphs[0]->tensors[3]->type, TensorType_UINT8); + EXPECT_FLOAT_EQ(model->subgraphs[0]->tensors[3]->quantization->scale[0], 0.5); + EXPECT_EQ(model->subgraphs[0]->tensors[3]->quantization->zero_point[0], -5); +} + } // namespace } // namespace optimize } // namespace tflite From 8e9b8a438eb409a05e594f5e6020f113c299d09b Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Mon, 23 Mar 2020 09:53:33 -0700 Subject: [PATCH 415/492] Improved convolution performance in Metal backend. Improved performance on small sizes. Better convolution selection for A11. PiperOrigin-RevId: 302455205 Change-Id: I5e01385da9354eea28e325b7c19201b6db169318 --- tensorflow/lite/delegates/gpu/metal/api.cc | 45 +- .../lite/delegates/gpu/metal/kernels/BUILD | 1 + .../lite/delegates/gpu/metal/kernels/conv.cc | 1750 +++++++---------- .../lite/delegates/gpu/metal/kernels/conv.h | 59 +- 4 files changed, 735 insertions(+), 1120 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 5eb7d284ad1..4c0af17090e 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -51,39 +51,6 @@ namespace tflite { namespace gpu { namespace metal { namespace { - -std::vector SelectConvolution( - const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) { - // Special precise version, in case we cover dst_shape poorly with standard - // work group size. - auto gpu_type = GetGpuType(); - bool a11_12 = gpu_type == GpuType::kA11 || gpu_type == GpuType::kA12; - const auto dst_shape = graph.FindOutputs(id)[0]->tensor.shape; - if (GetThreadsRatioUsualToPreciseConvolution(dst_shape) >= 1.2f) { - // Special version for PowerVR >= IPhone6S/SE - // Metal has bad driver for PowerVR in IPhone6, so for Iphone6 we should use - // default kernel with shared memory. - if ((gpu_type == GpuType::kA9 || gpu_type == GpuType::kA10) && - CheckConvolutionPrecise1x1Support(attr)) { - return ConvolutionPrecise1x1PowerVR(id, input_id, output_id, attr, - options); - } - if (a11_12 && GetThreadsRatioUsualToPreciseConvolution(dst_shape) >= 1.2f) { - return ConvolutionPrecise(id, input_id, output_id, attr, options); - } - } - if (a11_12) { - if (CheckConvolution1x1Support(attr)) { - return Convolution1x1(id, input_id, output_id, attr, options); - } else { - return ConvolutionGeneric(id, input_id, output_id, attr, options); - } - } else { - return Convolution(id, input_id, output_id, attr, options); - } -} - std::vector SelectDepthWiseConv( int id, ValueId input_id, ValueId output_id, const DepthwiseConvolution2DAttributes& attr, @@ -182,12 +149,14 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, input_shapes); break; } - case OperationType::CONVOLUTION_2D: - *tasks = SelectConvolution( - graph, node_id, inputs[0], outputs[0], - absl::any_cast(node->operation.attributes), - options); + case OperationType::CONVOLUTION_2D: { + const auto dst_shape = graph.FindOutputs(node_id)[0]->tensor.shape; + auto attr = + absl::any_cast(node->operation.attributes); + *tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape, + attr, options); break; + } case OperationType::CONVOLUTION_TRANSPOSED: *tasks = SelectConvolutionTransposed( node_id, inputs[0], outputs[0], diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index 70d882bb05b..f22fe642ca3 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -127,6 +127,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:environment", "//tensorflow/lite/delegates/gpu/metal:runtime_options", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc index 60ac73abfaa..73f152412a9 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc @@ -30,519 +30,181 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { namespace metal { + +enum class WeightsUploadType { + LOCAL_MEM_BY_THREADS, + GLOBAL_MEM, + CONSTANT_MEM, +}; + +struct ConvParams { + int3 block_size; + int3 work_group_size; + int3 work_group_launch_order; + int src_depth_loop_size; + bool need_src_loop = true; + bool need_dst_loop = true; + bool linear_wh; + bool linear_whs; + WeightsUploadType weights_upload_type; + bool x_kernel_is_1; + bool y_kernel_is_1; +}; + namespace { int GetNumOutputSlices(int dst_channels) { const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); - if (dst_depth % 4 == 0) { + if (dst_depth % 4 == 0 || dst_depth >= 16) { return 4; - } else if (dst_depth % 2 == 0) { + } else if (dst_depth % 2 == 0 || dst_depth >= 4) { return 2; } else { return 1; } } -int GetSrcBatchSize(int dst_channels) { - const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); - if (dst_depth % 4 == 0) { - return 2; - } else if (dst_depth % 2 == 0) { - return 4; - } else { - return 8; - } -} - -std::string GetValuesDeclarationPart(int num_output_slices, bool is_1x1) { - std::string code; - for (int d = 0; d < num_output_slices; ++d) { - code += absl::Substitute(R"( - float4 sum$0 = float4(0.0f, 0.0f, 0.0f, 0.0f); - )", - d); - } - if (is_1x1) { - code += absl::Substitute(R"( - threadgroup FLT4 temp[32]; - device FLT4* f_offseted = weights + (gid.z + params.z_offset.x) * $0 * src_offset; - )", - num_output_slices * 4); - } else { - code += absl::Substitute(R"( - threadgroup FLT4 temp[32]; - device FLT4* f_offseted = weights + (gid.z + params.z_offset.x) * $0 * src_offset * - kernel_y * kernel_x; - )", - num_output_slices * 4); - } - return code; -} - -std::string GetLocalMemoryUploadPart() { - std::string code = R"( - BARRIER(mem_flags::mem_none); - temp[tid] = f_offseted[tid]; - f_offseted += 32; - BARRIER(mem_flags::mem_threadgroup); - )"; - return code; -} - -std::string GetSummationPart(int num_output_slices, int index) { - std::string code = R"( - { - const FLT4 src = src_buffer[src_address]; - src_address += params.dilation_layer_offsets.z; - )"; - for (int d = 0; d < num_output_slices; ++d) { - code += absl::Substitute(R"( - sum$6.x += dot(temp[$0 * $1 + $2], src) * multiplier; - sum$6.y += dot(temp[$0 * $1 + $3], src) * multiplier; - sum$6.z += dot(temp[$0 * $1 + $4], src) * multiplier; - sum$6.w += dot(temp[$0 * $1 + $5], src) * multiplier; - )", - index, num_output_slices * 4, d * 4 + 0, d * 4 + 1, - d * 4 + 2, d * 4 + 3, d); - } - code += "}"; - return code; -} - -std::string GetBiasReadingPart(int num_output_slices) { - std::string code = absl::Substitute(R"( - { - gid.z = (gid.z + params.z_offset.x) * $0; - BARRIER(mem_flags::mem_none); - if (tid < $0) { - temp[tid] = biases[gid.z + tid]; - } - BARRIER(mem_flags::mem_threadgroup); - if (outside) { - return; - } - })", - num_output_slices); - return code; -} - -std::string GetWritingPart(int num_output_slices) { - std::string code; - for (int d = 0; d < num_output_slices; ++d) { - code += absl::Substitute(R"( - { - int dst_address = int(gid.y) * params.size.z + int(gid.x); - FLT4 value = FLT4(sum$0) + temp[$0]; - const int linear_index = gid.z * params.dilation_layer_offsets.w + dst_address; - $$2 - dst_buffer[linear_index + params.z_offset.y] = value; - gid.z += 1; - })", - d); - } - return code; -} - -std::string GetKernelForConv(const Convolution2DAttributes& params) { - const int num_output_slices = GetNumOutputSlices(params.weights.shape.o); - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - const bool is_1x1 = - params.weights.shape.w == 1 && params.weights.shape.h == 1; - const bool is_strided = params.strides.w > 1 || params.strides.h > 1; - const int src_group_size = GetSrcBatchSize(params.weights.shape.o); - - const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); - const int src_groups = src_depth / src_group_size; - const int src_depth_aligned = AlignByN(src_depth, src_group_size); - const int reminder_src_depth = src_depth - src_groups * src_group_size; - - code = absl::Substitute(R"( - #include - using namespace metal; - constant int src_depth_groups = $0; - constant int src_offset = $1; - constant int kernel_x = $2; - constant int kernel_y = $3; - struct uniforms { - int4 stride_padding; - int4 dilation_layer_offsets; - int4 size; - int4 z_offset; - }; - $$0 - kernel void ComputeFunction( - $$1 - uint tid[[thread_index_in_threadgroup]], - uint3 gid[[thread_position_in_grid]]) - { - const bool outside = static_cast(gid.x) >= params.size.z || - static_cast(gid.y) >= params.size.w; - )", - src_groups, src_depth_aligned, params.weights.shape.w, - params.weights.shape.h); - code += GetValuesDeclarationPart(num_output_slices, is_1x1); - - if (!is_1x1) { - code += R"( - for(int ky = 0; ky < kernel_y; ++ky) { - for(int kx = 0; kx < kernel_x; ++kx) { - int2 coords = int2(gid.xy) * params.stride_padding.xy + int2(kx, ky) * - params.dilation_layer_offsets.xy - params.stride_padding.zw; - const bool el_outside = coords.x < 0 || coords.y < 0 || coords.x >= params.size.x || - coords.y >= params.size.y; - const FLT multiplier = el_outside ? 0.0f : 1.0f; - )"; - } else { - code += "const FLT multiplier = 1.0f;\n"; - code += "int2 coords = int2(gid.xy)"; - if (is_strided) { - code += " * params.stride_padding.xy"; - } - code += ";\n"; - } - code += R"( - coords = clamp(coords, int2(0, 0), int2(params.size.x - 1, params.size.y - 1)); - int src_address = coords.y * params.size.x + coords.x; - for(int s = 0; s < src_depth_groups; ++s) { - )"; - code += GetLocalMemoryUploadPart(); - for (int sub_s = 0; sub_s < src_group_size; ++sub_s) { - code += GetSummationPart(num_output_slices, sub_s); - } - code += R"( - } - )"; - if (reminder_src_depth != 0) { - code += GetLocalMemoryUploadPart(); - for (int sub_s = 0; sub_s < reminder_src_depth; ++sub_s) { - code += GetSummationPart(num_output_slices, sub_s); - } - } - if (!is_1x1) { - code += R"( - } - } - )"; - } - code += GetBiasReadingPart(num_output_slices); - code += GetWritingPart(num_output_slices); - code += " }"; - return code; -} - -// Reorder weights to make the weights memory access pattern cache friendly for -// GPU -std::vector ReorderWeightsForConvShared( - const Convolution2DAttributes& params) { - const int dst_batch_size = GetNumOutputSlices(params.weights.shape.o) * 4; - const int src_batch_size = GetSrcBatchSize(params.weights.shape.o); - BHWC input_dimensions{params.weights.shape.o, params.weights.shape.h, - params.weights.shape.w, params.weights.shape.i}; - const int gpu_simd_size = dst_batch_size * src_batch_size; - const int weights_width = AlignByN(input_dimensions.c, gpu_simd_size); - const int weights_height = AlignByN(input_dimensions.b, dst_batch_size); - const int weights_channels = params.weights.shape.w * params.weights.shape.h; - const int weights_aligned_size = - weights_width * weights_height * weights_channels; - std::vector weights_reordered(weights_aligned_size); - float* destination = weights_reordered.data(); - const int dst_groups = - IntegralDivideRoundUp(input_dimensions.b, dst_batch_size); - const int src_sub_groups = - IntegralDivideRoundUp(input_dimensions.c, 4 * src_batch_size); - for (int group = 0; group < dst_groups; ++group) { - for (int y = 0; y < params.weights.shape.h; ++y) { - for (int x = 0; x < params.weights.shape.w; ++x) { - for (int sub_group = 0; sub_group < src_sub_groups; ++sub_group) { - for (int s = 0; s < src_batch_size; ++s) { - for (int d = 0; d < dst_batch_size; ++d) { - int output_index = group * dst_batch_size + d; - for (int i = 0; i < 4; ++i) { - int input_index = (sub_group * src_batch_size + s) * 4 + i; - if (input_index >= input_dimensions.c || - output_index >= input_dimensions.b) { - // Padding with zero - *destination++ = 0.0f; - } else { - int linear_index = - input_index + - input_dimensions.c * - (x + input_dimensions.w * - (y + input_dimensions.h * output_index)); - *destination++ = params.weights.data[linear_index]; - } - } - } - } - } - } - } - } - return weights_reordered; -} - -std::vector GetUniformBufferForConvShared( - const BHWC& input_dimensions, const BHWC& output_dimensions, - const Convolution2DAttributes& params) { - std::vector uniform_params = { - params.strides.w, - params.strides.h, - params.padding.prepended.w, - params.padding.prepended.h, - params.dilations.w, - params.dilations.h, - input_dimensions.w * input_dimensions.h, - output_dimensions.w * output_dimensions.h, - input_dimensions.w, - input_dimensions.h, - output_dimensions.w, - output_dimensions.h, - // TODO(chirkov): use z_offset for concat table optimization - /*z_offset.x=*/0, - /*z_offset.y=*/0, - /*z_offset.z=*/0, - /*z_offset.w=*/0, - }; - return GetByteBuffer(uniform_params); -} - -std::string GetKernelForConv1x1(const Convolution2DAttributes& params, - int z_out) { - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - std::string channels[4] = {"x", "y", "z", "w"}; - code += R"( -#include -using namespace metal; - -struct uniforms { - int4 src_size; - int4 dst_size; - int4 stride_padding; - int4 kernel_dilation; - uint4 work_group_size; +struct GlobalIdsParams { + std::vector global_ids; + std::vector group_ids; + std::vector local_sizes; + std::vector local_ids; + int3 block_size; + int3 launch_order; + bool linear_wh; + bool linear_whs; + std::string task_size_w; // must be filled if linear_wh or linear_whs enabled + std::string task_size_wh; // must be filled if linear_whs enabled }; -$0 -kernel void ComputeFunction( - $1 - uint3 group_id[[threadgroup_position_in_grid]], - uint3 tid3d[[thread_position_in_threadgroup]]) -{ - int gid_x = group_id.y * params.work_group_size.x + tid3d.x; - int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) << 1u; - )"; - code += " int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " + - std::to_string(z_out) + "u;\n"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; +std::string GlobalIdsGen(const GlobalIdsParams& params) { + std::string c; + int3 launch_remap; + launch_remap[params.launch_order.x] = 0; + launch_remap[params.launch_order.y] = 1; + launch_remap[params.launch_order.z] = 2; + if (params.linear_whs) { + c += " int linear_whs = " + params.global_ids[0] + ";\n"; + c += " int Z = (linear_whs / " + params.task_size_wh + ") * " + + std::to_string(params.block_size.z) + ";\n"; + c += " int linear_wh = linear_whs % " + params.task_size_wh + ";\n"; + c += " int Y = (linear_wh / " + params.task_size_w + ") * " + + std::to_string(params.block_size.y) + ";\n"; + c += " int X = (linear_wh % " + params.task_size_w + ") * " + + std::to_string(params.block_size.x) + ";\n"; + } else if (params.linear_wh) { + if (params.launch_order.x == 0) { + c += " int linear_wh = " + params.global_ids[0] + ";\n"; + } else { + c += " int linear_wh = " + params.group_ids[launch_remap.x] + " * " + + params.local_sizes[0] + " + " + params.local_ids[0] + ";\n"; + } + c += " int Y = (linear_wh / " + params.task_size_w + ") * " + + std::to_string(params.block_size.y) + ";\n"; + c += " int X = (linear_wh % " + params.task_size_w + ") * " + + std::to_string(params.block_size.x) + ";\n"; + if (params.launch_order.y == 1) { + c += " int Z = " + params.global_ids[1] + " * " + + std::to_string(params.block_size.z) + ";\n"; + } else { + c += " int Z = (" + params.group_ids[launch_remap.y] + " * " + + params.local_sizes[1] + " + " + params.local_ids[1] + ") * " + + std::to_string(params.block_size.z) + ";\n"; + } + } else { + if (params.launch_order.x == 0) { + c += " int X = " + params.global_ids[0] + " * " + + std::to_string(params.block_size.x) + ";\n"; + } else { + c += " int X = (" + params.group_ids[launch_remap.x] + " * " + + params.local_sizes[0] + " + " + params.local_ids[0] + ") * " + + std::to_string(params.block_size.x) + ";\n"; + } + if (params.launch_order.y == 1) { + c += " int Y = " + params.global_ids[1] + " * " + + std::to_string(params.block_size.y) + ";\n"; + } else { + c += " int Y = (" + params.group_ids[launch_remap.y] + " * " + + params.local_sizes[1] + " + " + params.local_ids[1] + ") * " + + std::to_string(params.block_size.y) + ";\n"; + } + if (params.launch_order.z == 2) { + c += " int Z = " + params.global_ids[2] + " * " + + std::to_string(params.block_size.z) + ";\n"; + } else { + c += " int Z = (" + params.group_ids[launch_remap.z] + " * " + + params.local_sizes[2] + " + " + params.local_ids[2] + ") * " + + std::to_string(params.block_size.z) + ";\n"; + } } - code += R"( - device FLT4* tmp = filters + gid_z * 4 * params.src_size.w; - - int y0 = clamp(gid_y, 0, params.src_size.y - 1); - int y1 = clamp(gid_y + 1, 0, params.src_size.y - 1); - int x0 = clamp(gid_x, 0, params.src_size.x - 1); - - int s = 0; - - device FLT4* src_loc_0 = src_buffer + y0 * params.src_size.x + x0; - device FLT4* src_loc_1 = src_buffer + y1 * params.src_size.x + x0; - do { - FLT4 src_0 = *src_loc_0; - FLT4 src_1 = *src_loc_1; - src_loc_0 += params.src_size.z; - src_loc_1 += params.src_size.z; - )"; - for (int i = 0; i < z_out * 4; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_0);\n"; - code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_1);\n"; - } - - code += " tmp += " + std::to_string(z_out * 4) + ";\n"; - code += R"( - s += 1; - } while (s < params.src_size.w); - const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x; - const int offset_1 = offset_0 + params.dst_size.x; - bool y0_in = gid_y < params.dst_size.y; - bool y1_in = gid_y + 1 < params.dst_size.y; - - device FLT4* bias_loc = biases + gid_z; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - } - code += R"( - if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) { - return; - } - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; - code += " if (y0_in) {\n"; - code += " FLT4 value = FLT4(r" + s_i + ");\n"; - code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " if (y1_in) {\n"; - code += " FLT4 value = FLT4(l" + s_i + ");\n"; - code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " }\n"; - } - code += " }\n"; - return code; + return c; } -std::string GetKernelForConvGeneric(const Convolution2DAttributes& params, - int z_out) { - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - std::string channels[4] = {"x", "y", "z", "w"}; - code += R"( -#include -using namespace metal; - -struct uniforms { - int4 src_size; - int4 dst_size; - int4 stride_padding; - int4 kernel_dilation; - uint4 work_group_size; -}; -$0 - -kernel void ComputeFunction( - $1 - uint3 group_id[[threadgroup_position_in_grid]], - uint3 tid3d[[thread_position_in_threadgroup]]) -{ - int gid_x = group_id.y * params.work_group_size.x + tid3d.x; - int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) * 2; - )"; - code += " int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " + - std::to_string(z_out) + "u;\n"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; +std::string GenerateUploadByThreads(const std::string& local_ptr_name, + const std::string& global_ptr_name, + const std::string& global_offset_name, + const std::string& lid_name, + int total_work_items, + int elements_to_upload) { + std::string c; + std::string offset = + global_offset_name.empty() ? "" : global_offset_name + " + "; + const int groups = elements_to_upload / total_work_items; + const int reminder = elements_to_upload % total_work_items; + for (int i = 0; i < groups; ++i) { + c += " " + local_ptr_name + "[" + lid_name + " + " + + std::to_string(total_work_items * i) + "] = " + global_ptr_name + "[" + + offset + lid_name + " + " + std::to_string(total_work_items * i) + + "];\n"; } - code += R"( - device FLT4* tmp = filters + gid_z * 4 * params.src_size.w * params.kernel_dilation.x * params.kernel_dilation.y; - - int y0 = gid_y * params.stride_padding.y + params.stride_padding.w; - int y1 = (gid_y + 1) * params.stride_padding.y + params.stride_padding.w; - int x0 = gid_x * params.stride_padding.x + params.stride_padding.z; - - int y = 0; - do { - int coord_y0 = y * params.kernel_dilation.w + y0; - int coord_y1 = y * params.kernel_dilation.w + y1; - bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y; - bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y; - coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1); - coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1); - int x = 0; - do { - int coord_x0 = x * params.kernel_dilation.z + x0; - bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x; - coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1); - FLT m0 = !(y0_out || x0_out); - FLT m1 = !(y1_out || x0_out); - int s = 0; - device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0; - device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x0; - do { - FLT4 src_0 = *src_loc_0 * m0; - FLT4 src_1 = *src_loc_1 * m1; - src_loc_0 += params.src_size.z; - src_loc_1 += params.src_size.z; - )"; - for (int i = 0; i < z_out * 4; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_0);\n"; - code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_1);\n"; + if (reminder != 0) { + c += " if (" + lid_name + " < " + std::to_string(reminder) + ") {\n"; + c += " " + local_ptr_name + "[" + lid_name + " + " + + std::to_string(total_work_items * groups) + "] = " + global_ptr_name + + "[" + offset + lid_name + " + " + + std::to_string(total_work_items * groups) + "];\n"; + c += " }\n"; } - - code += " tmp += " + std::to_string(z_out * 4) + ";\n"; - code += R"( - s += 1; - } while (s < params.src_size.w); - x++; - } while (x < params.kernel_dilation.x); - y++; - } while (y < params.kernel_dilation.y); - const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x; - const int offset_1 = offset_0 + params.dst_size.x; - bool p0_in = gid_x < params.dst_size.x && gid_y < params.dst_size.y; - bool p1_in = gid_x < params.dst_size.x && gid_y + 1 < params.dst_size.y; - - device FLT4* bias_loc = biases + gid_z; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - } - code += R"( - if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) { - return; - } - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; - code += " if (p0_in) {\n"; - code += " FLT4 value = FLT4(r" + s_i + ");\n"; - code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " if (p1_in) {\n"; - code += " FLT4 value = FLT4(l" + s_i + ");\n"; - code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " }\n"; - } - code += " }\n"; - return code; + return c; } -std::string GetKernelForConvPrecise(int z_out) { +std::string GenerateConvolution(const ConvParams& params) { + GlobalIdsParams ids_params; + ids_params.group_ids = {"group_id.x", "group_id.y", "group_id.z"}; + ids_params.global_ids = {"ugid.x", "ugid.y", "ugid.z"}; + ids_params.local_ids = {"tid3d.x", "tid3d.y", "tid3d.z"}; + ids_params.local_sizes = {"params.work_group_size.x", + "params.work_group_size.y", + "params.work_group_size.z"}; + ids_params.linear_wh = params.linear_wh; + ids_params.task_size_w = "params.task_sizes.x"; + ids_params.task_size_wh = "params.task_sizes.y"; + ids_params.linear_whs = params.linear_whs; + ids_params.block_size = params.block_size; + ids_params.launch_order = params.work_group_launch_order; + + std::string addr_space = + params.weights_upload_type == WeightsUploadType::CONSTANT_MEM ? "constant" + : "device"; + const bool use_local_mem = + params.weights_upload_type == WeightsUploadType::LOCAL_MEM_BY_THREADS; + const int local_mem_size = + params.block_size.z * 4 * params.src_depth_loop_size; + + const bool use_filters_constants = + !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 && + params.y_kernel_is_1; + std::string channels[4] = {"x", "y", "z", "w"}; - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - code += R"( + std::string c; + c.reserve(16 * 1024); // Reserve large enough buffer. + c += R"( #include using namespace metal; @@ -551,209 +213,298 @@ struct uniforms { int4 dst_size; int4 stride_padding; int4 kernel_dilation; - int4 slices; + int4 task_sizes; + uint4 work_group_size; }; $0 kernel void ComputeFunction( $1 + uint tid[[thread_index_in_threadgroup]], + uint3 group_id[[threadgroup_position_in_grid]], + uint3 tid3d[[thread_position_in_threadgroup]], uint3 ugid[[thread_position_in_grid]]) { - int linear_id = ugid.x; - int gid_z = linear_id / params.slices.y; - int linear_xy = (linear_id - gid_z * params.slices.y) << 1; - )"; - code += " gid_z *= " + std::to_string(z_out) + ";\n"; - code += R"( - int gid_y0 = linear_xy / params.slices.x; - int gid_x0 = linear_xy - gid_y0 * params.slices.x; - linear_xy += 1; - int gid_y1 = linear_xy / params.slices.x; - int gid_x1 = linear_xy - gid_y1 * params.slices.x; - - if (gid_z >= params.dst_size.w) return; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - } - code += R"( - device FLT4* tmp = filters + gid_z * 4 * params.src_size.w * - params.kernel_dilation.x * params.kernel_dilation.y; - - int y0 = gid_y0 * params.stride_padding.y + params.stride_padding.w; - int y1 = gid_y1 * params.stride_padding.y + params.stride_padding.w; - int x0 = gid_x0 * params.stride_padding.x + params.stride_padding.z; - int x1 = gid_x1 * params.stride_padding.x + params.stride_padding.z; )"; - code += R"( - int y = 0; - do { - int coord_y0 = y * params.kernel_dilation.w + y0; - int coord_y1 = y * params.kernel_dilation.w + y1; - bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y; - bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y; - coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1); - coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1); - int x = 0; - do { - int coord_x0 = x * params.kernel_dilation.z + x0; - int coord_x1 = x * params.kernel_dilation.z + x1; - bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x; - bool x1_out = coord_x1 < 0 || coord_x1 >= params.src_size.x; - coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1); - coord_x1 = clamp(coord_x1, 0, params.src_size.x - 1); - FLT m0 = !(y0_out || x0_out); - FLT m1 = !(y1_out || x1_out); - device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0; - device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x1; - int s = 0; - do { - FLT4 src_0 = *src_loc_0 * m0; - FLT4 src_1 = *src_loc_1 * m1; - src_loc_0 += params.src_size.z; - src_loc_1 += params.src_size.z; -)"; - for (int i = 0; i < z_out * 4; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_0);\n"; - code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_1);\n"; + c += GlobalIdsGen(ids_params); + c += " if (Z >= params.dst_size.w) return;\n"; + if (!use_local_mem && !params.linear_whs) { + c += " if (X >= params.dst_size.x || Y >= params.dst_size.y) return;\n"; + } + for (int z = 0; z < params.block_size.z; ++z) { + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_i = + std::to_string(z) + std::to_string(y) + std::to_string(x); + c += + " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } + } + } + auto for_every_yx = + [&](std::function + lambda) { + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + const std::string s_yx = s_y + s_x; + c += lambda(s_yx, s_x, s_y, x, y) + "\n"; + } + } + }; + if (!use_filters_constants) { + std::string kern_x = + params.x_kernel_is_1 ? "" : " * params.kernel_dilation.x"; + std::string kern_y = + params.y_kernel_is_1 ? "" : " * params.kernel_dilation.y"; + std::string dst_offset = + params.need_dst_loop ? " + Z * 4 * params.src_size.w" : ""; + if (!params.need_dst_loop) { + c += " " + addr_space + " FLT4* tmp = filters;\n"; + } else { + c += " " + addr_space + + " FLT4* tmp = filters + Z * 4 * params.src_size.w" + kern_x + + kern_y + ";\n"; + } + } + if (!params.x_kernel_is_1) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + c += " int x" + s_x + " = (X + " + s_x + + ") * params.stride_padding.x + params.stride_padding.z;\n"; + } + } + if (!params.y_kernel_is_1) { + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + c += " int y" + s_y + " = (Y + " + s_y + + ") * params.stride_padding.y + params.stride_padding.w;\n"; + } + } + if (use_local_mem) { + c += " threadgroup FLT4 weights_cache[" + std::to_string(local_mem_size) + + "];\n"; + } + if (!params.y_kernel_is_1) { + c += " int y = 0;\n"; + c += " do {\n"; + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + c += " int c_y" + s_y + " = y * params.kernel_dilation.w + y" + s_y + + ";\n"; + c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y + + " >= params.src_size.y;\n"; + c += " c_y" + s_y + " = clamp(c_y" + s_y + + ", 0, params.src_size.y - 1);\n"; + } + } else { + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + c += " int c_y" + s_y + " = clamp(Y + " + s_y + + ", 0, params.src_size.y - 1);\n"; + } + } + if (!params.x_kernel_is_1) { + c += " int x = 0;\n"; + c += " do {\n"; + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + c += " int c_x" + s_x + " = x * params.kernel_dilation.z + x" + s_x + + ";\n"; + c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x + + " >= params.src_size.x;\n"; + c += " c_x" + s_x + " = clamp(c_x" + s_x + + ", 0, params.src_size.x - 1);\n"; + } + } else { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + c += " int c_x" + s_x + " = clamp(X + " + s_x + + ", 0, params.src_size.x - 1);\n"; + } + } + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + const std::string s_yx = s_y + s_x; + if (!params.y_kernel_is_1 && !params.x_kernel_is_1) { + c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x + "_out);\n"; + } else if (!params.y_kernel_is_1) { + c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n"; + } else if (!params.x_kernel_is_1) { + c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n"; + } + } + } + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + const std::string s_yx = s_y + s_x; + c += " device FLT4* src_loc_" + s_yx + " = src_buffer + c_y" + s_y + + " * params.src_size.x + c_x" + s_x + ";\n"; + } + } + c += " int s = 0;\n"; + if (params.need_src_loop) { + c += " do {\n"; + } + if (use_local_mem) { + const int total_work_items = params.work_group_size.x * + params.work_group_size.y * + params.work_group_size.z; + c += " BARRIER(mem_flags::mem_none);\n"; + c += GenerateUploadByThreads("weights_cache", "tmp", + /*global_offset_name*/ "", "tid", + total_work_items, local_mem_size); + c += " BARRIER(mem_flags::mem_threadgroup);\n"; + } + auto declare_src = [&]() { + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_yx = std::to_string(y) + std::to_string(x); + c += " FLT4 src" + s_yx + ";\n"; + } + } + }; + auto read_src = [&]() { + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_yx = std::to_string(y) + std::to_string(x); + if (!params.y_kernel_is_1 || !params.x_kernel_is_1) { + c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx + ";\n"; + } else { + c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n"; + } + } + } + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_yx = std::to_string(y) + std::to_string(x); + c += " src_loc_" + s_yx + " += params.src_size.z;\n"; + } + } + }; + auto conv_core = [&](int offset) { + std::string name = use_local_mem ? "weights_cache" : "tmp"; + if (use_filters_constants) { + name = "filters"; + } + for (int z = 0; z < params.block_size.z; ++z) { + for (int ch = 0; ch < 4; ++ch) { + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + std::string s_id = std::to_string(y) + std::to_string(x); + std::string r_id = + std::to_string(z) + std::to_string(y) + std::to_string(x); + c += " r" + r_id + "." + channels[ch] + " += dot(" + name + "[" + + std::to_string(z * 4 + ch + offset) + "], src" + s_id + ");\n"; + } + } + } + } + }; + declare_src(); + read_src(); + c += " s += 1;\n"; + conv_core(0); + for (int i = 1; i < params.src_depth_loop_size; ++i) { + read_src(); + conv_core(i * params.block_size.z * 4); + c += " s += 1;\n"; + } + if (!use_filters_constants) { + c += " tmp += " + + std::to_string(params.block_size.z * 4 * params.src_depth_loop_size) + + ";\n"; + } + if (params.need_src_loop) { + c += " } while (s < params.src_size.w);\n"; + } + if (!params.x_kernel_is_1) { + c += " x++;\n"; + c += " } while (x < params.kernel_dilation.x);\n"; + } + if (!params.y_kernel_is_1) { + c += " y++;\n"; + c += " } while (y < params.kernel_dilation.y);\n"; } - code += " tmp += " + std::to_string(z_out * 4) + ";\n"; - code += R"( - s += 1; - } while (s < params.src_size.w); - x++; - } while (x < params.kernel_dilation.x); - y++; - } while (y < params.kernel_dilation.y); - const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0; - const int offset_1 = gid_z * params.dst_size.z + gid_y1 * params.dst_size.x + gid_x1; - bool p0_in = gid_x0 < params.dst_size.x && gid_y0 < params.dst_size.y; - bool p1_in = gid_x1 < params.dst_size.x && gid_y1 < params.dst_size.y; + if (use_local_mem && !params.linear_whs) { + c += " if (X >= params.dst_size.x || Y >= params.dst_size.y) return;\n"; + } - device FLT4* bias_loc = biases + gid_z; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; + for_every_yx([](const std::string& s_yx, const std::string& s_x, + const std::string& s_y, int x, int y) { + return " const int offset_" + s_yx + " = Z * params.dst_size.z + (Y + " + + s_y + ") * params.dst_size.x + X + " + s_x + ";"; + }); + + std::string bias_name = "biases"; + if (params.need_dst_loop) { + c += " device FLT4* bias_loc = biases + Z;\n"; + bias_name = "bias_loc"; } - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; - code += " if (p0_in) {\n"; - code += " FLT4 value = FLT4(r" + s_i + ");\n"; - code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " if (p1_in) {\n"; - code += " FLT4 value = FLT4(l" + s_i + ");\n"; - code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x1, gid_y1, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " }\n"; + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + for (int z = 0; z < params.block_size.z; ++z) { + std::string r_id = + std::to_string(z) + std::to_string(y) + std::to_string(x); + c += " r" + r_id + " += TO_ACCUM4_TYPE(" + bias_name + "[" + + std::to_string(z) + "]);\n"; + } + } } - code += " }\n"; - return code; + for (int z = 0; z < params.block_size.z; ++z) { + const std::string s_z = std::to_string(z); + c += " if (Z + " + s_z + " < params.dst_size.w) {\n"; + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + const std::string s_yx = s_y + s_x; + const std::string s_zyx = s_z + s_yx; + bool need_check_x = x >= 1; + bool need_check_y = y >= 1; + std::string check; + if (need_check_x) { + check += "(X + " + s_x + ") < params.dst_size.x"; + } + if (need_check_y) { + check += check.empty() ? "" : " && "; + check += "(Y + " + s_y + ") < params.dst_size.y"; + } + if (!check.empty()) { + c += " if (" + check + ") {\n"; + } else { + c += " {\n"; + } + c += " FLT4 value = FLT4(r" + s_zyx + ");\n"; + c += " int linear_index = offset_" + s_yx + + " + params.dst_size.z * " + s_z + ";\n"; + c += " uint3 gid = uint3(X + " + s_x + ", Y + " + s_y + ", Z + " + + s_z + ");\n"; + c += " $2\n"; + c += " dst_buffer[linear_index] = value;\n"; + c += " }\n"; + } + } + c += " }\n"; + } + c += "}\n"; + return c; } -std::string GetKernelForConvPrecise1x1PowerVR(int z_out) { - std::string channels[4] = {"x", "y", "z", "w"}; - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - code += R"( -#include -using namespace metal; - -struct uniforms { - int4 src_size; - int4 dst_size; - int4 slices; - int4 dummy0; -}; -$0 - -kernel void ComputeFunction( - $1 - uint3 ugid[[thread_position_in_grid]]) -{ - int linear_id = ugid.x; - int gid_z = linear_id / params.slices.y; - int linear_xy = linear_id - gid_z * params.slices.y; -)"; - code += " gid_z *= " + std::to_string(z_out) + ";\n"; - code += R"( - int gid_y0 = linear_xy / params.slices.x; - int gid_x0 = linear_xy - gid_y0 * params.slices.x; - - if (gid_z >= params.dst_size.w) return; -)"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " float4 r" + s_i + " = float4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - } - code += R"( - device FLT4* tmp = filters + gid_z * 4 * params.src_size.w; - - device FLT4* src_loc_0 = src_buffer + gid_y0 * params.src_size.x + gid_x0; - int s = 0; - do { - FLT4 src_0 = *src_loc_0; - src_loc_0 += params.src_size.z; -)"; - for (int i = 0; i < z_out * 4; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_0);\n"; - } - - code += " tmp += " + std::to_string(z_out * 4) + ";\n"; - code += R"( - s += 1; - } while (s < params.src_size.w); - const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0; - - device FLT4* bias_loc = biases + gid_z; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + s_i + " += float4(bias_loc[" + s_i + "]);\n"; - } - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; - code += " FLT4 value = FLT4(r" + s_i + ");\n"; - code += - " int linear_index = offset_0 + params.dst_size.z * " + s_i + ";\n"; - code += " uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - } - code += " }\n"; - return code; -} - -// Reorder weights to make the weights memory access pattern cache friendly for -// Convolution1x1/ConvolutionGeneric std::vector ReorderWeightsForConv(const Convolution2DAttributes& params, int z_out) { const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); - std::vector weights_reordered(params.weights.shape.w * - params.weights.shape.h * dst_depth * 4 * - src_depth * 4); + std::vector weights_reordered( + params.weights.shape.w * params.weights.shape.h * + AlignByN(dst_depth, z_out) * 4 * src_depth * 4); int counter = 0; for (int d = 0; d < IntegralDivideRoundUp(dst_depth, z_out); ++d) { for (int y = 0; y < params.weights.shape.h; ++y) { @@ -768,7 +519,7 @@ std::vector ReorderWeightsForConv(const Convolution2DAttributes& params, dst_ch >= params.weights.shape.o) { weights_reordered[counter++] = 0.0f; } else { - const int f_index = + const size_t f_index = params.weights.shape.LinearIndex({dst_ch, y, x, src_ch}); weights_reordered[counter++] = params.weights.data[f_index]; } @@ -782,13 +533,12 @@ std::vector ReorderWeightsForConv(const Convolution2DAttributes& params, return weights_reordered; } -uint3 GetWorkGroupForConv() { return {8, 4, 1}; } -uint3 GetWorkGroupForConvPrecise() { return {32, 1, 1}; } - -std::vector GetUniformBufferForConv( - const BHWC& src_size, const BHWC& dst_size, - const Convolution2DAttributes& params) { - const int3 group_size = GetWorkGroupForConv(); +std::vector GetUniformBuffer(const BHWC& src_size, + const BHWC& dst_size, + const Convolution2DAttributes& attr, + const ConvParams& params) { + const int grid_x = IntegralDivideRoundUp(dst_size.w, params.block_size.x); + const int grid_y = IntegralDivideRoundUp(dst_size.h, params.block_size.y); std::vector uniform_params = { src_size.w, src_size.h, @@ -798,240 +548,280 @@ std::vector GetUniformBufferForConv( dst_size.h, dst_size.w * dst_size.h, IntegralDivideRoundUp(dst_size.c, 4), - params.strides.w, - params.strides.h, - -params.padding.prepended.w, - -params.padding.prepended.h, - params.weights.shape.w, - params.weights.shape.h, - params.dilations.w, - params.dilations.h, - group_size.x, - group_size.y, - group_size.z, - 1u, // dummy, for alignment + attr.strides.w, + attr.strides.h, + -attr.padding.prepended.w, + -attr.padding.prepended.h, + attr.weights.shape.w, + attr.weights.shape.h, + attr.dilations.w, + attr.dilations.h, + grid_x, + grid_x * grid_y, + 0, // dummy, for alignment + 0, // dummy, for alignment + params.work_group_size.x, + params.work_group_size.y, + params.work_group_size.z, + 0, // dummy, for alignment }; return GetByteBuffer(uniform_params); } -std::vector GetUniformBufferForConvPrecise( - const BHWC& src_size, const BHWC& dst_size, - const Convolution2DAttributes& params) { - std::vector uniform_params = { - src_size.w, - src_size.h, - src_size.w * src_size.h, - IntegralDivideRoundUp(src_size.c, 4), - dst_size.w, - dst_size.h, - dst_size.w * dst_size.h, - IntegralDivideRoundUp(dst_size.c, 4), - params.strides.w, - params.strides.h, - -params.padding.prepended.w, - -params.padding.prepended.h, - params.weights.shape.w, - params.weights.shape.h, - params.dilations.w, - params.dilations.h, - dst_size.w, - IntegralDivideRoundUp(dst_size.w * dst_size.h, 2), - 0u, // dummy, for alignment - 0u, // dummy, for alignment - }; - return GetByteBuffer(uniform_params); +int GetGroupsCount(const BHWC& dst_shape, const int3& wg_size, + const int3& block_size) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + + int grid_x = IntegralDivideRoundUp(dst_shape.w, block_size.x); + int grid_y = IntegralDivideRoundUp(dst_shape.h, block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, block_size.z); + + return IntegralDivideRoundUp(grid_x, wg_size.x) * + IntegralDivideRoundUp(grid_y, wg_size.y) * + IntegralDivideRoundUp(grid_z, wg_size.z); } -std::vector GetUniformBufferForConvPrecise1x1( - const BHWC& src_size, const BHWC& dst_size, - const Convolution2DAttributes& params) { - std::vector uniform_params = { - src_size.w, - src_size.h, - src_size.w * src_size.h, - IntegralDivideRoundUp(src_size.c, 4), - dst_size.w, - dst_size.h, - dst_size.w * dst_size.h, - IntegralDivideRoundUp(dst_size.c, 4), - dst_size.w, - IntegralDivideRoundUp(dst_size.w * dst_size.h, 1), - 0u, // dummy, for alignment - 0u, // dummy, for alignment - 0u, // dummy, for alignment - 0u, // dummy, for alignment - 0u, // dummy, for alignment - 0u, // dummy, for alignment - }; - return GetByteBuffer(uniform_params); +int GetGroupsCountForLinearWH(const BHWC& dst_shape, const int3& wg_size, + const int3& block_size) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + + int grid_x = IntegralDivideRoundUp(dst_shape.w, block_size.x); + int grid_y = IntegralDivideRoundUp(dst_shape.h, block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, block_size.z); + + return IntegralDivideRoundUp(grid_x * grid_y, wg_size.x) * + IntegralDivideRoundUp(grid_z, wg_size.y); } -uint3 GetGroupsCountForConv(const uint3& group_size, const BHWC& dst_shape) { - const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4); - int groups_x = IntegralDivideRoundUp(dst_shape.w, group_size.x); - int groups_y = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_shape.h, 2), - group_size.y); - const int z_out = GetNumOutputSlices(dst_shape.c); - int groups_z = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_depth, z_out), - group_size.z); - return {groups_x, groups_y, groups_z}; +int GetGroupsCountForLinearWHS(const BHWC& dst_shape, const int3& wg_size, + const int3& block_size) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + + int grid_x = IntegralDivideRoundUp(dst_shape.w, block_size.x); + int grid_y = IntegralDivideRoundUp(dst_shape.h, block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, block_size.z); + + return IntegralDivideRoundUp(grid_x * grid_y * grid_z, wg_size.x); } -uint3 GetGroupsCountForConvPrecise(const uint3& group_size, - const BHWC& dst_shape, int xy_pixels) { - const int z_out = GetNumOutputSlices(dst_shape.c); - const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4); - int xy_size = IntegralDivideRoundUp(dst_shape.w * dst_shape.h, xy_pixels); - int z_size = IntegralDivideRoundUp(dst_depth, z_out); - int task_size = xy_size * z_size; - return {IntegralDivideRoundUp(task_size, group_size.x), 1, 1}; -} - -int GetConvolutionThreadsCount(const BHWC& dst_shape) { - const uint3 group_size = GetWorkGroupForConv(); - const uint3 groups_count = GetGroupsCountForConv(group_size, dst_shape); - return groups_count.x * groups_count.y * groups_count.z * group_size.x * - group_size.y * group_size.z; -} - -int GetConvolutionPreciseThreadsCount(const BHWC& dst_shape, int xy_pixels) { - const uint3 group_size = GetWorkGroupForConvPrecise(); - const uint3 groups_count = - GetGroupsCountForConvPrecise(group_size, dst_shape, xy_pixels); - return groups_count.x * groups_count.y * groups_count.z * group_size.x * - group_size.y * group_size.z; -} - -bool IsConv1x1(const Convolution2DAttributes& attr) { - return attr.weights.shape.h == 1 && attr.weights.shape.w == 1 && - attr.strides.h == 1 && attr.strides.w == 1 && attr.dilations.h == 1 && - attr.dilations.w == 1 && attr.padding.prepended.h == 0 && - attr.padding.prepended.w == 0 && attr.padding.appended.h == 0 && +bool IsKernelXIs1(const Convolution2DAttributes& attr) { + return attr.weights.shape.w == 1 && attr.strides.w == 1 && + attr.dilations.w == 1 && attr.padding.prepended.w == 0 && attr.padding.appended.w == 0; } +bool IsKernelYIs1(const Convolution2DAttributes& attr) { + return attr.weights.shape.h == 1 && attr.strides.h == 1 && + attr.dilations.h == 1 && attr.padding.prepended.h == 0 && + attr.padding.appended.h == 0; +} + +int GetMaximumPossibleWavesCount(const BHWC& dst_shape, GpuType gpu) { + if (gpu == GpuType::kA7 || gpu == GpuType::kA8) { + return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1}); + } else { + return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, {1, 1, 1}); + } +} + +int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) { + const int max_waves = GetMaximumPossibleWavesCount(dst_shape, gpu); + int base_threshold; + if (gpu == GpuType::kA7 || gpu == GpuType::kA8) { + base_threshold = 32; + } else if (gpu == GpuType::kA11) { + base_threshold = 48; + } else { + base_threshold = 64; + } + if (max_waves >= base_threshold * 4) { + return 8; + } else if (max_waves >= base_threshold * 2) { + return 4; + } else if (max_waves >= base_threshold) { + return 2; + } else { + return 1; + } +} + +ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr, + const BHWC& dst_shape, GpuType gpu) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4); + + ConvParams params; + params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; + params.x_kernel_is_1 = IsKernelXIs1(attr); + params.y_kernel_is_1 = IsKernelYIs1(attr); + params.src_depth_loop_size = 1; + params.block_size = int3(1, 1, 1); + params.linear_wh = false; + params.linear_whs = false; + params.work_group_launch_order = int3(0, 1, 2); + + int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu); + + if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) { + params.block_size.z = 4; + blk_total_size /= 4; + } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) { + params.block_size.z = 2; + blk_total_size /= 2; + } + if (blk_total_size >= 4) { + params.block_size.x = 2; + params.block_size.y = 2; + blk_total_size /= 4; + } else if (blk_total_size >= 2) { + if (dst_shape.w % 2 != 0 && dst_shape.h % 2 == 0) { + params.block_size.y = 2; + } else { + params.block_size.x = 2; + } + blk_total_size /= 2; + } + + params.work_group_size = params.block_size.x <= params.block_size.y + ? int3(8, 4, 1) + : int3(4, 8, 1); + + int g1 = GetGroupsCount(dst_shape, params.work_group_size, params.block_size); + int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, params.block_size); + int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, params.block_size); + + if (g2 < g1) { + params.linear_wh = true; + params.work_group_size = int3(32, 1, 1); + params.work_group_launch_order = int3(0, 1, 2); + } + float precise_threshold = 3.1f; + float precise_ratio = static_cast(g2) / static_cast(g3); + if (precise_ratio > precise_threshold) { + params.linear_wh = false; + params.linear_whs = true; + params.work_group_size = int3(32, 1, 1); + params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + } + + if (params.src_depth_loop_size == src_slices) { + params.need_src_loop = false; + } + if (params.block_size.z == dst_slices) { + params.need_dst_loop = false; + } + const bool use_filters_constants = + !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 && + params.y_kernel_is_1; + if (use_filters_constants) { + params.weights_upload_type = WeightsUploadType::CONSTANT_MEM; + } + + return params; +} + +ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr, + const BHWC& dst_shape, GpuType gpu) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4); + int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu); + bool apple_gpu = gpu == GpuType::kA11 || gpu == GpuType::kA12; + int3 block_size = int3(1, 1, 1); + if (blk_total_size >= 2 && apple_gpu) { + if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) { + block_size.x = 2; + } else { + block_size.y = 2; + } + blk_total_size /= 2; + } + if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) { + block_size.z = 4; + blk_total_size /= 4; + } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) { + block_size.z = 2; + blk_total_size /= 2; + } + if (blk_total_size >= 4 && dst_slices == 3) { + block_size.z = 3; + blk_total_size /= 4; + } + + ConvParams params; + params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + params.x_kernel_is_1 = IsKernelXIs1(attr); + params.y_kernel_is_1 = IsKernelYIs1(attr); + params.src_depth_loop_size = 1; + params.block_size = block_size; + params.linear_wh = false; + params.linear_whs = false; + params.work_group_size = int3(8, 4, 1); + params.work_group_launch_order = int3(2, 0, 1); + int g1 = GetGroupsCount(dst_shape, {8, 4, 1}, block_size); + int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, block_size); + int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, block_size); + if (g2 < g1) { + params.linear_wh = true; + params.work_group_size = int3(32, 1, 1); + params.work_group_launch_order = int3(0, 1, 2); + } + float precise_threshold = gpu == GpuType::kA12 ? 1.0f : 1.04f; + float precise_ratio = static_cast(g2) / static_cast(g3); + if (precise_ratio > precise_threshold) { + params.linear_wh = false; + params.linear_whs = true; + params.work_group_size = int3(32, 1, 1); + } + int total_elements = + params.block_size.x * params.block_size.y * params.block_size.z; + if (total_elements == 1) { + if (src_slices % 4 == 0) { + params.src_depth_loop_size = 4; + } else if (src_slices % 2 == 0) { + params.src_depth_loop_size = 2; + } + } else if (total_elements == 2) { + if (src_slices % 2 == 0) { + params.src_depth_loop_size = 2; + } + } + if (params.src_depth_loop_size == src_slices) { + params.need_src_loop = false; + } + if (params.block_size.z == dst_slices) { + params.need_dst_loop = false; + } + const bool use_filters_constants = + !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 && + params.y_kernel_is_1; + if (use_filters_constants) { + params.weights_upload_type = WeightsUploadType::CONSTANT_MEM; + } + + return params; +} + +ConvParams GetConvParams(const Convolution2DAttributes& attr, + const BHWC& dst_shape) { + auto gpu_type = GetGpuType(); + if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) { + return GetConvParamsForA7A8(attr, dst_shape, gpu_type); + } else { + return GetConvParamsForA9AndHigher(attr, dst_shape, gpu_type); + } +} + } // namespace -std::vector Convolution( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - desc->shader_source = GetKernelForConv(params); - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - return CalculateOutputShape(buffers.find(input_id)->second, params); - }}; - - auto weights_reordered = ReorderWeightsForConvShared(params); - desc->immutable_buffers = { - {"device FLT4* const weights", - GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConvShared(input_dimensions, - output_dimensions, params); - }}, - }; - - desc->resize_function = [output_id, - params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const int num_output_slices = GetNumOutputSlices(params.weights.shape.o); - const uint3 group_size{8, 4, 1}; - int groups_x = IntegralDivideRoundUp(output_dims.w, group_size.x); - int groups_y = IntegralDivideRoundUp(output_dims.h, group_size.y); - const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4); - int groups_z = IntegralDivideRoundUp(dst_depth, num_output_slices); - return std::make_pair(group_size, uint3{groups_x, groups_y, groups_z}); - }; - - return {desc}; -} - -std::vector Convolution1x1( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, - const metal::RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - const int z_out = GetNumOutputSlices(params.weights.shape.o); - desc->shader_source = GetKernelForConv1x1(params, z_out); - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - auto out_shape = - CalculateOutputShape(buffers.find(input_id)->second, params); - return out_shape; - }}; - - auto weights_reordered = ReorderWeightsForConv(params, z_out); - desc->immutable_buffers = { - {"device FLT4* const filters", - GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConv(input_dimensions, output_dimensions, - params); - }}, - }; - - desc->resize_function = [output_id, - params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const uint3 group_size = GetWorkGroupForConv(); - const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims); - return std::make_pair( - group_size, uint3{groups_count.z, groups_count.x, groups_count.y}); - }; - - return {desc}; -} - -bool CheckConvolution1x1Support(const Convolution2DAttributes& attr) { - return IsConv1x1(attr); -} - std::vector ConvolutionGeneric( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, - const metal::RuntimeOptions& options) { + int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape, + const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) { + ConvParams params = GetConvParams(attr, dst_shape); + auto desc = std::make_shared(); desc->id = id; desc->is_linkable = false; - const int z_out = GetNumOutputSlices(params.weights.shape.o); - desc->shader_source = GetKernelForConvGeneric(params, z_out); + desc->shader_source = GenerateConvolution(params); desc->input_buffers = { {input_id, "device FLT4* const src_buffer"}, @@ -1039,160 +829,72 @@ std::vector ConvolutionGeneric( desc->output_buffer = { output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { + [input_id, attr](const std::map& buffers) { auto out_shape = - CalculateOutputShape(buffers.find(input_id)->second, params); + CalculateOutputShape(buffers.find(input_id)->second, attr); return out_shape; }}; - auto weights_reordered = ReorderWeightsForConv(params, z_out); + auto weights_reordered = ReorderWeightsForConv(attr, params.block_size.z); + std::string addr_space = + params.weights_upload_type == WeightsUploadType::CONSTANT_MEM ? "constant" + : "device"; + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); desc->immutable_buffers = { - {"device FLT4* const filters", + {addr_space + " FLT4* const filters", GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, + {addr_space + " FLT4* const biases", + GetByteBufferConvertedResized( + attr.bias.data, options.storage_precision, + AlignByN(dst_depth, params.block_size.z) * 4)}, }; desc->uniform_buffers = { {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConv(input_dimensions, output_dimensions, - params); + [input_id, output_id, attr, + params](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + const auto& dst_shape = buffers.find(output_id)->second; + return GetUniformBuffer(src_shape, dst_shape, attr, params); }}, }; desc->resize_function = [output_id, params](const std::map& buffers) { const auto& output_dims = buffers.find(output_id)->second; - const uint3 group_size = GetWorkGroupForConv(); - const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims); - return std::make_pair( - group_size, uint3{groups_count.z, groups_count.x, groups_count.y}); - }; + const int dst_slices = IntegralDivideRoundUp(output_dims.c, 4); - return {desc}; -} + int grid_x = IntegralDivideRoundUp(output_dims.w, params.block_size.x); + int grid_y = IntegralDivideRoundUp(output_dims.h, params.block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, params.block_size.z); -std::vector ConvolutionPrecise( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, - const metal::RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - const int z_out = GetNumOutputSlices(params.weights.shape.o); - desc->shader_source = GetKernelForConvPrecise(z_out); - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - auto out_shape = - CalculateOutputShape(buffers.find(input_id)->second, params); - return out_shape; - }}; - - auto weights_reordered = ReorderWeightsForConv(params, z_out); - desc->immutable_buffers = { - {"device FLT4* const filters", - GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConvPrecise(input_dimensions, - output_dimensions, params); - }}, - }; - - desc->resize_function = [output_id, - params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const uint3 group_size = GetWorkGroupForConvPrecise(); - const uint3 groups_count = - GetGroupsCountForConvPrecise(group_size, output_dims, 2); + const uint3 group_size(params.work_group_size.x, params.work_group_size.y, + params.work_group_size.z); + int3 wg; + uint3 groups_count; + if (params.linear_whs) { + wg.x = IntegralDivideRoundUp(grid_x * grid_y * grid_z, + params.work_group_size.x); + groups_count = uint3(wg.x, 1, 1); + } else if (params.linear_wh) { + wg.x = IntegralDivideRoundUp(grid_x * grid_y, params.work_group_size.x); + wg.y = IntegralDivideRoundUp(grid_z, params.work_group_size.y); + groups_count = uint3(wg[params.work_group_launch_order.x], + wg[params.work_group_launch_order.y], 1); + } else { + wg.x = IntegralDivideRoundUp(grid_x, params.work_group_size.x); + wg.y = IntegralDivideRoundUp(grid_y, params.work_group_size.y); + wg.z = IntegralDivideRoundUp(grid_z, params.work_group_size.z); + groups_count = uint3(wg[params.work_group_launch_order.x], + wg[params.work_group_launch_order.y], + wg[params.work_group_launch_order.z]); + } return std::make_pair(group_size, groups_count); }; return {desc}; } -float GetThreadsRatioUsualToPreciseConvolution(const BHWC& dst_shape) { - return static_cast(GetConvolutionThreadsCount(dst_shape)) / - static_cast(GetConvolutionPreciseThreadsCount(dst_shape, 2)); -} - -std::vector ConvolutionPrecise1x1PowerVR( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - const int z_out = GetNumOutputSlices(params.weights.shape.o); - desc->shader_source = GetKernelForConvPrecise1x1PowerVR(z_out); - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - auto out_shape = - CalculateOutputShape(buffers.find(input_id)->second, params); - return out_shape; - }}; - - auto weights_reordered = ReorderWeightsForConv(params, z_out); - desc->immutable_buffers = { - {"device FLT4* const filters", - GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConvPrecise1x1(input_dimensions, - output_dimensions, params); - }}, - }; - - desc->resize_function = [output_id, - params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const uint3 group_size = GetWorkGroupForConvPrecise(); - const uint3 groups_count = - GetGroupsCountForConvPrecise(group_size, output_dims, 1); - return std::make_pair(group_size, groups_count); - }; - - return {desc}; -} - -bool CheckConvolutionPrecise1x1Support(const Convolution2DAttributes& attr) { - return IsConv1x1(attr); -} - } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h index 692145678cb..2853631abe8 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h @@ -27,67 +27,10 @@ namespace tflite { namespace gpu { namespace metal { -std::vector Convolution( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, - const metal::RuntimeOptions& options); - -// Convolution for kernel 1x1 -// require: -// kernel_size = 1x1; -// padding prepended and appended = 0x0 -// dilation = 1x1; -// stride = 1x1; -// Works very good on A12 (IPhoneXS, etc). -// Works good on A9/A10/A11 (IPhone6S, IPhone7, IPhoneX, etc). -// Works bad on A7/A8 (IPhone5S, IPhone6, etc). -std::vector Convolution1x1( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options); - -// TODO(impjdi): Move it inside module. -bool CheckConvolution1x1Support(const Convolution2DAttributes& attr); - -// This convolution pass all conv parameters (beside output_channels) -// as dynamic arguments (uniform buffer) to kernel. -// Depending on output_channels can be generated different kernels -// Kernel can proceed 4/8/12/16 output channels per one thread. -// 16 channels output is the fastest but the least flexible. std::vector ConvolutionGeneric( - int id, ValueId input_id, ValueId output_id, + int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape, const Convolution2DAttributes& params, const RuntimeOptions& options); -// This convolution makes more precise mapping of threads on elements. -// For example, if we have output tensor 12x7 and work group = 8x4, -// then we need 4 workgroups to cover this tensor in usual case. -// But in general we have only 84 elements(12*7), and we can cover it with 3 -// workgroups of size 32. So this version of convolution use this precise -// mapping. -// But this convolution, due to some hardware limitations, doesn't work better -// always. In general it works good on A12. -// Each thread process 2 pixels in XY dimension and variable amount of pixels -// in Z dimension(depends on dst_channels). -std::vector ConvolutionPrecise( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options); - -// As previous, but specific for 1x1 and each thread process 1 pixel in XY -// dimension. -// This convolution for PowerVR in FP16 mode with FP32 accumulator -// It will work in other modes also, but not with good performance -std::vector ConvolutionPrecise1x1PowerVR( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options); - -// TODO(impjdi): Move it inside module. -bool CheckConvolutionPrecise1x1Support(const Convolution2DAttributes& attr); - -// This function calculates amount of threads that should be launched for -// ConvolutionGeneric or Convolution1x1 (threads_count1) and amount of threads -// that should be launched for ConvolutionPrecise (threads_count2) and returns -// threads_count1 / threads_count2. -float GetThreadsRatioUsualToPreciseConvolution(const BHWC& dst_shape); - } // namespace metal } // namespace gpu } // namespace tflite From 00ad1522c8b826d79c6ae067f768818d76fa7c52 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 23 Mar 2020 10:05:32 -0700 Subject: [PATCH 416/492] Improve keras.constraints docstrings. PiperOrigin-RevId: 302458247 Change-Id: I96722e54807ea3b8a4358b9552ec0233f8399707 --- tensorflow/python/keras/constraints.py | 107 +++++++++++++++---------- 1 file changed, 63 insertions(+), 44 deletions(-) diff --git a/tensorflow/python/keras/constraints.py b/tensorflow/python/keras/constraints.py index 043ceb8dd6d..7cdc00151a6 100644 --- a/tensorflow/python/keras/constraints.py +++ b/tensorflow/python/keras/constraints.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import keras_export +from tensorflow.tools.docs import doc_controls @keras_export('keras.constraints.Constraint') @@ -48,19 +49,21 @@ class MaxNorm(Constraint): Constrains the weights incident to each hidden unit to have a norm less than or equal to a desired value. + Also available via the shortcut function `tf.keras.constraints.max_norm`. + Arguments: - m: the maximum norm for the incoming weights. - axis: integer, axis along which to calculate weight norms. - For instance, in a `Dense` layer the weight matrix - has shape `(input_dim, output_dim)`, - set `axis` to `0` to constrain each weight vector - of length `(input_dim,)`. - In a `Conv2D` layer with `data_format="channels_last"`, - the weight tensor has shape - `(rows, cols, input_depth, output_depth)`, - set `axis` to `[0, 1, 2]` - to constrain the weights of each filter tensor of size - `(rows, cols, input_depth)`. + max_value: the maximum norm value for the incoming weights. + axis: integer, axis along which to calculate weight norms. + For instance, in a `Dense` layer the weight matrix + has shape `(input_dim, output_dim)`, + set `axis` to `0` to constrain each weight vector + of length `(input_dim,)`. + In a `Conv2D` layer with `data_format="channels_last"`, + the weight tensor has shape + `(rows, cols, input_depth, output_depth)`, + set `axis` to `[0, 1, 2]` + to constrain the weights of each filter tensor of size + `(rows, cols, input_depth)`. """ @@ -68,12 +71,14 @@ class MaxNorm(Constraint): self.max_value = max_value self.axis = axis + @doc_controls.do_not_generate_docs def __call__(self, w): norms = K.sqrt( math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) desired = K.clip(norms, 0, self.max_value) return w * (desired / (K.epsilon() + norms)) + @doc_controls.do_not_generate_docs def get_config(self): return {'max_value': self.max_value, 'axis': self.axis} @@ -81,6 +86,8 @@ class MaxNorm(Constraint): @keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg') class NonNeg(Constraint): """Constrains the weights to be non-negative. + + Also available via the shortcut function `tf.keras.constraints.non_neg`. """ def __call__(self, w): @@ -91,29 +98,33 @@ class NonNeg(Constraint): class UnitNorm(Constraint): """Constrains the weights incident to each hidden unit to have unit norm. + Also available via the shortcut function `tf.keras.constraints.unit_norm`. + Arguments: - axis: integer, axis along which to calculate weight norms. - For instance, in a `Dense` layer the weight matrix - has shape `(input_dim, output_dim)`, - set `axis` to `0` to constrain each weight vector - of length `(input_dim,)`. - In a `Conv2D` layer with `data_format="channels_last"`, - the weight tensor has shape - `(rows, cols, input_depth, output_depth)`, - set `axis` to `[0, 1, 2]` - to constrain the weights of each filter tensor of size - `(rows, cols, input_depth)`. + axis: integer, axis along which to calculate weight norms. + For instance, in a `Dense` layer the weight matrix + has shape `(input_dim, output_dim)`, + set `axis` to `0` to constrain each weight vector + of length `(input_dim,)`. + In a `Conv2D` layer with `data_format="channels_last"`, + the weight tensor has shape + `(rows, cols, input_depth, output_depth)`, + set `axis` to `[0, 1, 2]` + to constrain the weights of each filter tensor of size + `(rows, cols, input_depth)`. """ def __init__(self, axis=0): self.axis = axis + @doc_controls.do_not_generate_docs def __call__(self, w): return w / ( K.epsilon() + K.sqrt( math_ops.reduce_sum( math_ops.square(w), axis=self.axis, keepdims=True))) + @doc_controls.do_not_generate_docs def get_config(self): return {'axis': self.axis} @@ -125,27 +136,29 @@ class MinMaxNorm(Constraint): Constrains the weights incident to each hidden unit to have the norm between a lower bound and an upper bound. + Also available via the shortcut function `tf.keras.constraints.min_max_norm`. + Arguments: - min_value: the minimum norm for the incoming weights. - max_value: the maximum norm for the incoming weights. - rate: rate for enforcing the constraint: weights will be - rescaled to yield - `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. - Effectively, this means that rate=1.0 stands for strict - enforcement of the constraint, while rate<1.0 means that - weights will be rescaled at each step to slowly move - towards a value inside the desired interval. - axis: integer, axis along which to calculate weight norms. - For instance, in a `Dense` layer the weight matrix - has shape `(input_dim, output_dim)`, - set `axis` to `0` to constrain each weight vector - of length `(input_dim,)`. - In a `Conv2D` layer with `data_format="channels_last"`, - the weight tensor has shape - `(rows, cols, input_depth, output_depth)`, - set `axis` to `[0, 1, 2]` - to constrain the weights of each filter tensor of size - `(rows, cols, input_depth)`. + min_value: the minimum norm for the incoming weights. + max_value: the maximum norm for the incoming weights. + rate: rate for enforcing the constraint: weights will be + rescaled to yield + `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. + Effectively, this means that rate=1.0 stands for strict + enforcement of the constraint, while rate<1.0 means that + weights will be rescaled at each step to slowly move + towards a value inside the desired interval. + axis: integer, axis along which to calculate weight norms. + For instance, in a `Dense` layer the weight matrix + has shape `(input_dim, output_dim)`, + set `axis` to `0` to constrain each weight vector + of length `(input_dim,)`. + In a `Conv2D` layer with `data_format="channels_last"`, + the weight tensor has shape + `(rows, cols, input_depth, output_depth)`, + set `axis` to `[0, 1, 2]` + to constrain the weights of each filter tensor of size + `(rows, cols, input_depth)`. """ def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): @@ -154,6 +167,7 @@ class MinMaxNorm(Constraint): self.rate = rate self.axis = axis + @doc_controls.do_not_generate_docs def __call__(self, w): norms = K.sqrt( math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) @@ -162,6 +176,7 @@ class MinMaxNorm(Constraint): (1 - self.rate) * norms) return w * (desired / (K.epsilon() + norms)) + @doc_controls.do_not_generate_docs def get_config(self): return { 'min_value': self.min_value, @@ -176,7 +191,10 @@ class MinMaxNorm(Constraint): class RadialConstraint(Constraint): """Constrains `Conv2D` kernel weights to be the same for each radius. - For example, the desired output for the following 4-by-4 kernel:: + Also available via the shortcut function + `tf.keras.constraints.radial_constraint`. + + For example, the desired output for the following 4-by-4 kernel: ``` kernel = [[v_00, v_01, v_02, v_03], @@ -200,6 +218,7 @@ class RadialConstraint(Constraint): shape `(rows, cols, input_depth, output_depth)`. """ + @doc_controls.do_not_generate_docs def __call__(self, w): w_shape = w.shape if w_shape.rank is None or w_shape.rank != 4: From b71e31d545960901dd25526946e57086cbb2770a Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Mon, 23 Mar 2020 10:14:57 -0700 Subject: [PATCH 417/492] Fixed compilation of elementwise_test.cc Reduced epsilon for FP16 in Exp test. PiperOrigin-RevId: 302460743 Change-Id: I5e4e46e45181ad49e009b7b51e25bd9706aec936 --- .../lite/delegates/gpu/cl/kernels/elementwise_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc index d558f2a6bd4..7a3aaecfe7f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc @@ -88,7 +88,7 @@ TEST_F(OpenCLOperationTest, Exp) { for (auto storage : env_.GetSupportedStorages()) { for (auto precision : env_.GetSupportedPrecisions()) { - const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; OperationDef op_def; op_def.precision = precision; auto data_type = DeduceDataTypeFromPrecision(precision); @@ -499,7 +499,7 @@ TEST_F(OpenCLOperationTest, MaxiumumWithScalar) { TensorFloat32 dst_tensor; BroadcastSettings broadcast; ElementwiseTwoInput operation = CreateElementwiseTwoInput( - creation_context_, op_def, OperationType::MAXIMUM, broadcast, attr); + creation_context_, op_def, OperationType::MAXIMUM, broadcast, &attr); ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, BHWC(1, 4, 1, 1), &dst_tensor)); EXPECT_THAT(dst_tensor.data, @@ -555,7 +555,7 @@ TEST_F(OpenCLOperationTest, MinimumWithScalar) { TensorFloat32 dst_tensor; BroadcastSettings broadcast; ElementwiseTwoInput operation = CreateElementwiseTwoInput( - creation_context_, op_def, OperationType::MINIMUM, broadcast, attr); + creation_context_, op_def, OperationType::MINIMUM, broadcast, &attr); ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, BHWC(1, 4, 1, 1), &dst_tensor)); EXPECT_THAT(dst_tensor.data, From b70b14a4624f6a06d394001f8d1fcd5d6d25d531 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Mon, 23 Mar 2020 10:17:25 -0700 Subject: [PATCH 418/492] Only create TPU replicated variable handle in graph mode. PiperOrigin-RevId: 302461289 Change-Id: I4923d3db3e59db45e95a7a52c0c60fb42b3ee911 --- tensorflow/python/distribute/tpu_values.py | 2 +- tensorflow/python/distribute/values_test.py | 33 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py index 5ac2a11f82b..9d0719f34b4 100644 --- a/tensorflow/python/distribute/tpu_values.py +++ b/tensorflow/python/distribute/tpu_values.py @@ -112,7 +112,7 @@ class TPUVariableMixin(object): def handle(self): # If we're in a tpu.rewrite(), return the replicated handle. tpu_context = enclosing_tpu_context() - if tpu_context is None: + if tpu_context is None or context.executing_eagerly(): return self._get_closest().handle else: return tpu_context.get_replicated_var_handle(self._handle_id, diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 0c7b3dffd2b..685dbaf4d40 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -990,6 +990,39 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32), per_replica_results) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + mode=["eager"])) + def testInitScope(self, distribution): + + class C(object): + pass + + obj = C() + obj.w = None + obj.v = None + + @def_function.function + def assign(): + with ops.init_scope(): + if obj.w is None: + obj.w = variables_lib.Variable( + 0, aggregation=variables_lib.VariableAggregation.MEAN) + obj.v = variables_lib.Variable( + obj.w.read_value(), + aggregation=variables_lib.VariableAggregation.MEAN) + + return obj.v.assign_add(2) + + per_replica_results = self.evaluate( + distribution.experimental_local_results(distribution.run(assign))) + self.assertAllEqual([2, 2], per_replica_results) + @combinations.generate( combinations.combine( distribution=[ From fa3b55e940cef475c8691487f4c00ac900a67d04 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 23 Mar 2020 10:40:00 -0700 Subject: [PATCH 419/492] Change comment of the external repo #include file origin. PiperOrigin-RevId: 302466710 Change-Id: I4d87c4272dbedc2b1ccee3ffbdc3be9690b7b480 --- tensorflow/c/eager/dlpack.cc | 2 +- .../compiler/mlir/lite/converter_gen.cc | 8 +-- .../compiler/mlir/lite/emit_error_reporter.h | 2 +- .../lite/experimental/estimators/estimator.h | 2 +- .../compiler/mlir/lite/flatbuffer_import.cc | 36 ++++++------- .../compiler/mlir/lite/flatbuffer_import.h | 6 +-- .../compiler/mlir/lite/flatbuffer_operator.cc | 8 +-- .../compiler/mlir/lite/flatbuffer_operator.h | 8 +-- .../mlir/lite/flatbuffer_to_string.cc | 4 +- .../mlir/lite/flatbuffer_translate.cc | 32 ++++++------ .../compiler/mlir/lite/flatbuffer_translate.h | 2 +- tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 26 +++++----- tensorflow/compiler/mlir/lite/ir/tfl_ops.h | 24 ++++----- .../compiler/mlir/lite/mlir_tflite_runner.cc | 8 +-- .../lite/python/graphdef_to_tfl_flatbuffer.cc | 10 ++-- .../python/saved_model_to_tfl_flatbuffer.cc | 10 ++-- .../lite/python/tf_tfl_flatbuffer_helpers.cc | 10 ++-- .../lite/python/tf_tfl_flatbuffer_helpers.h | 2 +- .../quantization/import_quant_stats_pass.cc | 24 ++++----- .../lite/quantization/lite/quantize_model.cc | 10 ++-- .../mlir/lite/quantization/lite/tfl_to_std.cc | 2 +- .../mlir/lite/quantization/lite/tfl_to_std.h | 2 +- .../lite/quantization/quantization_driver.cc | 24 ++++----- .../lite/quantization/quantization_passes.h | 4 +- .../lite/quantization/quantization_traits.h | 4 +- .../lite/quantization/quantization_utils.cc | 18 +++---- .../lite/quantization/quantization_utils.h | 24 ++++----- .../lite/quantization/tensorflow/passes.h | 4 +- .../quantization/tensorflow/tf_to_quant.cc | 6 +-- .../tools/op_quant_spec_getters_gen.cc | 2 +- .../quantization/xla/cpu_kernel_fusion.cc | 24 ++++----- .../mlir/lite/quantization/xla/materialize.cc | 16 +++--- .../mlir/lite/quantization/xla/passes.h | 4 +- .../mlir/lite/quantization/xla/propagate.cc | 8 +-- .../mlir/lite/quantization/xla/quantize.cc | 16 +++--- .../mlir/lite/sparsity/sparsify_model.cc | 10 ++-- .../compiler/mlir/lite/tf_tfl_passes.cc | 10 ++-- tensorflow/compiler/mlir/lite/tf_tfl_passes.h | 4 +- .../compiler/mlir/lite/tf_tfl_translate.cc | 12 ++--- .../mlir/lite/tf_to_tfl_flatbuffer.cc | 12 ++--- .../compiler/mlir/lite/tf_to_tfl_flatbuffer.h | 6 +-- .../lite/transforms/default_quant_params.cc | 6 +-- .../mlir/lite/transforms/dense_to_sparse.cc | 6 +-- .../mlir/lite/transforms/dilated_conv.h | 12 ++--- .../mlir/lite/transforms/extract_ophint.cc | 40 +++++++------- .../transforms/legalize_ophint_func_op.cc | 34 ++++++------ .../mlir/lite/transforms/legalize_tf.cc | 22 ++++---- .../mlir/lite/transforms/legalize_tf_while.cc | 12 ++--- .../transforms/load_quantization_recipe.cc | 10 ++-- .../transforms/lower_static_tensor_list.cc | 44 ++++++++-------- .../compiler/mlir/lite/transforms/optimize.cc | 16 +++--- .../transforms/optimize_functional_ops.cc | 18 +++---- .../mlir/lite/transforms/post_quantize.cc | 4 +- .../prepare_composite_functions_tf.cc | 30 +++++------ .../mlir/lite/transforms/prepare_quantize.cc | 10 ++-- .../mlir/lite/transforms/prepare_tf.cc | 22 ++++---- .../compiler/mlir/lite/transforms/quantize.cc | 22 ++++---- .../lite/transforms/runtime_type_verify.cc | 4 +- .../lite/transforms/split_merged_operands.cc | 36 ++++++------- .../mlir/lite/transforms/trim_functions_tf.cc | 14 ++--- .../lite/transforms/while_loop_outline.cc | 18 +++---- .../mlir/lite/utils/attribute_utils.cc | 4 +- .../mlir/lite/utils/attribute_utils.h | 2 +- .../compiler/mlir/lite/utils/convert_type.cc | 6 +-- .../compiler/mlir/lite/utils/convert_type.h | 2 +- .../compiler/mlir/lite/utils/lstm_utils.cc | 28 +++++----- .../compiler/mlir/lite/utils/lstm_utils.h | 12 ++--- .../mlir/lite/utils/lstm_utils_test.cc | 22 ++++---- .../mlir/lite/utils/stateful_ops_utils.cc | 2 +- .../mlir/lite/utils/stateful_ops_utils.h | 2 +- .../compiler/mlir/lite/utils/validators.cc | 4 +- .../compiler/mlir/lite/utils/validators.h | 4 +- .../mlir/mlir_graph_optimization_pass.h | 2 +- .../compiler/mlir/op_or_arg_name_mapper.cc | 6 +-- .../compiler/mlir/op_or_arg_name_mapper.h | 4 +- tensorflow/compiler/mlir/python/mlir.cc | 6 +-- .../analysis/side_effect_analysis.cc | 20 +++---- .../analysis/side_effect_analysis.h | 8 +-- .../mlir/tensorflow/ir/control_flow_ops.cc | 6 +-- .../mlir/tensorflow/ir/control_flow_ops.h | 8 +-- .../compiler/mlir/tensorflow/ir/tf_device.cc | 32 ++++++------ .../compiler/mlir/tensorflow/ir/tf_device.h | 8 +-- .../mlir/tensorflow/ir/tf_executor.cc | 36 ++++++------- .../compiler/mlir/tensorflow/ir/tf_executor.h | 14 ++--- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 44 ++++++++-------- .../compiler/mlir/tensorflow/ir/tf_ops.h | 24 ++++----- .../mlir/tensorflow/ir/tf_saved_model.cc | 20 +++---- .../mlir/tensorflow/ir/tf_saved_model.h | 8 +-- .../compiler/mlir/tensorflow/ir/tf_structs.h | 10 ++-- .../compiler/mlir/tensorflow/ir/tf_traits.h | 8 +-- .../compiler/mlir/tensorflow/ir/tf_types.cc | 4 +- .../compiler/mlir/tensorflow/ir/tf_types.h | 10 ++-- .../mlir/tensorflow/ir/tf_verifiers.cc | 2 +- .../mlir/tensorflow/ir/tf_verifiers.h | 2 +- .../annotate_parameter_replication.cc | 16 +++--- .../transforms/batchmatmul_to_einsum.cc | 20 +++---- .../transforms/batchmatmul_to_einsum.h | 6 +-- .../mlir/tensorflow/transforms/bridge.cc | 6 +-- .../mlir/tensorflow/transforms/bridge.h | 2 +- .../mlir/tensorflow/transforms/bridge_pass.cc | 6 +-- .../transforms/cluster_formation.cc | 14 ++--- .../transforms/cluster_outlining.cc | 18 +++---- .../transforms/collection_ops_util.cc | 18 +++---- .../transforms/collection_ops_util.h | 8 +-- .../tensorflow/transforms/constant_fold.cc | 2 +- .../tensorflow/transforms/constant_fold.h | 6 +-- .../tensorflow/transforms/decode_constant.cc | 8 +-- .../tensorflow/transforms/decode_constant.h | 2 +- .../transforms/decompose_resource_ops.cc | 2 +- .../transforms/decompose_resource_ops.h | 4 +- .../transforms/decompose_resource_ops_pass.cc | 4 +- .../tensorflow/transforms/dialect_hooks.cc | 14 ++--- .../mlir/tensorflow/transforms/einsum.cc | 20 +++---- .../mlir/tensorflow/transforms/einsum.h | 16 +++--- .../transforms/executor_island_coarsening.cc | 12 ++--- .../executor_tpuv1_inline_tpu_island.cc | 18 +++---- .../executor_tpuv1_island_coarsening.cc | 22 ++++---- .../executor_tpuv1_outline_tpu_island.cc | 18 +++---- .../mlir/tensorflow/transforms/fold_switch.cc | 36 ++++++------- .../transforms/freeze_global_tensors.cc | 10 ++-- .../functional_control_flow_to_cfg.cc | 16 +++--- .../mlir/tensorflow/transforms/gpu_fusion.cc | 16 +++--- .../transforms/graph_optimization_pass.cc | 8 +-- .../tensorflow/transforms/graph_pruning.cc | 8 +-- .../transforms/launch_to_device_attribute.cc | 14 ++--- .../transforms/layout_optimization.cc | 16 +++--- .../tensorflow/transforms/legalize_hlo.cc | 14 ++--- .../mlir/tensorflow/transforms/lower_tf.cc | 12 ++--- .../mlir/tensorflow/transforms/lower_tf.h | 4 +- .../tensorflow/transforms/lower_tf_pass.cc | 4 +- .../transforms/mark_function_visibility.cc | 4 +- .../materialize_mlir_passthrough_op.cc | 20 +++---- .../mlir/tensorflow/transforms/optimize.cc | 16 +++--- .../transforms/optimize_global_tensors.cc | 26 +++++----- .../transforms/parallel_execute_to_islands.cc | 12 ++--- .../mlir/tensorflow/transforms/passes.h | 2 +- .../transforms/promote_resources_to_args.cc | 16 +++--- .../transforms/raise_control_flow.cc | 6 +-- .../replicate_invariant_op_hoisting.cc | 12 ++--- .../transforms/replicate_to_island.cc | 18 +++---- .../transforms/resource_device_inference.cc | 20 +++---- .../transforms/resource_op_lifting.cc | 38 +++++++------- .../tensorflow/transforms/shape_inference.cc | 30 +++++------ .../tensorflow/transforms/shape_inference.h | 8 +-- .../transforms/shape_inference_pass.cc | 18 +++---- .../tensorflow/transforms/sink_constant.cc | 10 ++-- .../transforms/stack_ops_decomposition.cc | 30 +++++------ .../tensor_list_ops_decomposition.cc | 18 +++---- .../transforms/test_side_effect_analysis.cc | 10 ++-- .../transforms/tf_graph_optimization_pass.cc | 8 +-- .../transforms/tf_graph_optimization_pass.h | 2 +- .../transforms/tpu_cluster_formation.cc | 22 ++++---- .../transforms/tpu_dynamic_layout_pass.cc | 30 +++++------ .../transforms/tpu_dynamic_padding_mapper.cc | 18 +++---- .../tpu_merge_variables_with_execute.cc | 24 ++++----- .../tensorflow/transforms/tpu_rewrite_pass.cc | 18 +++---- .../tpu_sharding_identification_pass.cc | 14 ++--- .../tpu_variable_runtime_reformatting.cc | 28 +++++----- .../transforms/unroll_batch_matmul.cc | 20 +++---- .../transforms/unroll_batch_matmul.h | 6 +-- .../tensorflow/translate/breakup-islands.cc | 16 +++--- .../translate/control_to_executor_dialect.cc | 14 ++--- .../translate/derived_attr_populator_gen.cc | 2 +- .../translate/executor_to_control_dialect.cc | 14 ++--- .../tensorflow/translate/export_graphdef.cc | 26 +++++----- .../tensorflow/translate/export_graphdef.h | 6 +-- .../translate/export_tf_dialect_op.h | 2 +- .../mlir/tensorflow/translate/import_model.cc | 26 +++++----- .../mlir/tensorflow/translate/import_model.h | 4 +- .../translate/mlir_roundtrip_pass.cc | 6 +-- .../translate/mlir_roundtrip_pass.h | 2 +- .../translate/tf_functional_to_executor.cc | 10 ++-- .../tensorflow/translate/tf_mlir_translate.cc | 16 +++--- .../tensorflow/translate/tf_mlir_translate.h | 4 +- .../tf_mlir_translate_registration.cc | 4 +- .../translate/translate_tf_dialect_op.cc | 10 ++-- .../mlir/tensorflow/utils/bridge_logger.cc | 4 +- .../mlir/tensorflow/utils/bridge_logger.h | 6 +-- .../tensorflow/utils/compile_mlir_util.cc | 20 +++---- .../mlir/tensorflow/utils/compile_mlir_util.h | 2 +- .../mlir/tensorflow/utils/convert_tensor.cc | 8 +-- .../mlir/tensorflow/utils/convert_tensor.h | 4 +- .../tensorflow/utils/convert_tensor_test.cc | 6 +-- .../mlir/tensorflow/utils/convert_type.cc | 6 +-- .../mlir/tensorflow/utils/convert_type.h | 4 +- .../tensorflow/utils/convert_type_test.cc | 6 +-- .../mlir/tensorflow/utils/device_util.cc | 8 +-- .../mlir/tensorflow/utils/device_util.h | 4 +- .../mlir/tensorflow/utils/device_util_test.cc | 12 ++--- .../mlir/tensorflow/utils/dump_mlir_util.cc | 2 +- .../mlir/tensorflow/utils/dump_mlir_util.h | 2 +- .../tensorflow/utils/dump_mlir_util_test.cc | 6 +-- .../mlir/tensorflow/utils/error_util.h | 6 +-- .../mlir/tensorflow/utils/error_util_test.cc | 4 +- .../mlir/tensorflow/utils/eval_util.cc | 8 +-- .../mlir/tensorflow/utils/eval_util.h | 2 +- .../mlir/tensorflow/utils/export_utils.cc | 22 ++++---- .../mlir/tensorflow/utils/export_utils.h | 8 +-- .../mlir/tensorflow/utils/translate_utils.cc | 2 +- .../mlir/tensorflow/utils/translate_utils.h | 6 +-- .../tensorflow/utils/xla_sharding_util.cc | 12 ++--- .../mlir/tensorflow/utils/xla_sharding_util.h | 8 +-- tensorflow/compiler/mlir/tf_mlir_opt_main.cc | 8 +-- .../compiler/mlir/tf_mlir_translate_main.cc | 10 ++-- .../compiler/mlir/xla/convert_op_folder.cc | 6 +-- .../compiler/mlir/xla/convert_op_folder.h | 4 +- .../mlir/xla/hlo_function_importer.cc | 16 +++--- .../compiler/mlir/xla/hlo_function_importer.h | 12 ++--- .../compiler/mlir/xla/hlo_module_importer.cc | 12 ++--- .../compiler/mlir/xla/hlo_module_importer.h | 8 +-- tensorflow/compiler/mlir/xla/hlo_utils.cc | 8 +-- tensorflow/compiler/mlir/xla/hlo_utils.h | 6 +-- .../compiler/mlir/xla/ir/hlo_client_ops.cc | 2 +- .../compiler/mlir/xla/ir/hlo_client_ops.h | 16 +++--- tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 38 +++++++------- tensorflow/compiler/mlir/xla/ir/hlo_ops.h | 24 ++++----- tensorflow/compiler/mlir/xla/ir/hlo_utils.cc | 2 +- tensorflow/compiler/mlir/xla/ir/hlo_utils.h | 10 ++-- tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc | 28 +++++----- tensorflow/compiler/mlir/xla/ir/lhlo_ops.h | 20 +++---- .../compiler/mlir/xla/ir/mlir_hlo_builder.cc | 4 +- .../compiler/mlir/xla/ir/mlir_hlo_builder.h | 10 ++-- .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 20 +++---- .../compiler/mlir/xla/mlir_hlo_to_hlo.h | 2 +- .../compiler/mlir/xla/operator_writer_gen.cc | 4 +- .../xla/transforms/hlo_legalize_to_lhlo.cc | 24 ++++----- .../xla/transforms/legalize_control_flow.cc | 22 ++++---- .../mlir/xla/transforms/legalize_tf.cc | 28 +++++----- .../transforms/legalize_tf_control_flow.cc | 26 +++++----- .../xla/transforms/legalize_tf_with_tf2xla.cc | 20 +++---- .../xla/transforms/legalize_to_standard.cc | 8 +-- .../mlir/xla/transforms/lhlo_copy_removal.cc | 6 +-- .../mlir/xla/transforms/lhlo_fuse_linalg.cc | 6 +-- .../xla/transforms/lhlo_legalize_to_affine.cc | 16 +++--- .../xla/transforms/lhlo_legalize_to_gpu.cc | 30 +++++------ .../lhlo_legalize_to_parallel_loops.cc | 12 ++--- .../mlir/xla/transforms/lower_complex.cc | 16 +++--- .../mlir/xla/transforms/lower_general_dot.cc | 18 +++---- .../xla/transforms/map_xla_to_scalar_op.h | 2 +- .../xla/transforms/materialize_broadcasts.cc | 10 ++-- .../transforms/materialize_broadcasts_pass.cc | 12 ++--- .../compiler/mlir/xla/transforms/passes.h | 4 +- .../compiler/mlir/xla/transforms/rewriters.h | 6 +-- .../mlir/xla/transforms/unfuse_batch_norm.cc | 16 +++--- .../xla/transforms/unfuse_batch_norm_pass.cc | 12 ++--- .../xla/transforms/xla_legalize_to_linalg.cc | 28 +++++----- tensorflow/compiler/mlir/xla/type_to_shape.cc | 10 ++-- tensorflow/compiler/mlir/xla/type_to_shape.h | 2 +- .../compiler/mlir/xla/type_to_shape_test.cc | 6 +-- .../compiler/mlir/xla/xla_mlir_translate.cc | 4 +- tensorflow/compiler/tf2xla/mlir_tf2xla.cc | 4 +- tensorflow/compiler/xla/python/dlpack.cc | 2 +- .../xla/service/mlir_gpu/emission_context.cc | 4 +- .../xla/service/mlir_gpu/emission_context.h | 2 +- .../experimental/conv_emitter/conv_emitter.cc | 14 ++--- .../experimental/conv_emitter/conv_emitter.h | 2 +- .../conv_emitter/conv_emitter_test.cc | 16 +++--- .../service/mlir_gpu/hlo_dialect_emitter.cc | 8 +-- .../service/mlir_gpu/hlo_dialect_emitter.h | 8 +-- .../xla/service/mlir_gpu/inject_errors_pass.h | 2 +- .../xla/service/mlir_gpu/kernel_lowering.cc | 52 +++++++++---------- .../xla/service/mlir_gpu/kernel_lowering.h | 2 +- .../service/mlir_gpu/lhlo_dialect_emitter.cc | 16 +++--- .../service/mlir_gpu/lhlo_dialect_emitter.h | 8 +-- .../xla/service/mlir_gpu/mlir_compiler.cc | 28 +++++----- .../xla/service/mlir_gpu/mlir_compiler.h | 4 +- .../service/mlir_gpu/mlir_irgen_test_base.cc | 4 +- tensorflow/core/kernels/string_lower_op.cc | 2 +- tensorflow/core/kernels/string_upper_op.cc | 2 +- tensorflow/core/kernels/unicode_ops.cc | 22 ++++---- tensorflow/core/kernels/unicode_script_op.cc | 4 +- tensorflow/core/platform/default/port.cc | 2 +- .../core/platform/default/strong_hash.h | 4 +- tensorflow/core/platform/gif.h | 2 +- tensorflow/core/platform/jpeg.h | 4 +- tensorflow/core/platform/png.h | 2 +- tensorflow/lite/delegates/flex/kernel.cc | 2 +- tensorflow/lite/delegates/flex/test_util.cc | 2 +- .../lite/delegates/gpu/cl/program_cache.cc | 2 +- .../lite/delegates/gpu/gl/serialization.h | 2 +- .../gl/workgroups/calculator_from_metadata.cc | 2 +- tensorflow/lite/delegates/gpu/gl_delegate.cc | 2 +- .../lite/delegates/xnnpack/conv_2d_test.cc | 2 +- .../xnnpack/depthwise_conv_2d_test.cc | 2 +- .../kernels/ctc_beam_search_decoder.cc | 2 +- .../kernels/ctc_beam_search_decoder_test.cc | 2 +- .../lite/experimental/kernels/hashtable.cc | 2 +- .../kernels/hashtable_ops_test.cc | 2 +- .../microfrontend/audio_microfrontend.cc | 2 +- .../microfrontend/audio_microfrontend_test.cc | 2 +- .../flatbuffers_lib/flatbuffers_lib.cc | 4 +- .../writer/option_writer_generator.cc | 2 +- tensorflow/lite/kernels/audio_spectrogram.cc | 2 +- .../lite/kernels/audio_spectrogram_test.cc | 2 +- .../lite/kernels/detection_postprocess.cc | 2 +- .../kernels/detection_postprocess_test.cc | 2 +- tensorflow/lite/kernels/if_test.cc | 2 +- tensorflow/lite/kernels/mfcc.cc | 2 +- tensorflow/lite/kernels/mfcc_test.cc | 2 +- .../lite/kernels/non_max_suppression.cc | 2 +- tensorflow/lite/kernels/subgraph_test_util.cc | 2 +- tensorflow/lite/kernels/while_test.cc | 2 +- .../schema/flatbuffer_compatibility_test.cc | 2 +- tensorflow/lite/toco/tflite/export_test.cc | 2 +- .../tools/optimize/calibration/calibrator.h | 2 +- .../tools/optimize/quantize_model_test.cc | 4 +- .../tools/optimize/quantize_weights_test.cc | 4 +- .../lite/tools/versioning/runtime_version.h | 2 +- 308 files changed, 1682 insertions(+), 1682 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 6304c8fd8f4..fee2154c8dc 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/c/eager/dlpack.h" -#include "include/dlpack/dlpack.h" // TF:dlpack +#include "include/dlpack/dlpack.h" // from @dlpack #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index b1fa1675845..83c95c03c8b 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -27,10 +27,10 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Attribute.h" // TF:llvm-project -#include "mlir/TableGen/Format.h" // TF:llvm-project -#include "mlir/TableGen/Operator.h" // TF:llvm-project -#include "mlir/TableGen/Predicate.h" // TF:llvm-project +#include "mlir/TableGen/Attribute.h" // from @llvm-project +#include "mlir/TableGen/Format.h" // from @llvm-project +#include "mlir/TableGen/Operator.h" // from @llvm-project +#include "mlir/TableGen/Predicate.h" // from @llvm-project using llvm::DefInit; using llvm::dyn_cast; diff --git a/tensorflow/compiler/mlir/lite/emit_error_reporter.h b/tensorflow/compiler/mlir/lite/emit_error_reporter.h index 76cc1f612bb..7f0ed8cf3c4 100644 --- a/tensorflow/compiler/mlir/lite/emit_error_reporter.h +++ b/tensorflow/compiler/mlir/lite/emit_error_reporter.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h b/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h index 26f6b0f3428..7d58fc41ab3 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_ #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" namespace hardware { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 4b888764053..4f6d11394d4 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -44,24 +44,24 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Translation.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.h b/tensorflow/compiler/mlir/lite/flatbuffer_import.h index 5dba9a0efc4..3cab45a5c15 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ #include "absl/strings/string_view.h" -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project namespace tflite { // Converts a TFLite flatbuffer stored in `buffer` to a MLIR module diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 2b4ca354996..9734608b19b 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index 4e8e3f6424e..5b55b557aa0 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -23,12 +23,12 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/core/platform/status.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc index c53d2c4ae4d..f9a4d29fdb3 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc @@ -23,8 +23,8 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers -#include "flatbuffers/minireflect.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/minireflect.h" // from @flatbuffers #include "tensorflow/lite/schema/reflection/schema_generated.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index e8337d4a79f..4163d13c36c 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -31,8 +31,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/None.h" @@ -42,20 +42,20 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Translation.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.h b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h index 03f92ddbf03..9f39928a737 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 471e50e0a52..caed5396c1b 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -26,19 +26,19 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project -#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/FoldUtils.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index a9b89c2bb64..3755bf490b9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -18,18 +18,18 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/Traits.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // TF:llvm-project -#include "mlir/Interfaces/LoopLikeInterface.h" // TF:llvm-project -#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index 6f8292308a4..e635885801e 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -30,10 +30,10 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Parser.h" // TF:llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 660f73e59e9..0a3f0eb3518 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/FileUtilities.h" // TF:llvm-project -#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index a546dba3ff3..f8435d17c8d 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/FileUtilities.h" // TF:llvm-project -#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index e0eb8004a01..ae342dd49ae 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/FileUtilities.h" // TF:llvm-project -#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 41846d8e846..96c2096e469 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/lite/toco/model_flags.pb.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 26062b96de0..5a5012173e2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -23,18 +23,18 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/AffineExpr.h" // TF:llvm-project -#include "mlir/IR/AffineMap.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 2f677397109..f961b037a6c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -17,11 +17,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc index d680c889d2c..8ea1709b15f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h index 35d667f506c..5d2c59fd7c7 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_ -#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Function.h" // from @llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 531a442fd6b..2964a3e79f8 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -24,18 +24,18 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h b/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h index 58e9538045b..2aa5f8e2d0d 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_ -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project namespace mlir { namespace quant { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index 885831ad0ce..b59164b72e6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -18,8 +18,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace OpTrait { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index f5c7287631a..0bd914aa2e7 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -22,15 +22,15 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantizeUtils.h" // TF:llvm-project -#include "mlir/Dialect/Quant/UniformSupport.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantizeUtils.h" // from @llvm-project +#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace quant { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 6a54262363c..27ccc7d2b22 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -23,18 +23,18 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h index c345da01c54..178daf1b1e0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 1a310de8b01..2c0b435cc04 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index 15c615d3dfd..208fb4c8a56 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -20,7 +20,7 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Operator.h" // TF:llvm-project +#include "mlir/TableGen/Operator.h" // from @llvm-project using llvm::LessRecord; using llvm::raw_ostream; diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc index 4ca5692584f..7bfeb241904 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc @@ -27,18 +27,18 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc index ab170def2b5..25a5f38bf0a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc @@ -25,14 +25,14 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/passes.h b/tensorflow/compiler/mlir/lite/quantization/xla/passes.h index b0d95948797..c4f9d63cf68 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/passes.h +++ b/tensorflow/compiler/mlir/lite/quantization/xla/passes.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { namespace xla_hlo { diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc index dd185598a90..4087eeb3c09 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc @@ -21,10 +21,10 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc index a5ac34e1cc0..9df41bb660a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index c05337918f2..806c0353ed9 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -17,11 +17,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index a80a1612488..bb7a30e64f6 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h index 651248b1059..ca153f54902 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 74e48cd6d91..fd7d95e1e33 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -20,12 +20,12 @@ limitations under the License. #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/FileUtilities.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index b05dcaadab2..b1c6cbc8d82 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Parser.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/FileUtilities.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index f670ac8e52b..c93f8a6d416 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ #include "llvm/Support/SourceMgr.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index bb48c392a5f..0bbacd48ade 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -25,9 +25,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index eb9b4edd3d5..2341c0306f1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -16,9 +16,9 @@ limitations under the License. // This transformation pass convert dense tensor to sparse format. #include "absl/memory/memory.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index d8a26154b2b..68a1c617e34 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -21,12 +21,12 @@ limitations under the License. #include #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index 5893d4f3779..40b9c54450e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -21,26 +21,26 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc index f3a15b7ebd3..0d9630a9793 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc @@ -15,23 +15,23 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringMap.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 3210ac7bc2b..4d40eec7a1b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -28,17 +28,17 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/Quant/UniformSupport.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc index ea44a34eb2b..66173c3c5b5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc @@ -15,12 +15,12 @@ limitations under the License. // Converts TF While to TFL While with single call in body and cond. -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc index 59b1dcce35d..3d42f81a758 100644 --- a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc +++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc @@ -19,11 +19,11 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 9df205d908c..8c6a2970397 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -31,28 +31,28 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 6137abfee4f..e324f614ca4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -31,14 +31,14 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 1c598fec08e..cf12e036360 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -17,15 +17,15 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 267901f69f3..86d23a2b0b2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -16,8 +16,8 @@ limitations under the License. // This transformation pass applies some clean up steps after quantization. #include "llvm/Support/Casting.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index b2cc58b863a..e8a2a2e75d8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -23,21 +23,21 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/Interfaces/CallInterfaces.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 287b9ca911c..cdbf4c41539 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -22,11 +22,11 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/Quant/QuantOps.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index efcc950cae0..f79543e6db6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -38,17 +38,17 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/Quant/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/Quant/UniformSupport.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index c78b04df247..3be335e8c7b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -19,17 +19,17 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc index 2a35701f0e6..92eb7023438 100644 --- a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc index c8aa67084ce..7f745727c49 100644 --- a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc @@ -18,24 +18,24 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc index 13afa1bf9b8..a81f2147059 100644 --- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -20,13 +20,13 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" // The cmd line flag to specify the whitelist of functions. Rest are trimmed diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index be024eccd45..d4c359b6178 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -17,15 +17,15 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc index 3d4bbdfa13c..bc1924e1da0 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h index 7c0ff910db1..1171efa6bcf 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h @@ -19,7 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 7158d634a89..00206373872 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.h b/tensorflow/compiler/mlir/lite/utils/convert_type.h index 90600c423bd..c4d9f98a02c 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.h +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index a138812e54d..1988dff048c 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -21,20 +21,20 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h index d8830d5e48c..b5063a33cd0 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -20,12 +20,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 0593bd150c7..5df57de6f71 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -24,17 +24,17 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/platform/test.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc index 4067cfb04b9..a9c9da039c1 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h index 635922d5cbb..fa591f66473 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc index f8e3dd12c8b..f863eeed0d6 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.cc +++ b/tensorflow/compiler/mlir/lite/utils/validators.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/validators.h" -#include "mlir/Dialect/Traits.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index fa1304c68e0..247947c3adc 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -19,8 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h index aed5307d39d..b405bcd6913 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/core/common_runtime/function_optimization_registry.h" #include "tensorflow/core/common_runtime/optimization_registry.h" diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index 63f558bc9c5..272fab9cd1c 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -25,9 +25,9 @@ limitations under the License. #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project static inline absl::string_view StringRefToView(llvm::StringRef ref) { return absl::string_view(ref.data(), ref.size()); diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h index 9445cc1374e..108496e2283 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h @@ -23,8 +23,8 @@ limitations under the License. #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project namespace tensorflow { diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index e6ac78be711..d0f6e015922 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include "llvm/Support/raw_ostream.h" -#include "mlir/Parser.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 931f24b9606..ff1620347f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -26,16 +26,16 @@ limitations under the License. #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index c491503917e..5989494f9aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -22,10 +22,10 @@ limitations under the License. #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Region.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc index e4b797d349a..96f2b62ceb8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" -#include "mlir/IR/DialectImplementation.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project namespace mlir { namespace TFControlFlow { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h index 0156d7e7e9d..15a4ecfc537 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h @@ -23,10 +23,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_ -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/SideEffects.h" // from @llvm-project namespace mlir { namespace TFControlFlow { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 163f4562d49..f757a1fe638 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -26,22 +26,22 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/SMLoc.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/UseDefLists.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Support/STLExtras.h" // TF:llvm-project -#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/STLExtras.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h index 1b20120cc2e..6600edf35a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -19,10 +19,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project namespace mlir { namespace tf_device { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 36b747b7fb7..8d670d96748 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -26,24 +26,24 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/Dialect/Traits.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/DialectImplementation.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Support/STLExtras.h" // TF:llvm-project -#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project -#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/STLExtras.h" // from @llvm-project +#include "mlir/Transforms/FoldUtils.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h index b7d8549ece7..3a8f3d14550 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h @@ -21,13 +21,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ -#include "mlir/Dialect/Traits.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index c2a94a8efe5..3622a636c3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -35,28 +35,28 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/Dialect/Traits.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/DialectImplementation.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Parser.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Support/STLExtras.h" // TF:llvm-project -#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/STLExtras.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index fbd1a335be1..8dc8fb351f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -19,18 +19,18 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ -#include "mlir/Dialect/Traits.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Interfaces/CallInterfaces.h" // TF:llvm-project -#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // TF:llvm-project -#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffects.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index ea46662bace..2deed928ba3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -22,16 +22,16 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h index e93293741f4..47ebb1a1be5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_ -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project namespace mlir { namespace tf_saved_model { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h index 65887a0c960..b1f39ad1d28 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h @@ -19,11 +19,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_STRUCTS_H_ #include "llvm/ADT/StringMap.h" -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/core/util/device_name_utils.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 18beb23663c..85c6819a8b4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -18,10 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index ef97b234ef7..188bc67f70e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "llvm/Support/ErrorHandling.h" -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project namespace { // Returns the shape of the given value if it's ranked; returns llvm::None diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 2898338f8eb..c5225a34fb4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -18,11 +18,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_ -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.cc index 247df44a90a..772769eebc3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h index 5289328e73f..f7d38f2b371 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_VERIFIERS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_VERIFIERS_H_ -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index cdbcd194ae6..d96f4c18a10 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -16,14 +16,14 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index 0663ad8c52e..bcf08c6b3ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -25,16 +25,16 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/matmul_bcast.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h index b0a1b59fb94..d39f3575b4a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/matmul_bcast.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 73110a724ea..2e1201c10c5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h index f837df534e9..0b831917a07 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/core/lib/core/status.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc index 936f7bf3359..080aea2521e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc index 7ced69fd32c..48f25b50ef6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -20,13 +20,13 @@ limitations under the License. #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index 0fef58ebb8a..1c9ace21efb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -17,15 +17,15 @@ limitations under the License. // `tf_device.launch` with equivalent `tf_device.launch_func` operations. #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 71426b04d99..428da70a4c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -19,15 +19,15 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h index 6b86cafed3f..a6afe618cfb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index c1a87c289bf..2269b4c55c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project +#include "mlir/Interfaces/SideEffects.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h index 3718d4bd765..69e39080965 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h @@ -18,9 +18,9 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc index 51c37b038d3..53129dbf703 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h index ae8b4eace4d..1acbb2e3a55 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECODE_CONSTANT_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECODE_CONSTANT_H_ -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc index c2fd8a152f3..59bed7e404e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h" -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h index 6697a2181ad..a32b5887f79 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc index 8b5f93be4c8..a439d7dcc45 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc b/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc index 05b0fb20b62..109ceea47e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/DialectHooks.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/DialectHooks.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 833b52e3e89..115b8938975 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -25,16 +25,16 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Regex.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/matmul_bcast.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h index 734d22432a1..490fe1ee887 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h @@ -23,14 +23,14 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/matmul_bcast.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index 837944ce0e7..eb2aa16e25f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -27,12 +27,12 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index ad844883453..85d9d994b30 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -17,15 +17,15 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Visitors.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc index cc87bd31486..54782116094 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc @@ -29,17 +29,17 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/UseDefLists.h" // TF:llvm-project -#include "mlir/IR/Visitors.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index 01901d8b5a4..b25cc23aac8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -16,15 +16,15 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index 4d5ad5ad423..7d0e7e20e5d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -30,24 +30,24 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/IR/Visitors.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index 82c198ac82f..9ae3ffdaa7d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/UseDefLists.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index 8cfa69c396e..b502b0ceb01 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -16,14 +16,14 @@ limitations under the License. // This transformation pass transforms functional control flow operations in the // standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form. -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index de830d879dd..a88ea2f387d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index c563a98d8c8..e2090803c00 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h" -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index d52c49e4436..6e022a64262 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -18,10 +18,10 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index 9a196aef54b..9319e91064d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -42,13 +42,13 @@ limitations under the License. // tensor, tensor, tensor, tensor // } -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Visitors.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index cb84be5748c..237f08c6c41 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 281efe98d2e..c3a0b1e303a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -17,13 +17,13 @@ limitations under the License. #include -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 9268881cb71..f934e2ac169 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -19,12 +19,12 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h index b72b0f25938..8cba39abe24 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc index be9e0f4aef4..ecd59442bf4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc index 55fdce7a8b6..dd884fd09fd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #define DEBUG_TYPE "tf-shape-inference" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc index ae208cbf686..c62b2a539b5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc @@ -17,16 +17,16 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Parser.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #define DEBUG_TYPE "tf-materialize-passthrough-op" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index acaf7974280..df8d1aeed16 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 68617e36f0c..713bcff1a71 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -20,19 +20,19 @@ limitations under the License. #include #include "llvm/ADT/DenseMap.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Interfaces/CallInterfaces.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc index 5caf08c672e..b5ecd5bd32b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc @@ -69,12 +69,12 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 92d15e13621..e1abe8465d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index a0a521745cf..4d83c647f40 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -39,14 +39,14 @@ limitations under the License. #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc index 55cb1e2c3df..e71b4a530b3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc @@ -22,9 +22,9 @@ limitations under the License. // eliminating control dependencies, and results in the code being in the // canonical TensorFlow dialect. -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 03e0b99a6ef..4d836cd056a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -20,12 +20,12 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/IR/Visitors.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 839bb1d1269..5cec4c0ed66 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -25,15 +25,15 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Visitors.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index c92ce1f01ad..d0abb8d844f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -26,16 +26,16 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/IR/Visitors.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 32dbb6f5d34..ed380c7b8bc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -22,25 +22,25 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/IR/Visitors.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 6d2ce76eca8..84c527d18ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -25,21 +25,21 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index 73993a07292..0524ec678ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -18,10 +18,10 @@ limitations under the License. #include -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Region.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index cffb14892c8..e45504ce819 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -20,15 +20,15 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc index 9d872fb3d1a..ff3dae278b3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc @@ -19,11 +19,11 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 4033d522091..54d23134c8c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -23,21 +23,21 @@ limitations under the License. #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/SymbolTable.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 8b1ba7d1d30..e1010f3b9bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -18,15 +18,15 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc index eb754cc3bbd..38960eef411 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc @@ -22,11 +22,11 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 5606428bb19..0ba01738532 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" #include "llvm/Support/CommandLine.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h index 49d92bf3151..bea23f8face 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index ea3a4efac74..fe11fee9f08 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -35,17 +35,17 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 7fe65b888d9..45fd3a5751d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -18,21 +18,21 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/STLExtras.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/STLExtras.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc index 38a01e168f7..6013dfdf4ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc @@ -24,15 +24,15 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc index c1419873dba..46d22844457 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc @@ -28,18 +28,18 @@ limitations under the License. #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index e20e78a243c..f9e24e4373d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -26,15 +26,15 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index a88675f1557..af01cf329b0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -21,13 +21,13 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 6e698c3ca5c..8a7f0c55c3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -26,20 +26,20 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/STLExtras.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/STLExtras.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index c6223ed13f7..d5603416d54 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -24,16 +24,16 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/matmul_bcast.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h index c725930a484..18c41f1ffdc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/matmul_bcast.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 32cb2e02930..d33dfba50ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -19,14 +19,14 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/STLExtras.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/STLExtras.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc index b89b3d8e6b2..6d3e35ac19b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc @@ -22,13 +22,13 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc b/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc index f78307a0282..e4c965b6cb1 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc @@ -23,7 +23,7 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Operator.h" // TF:llvm-project +#include "mlir/TableGen/Operator.h" // from @llvm-project using llvm::LessRecord; using llvm::raw_ostream; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc index 7755f5f2259..7410074e300 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc @@ -21,13 +21,13 @@ limitations under the License. #include "llvm/ADT/SmallString.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 5bdba99ac6e..b3511503454 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -28,19 +28,19 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Support/DebugStringHelper.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index e962ec174f5..2d522f6031e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -18,9 +18,9 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index df1f4859ded..e6657ebc8dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_TF_DIALECT_OP_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 366403e0654..9aa62efc61a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -42,19 +42,19 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Analysis/Verifier.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Analysis/Verifier.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index efc316483fe..8603eadb487 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc index 004293410b3..2a4d059f21e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h" -#include "mlir/Analysis/Verifier.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Analysis/Verifier.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h index 4a67b7fae76..e80b21b3f4d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc index 2ee3893eac9..4a625b62857 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include "llvm/Support/Debug.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #define DEBUG_TYPE "tf-functional-to-executor" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index d5fcf86cc93..12e38da987e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -17,14 +17,14 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Parser.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index 76bada96845..ef72000b4d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -21,8 +21,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project namespace tensorflow { // TODO(antiagainst): Directly manipulating files in library functions is not diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index e194289b120..b4c279c367d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -21,8 +21,8 @@ limitations under the License. #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Translation.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc index a9b5021559c..bd3fe9876ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Translation.h" // TF:llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index 7d449b8775f..8212c0b50a4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -17,8 +17,8 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h index 4f6d49b77e9..b5b2ad33b31 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index b13eec71de9..5394dbfb21a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -17,16 +17,16 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Parser.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 41fa8b90e4f..19423adfe17 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 0361b91c9e4..bdb4ebc5058 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -24,10 +24,10 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index b2646c265ad..fdaf7ef0d45 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -18,8 +18,8 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index bcd37e39de9..5d039176bb0 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index 84a8969a486..0caceb69510 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -17,9 +17,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Support/DebugStringHelper.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h index 24c4273ad0e..da0bb9f6cb7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc index e7206096d2c..07f6b129a41 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index 9561d0a2f93..bf0b3b75ace 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -23,10 +23,10 @@ limitations under the License. #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Regex.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/util/device_name_utils.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h index 1cbf0517554..893e118024c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/util/device_name_utils.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index 25e55e23c1a..bc849e1d116 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/device_attributes.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index 36a59d12060..d21ffe6be35 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/path.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 79c4961273a..726eed8974e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index 69e90de3cb6..b7e63661e17 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index 7eb30ee2c46..abef0de4585 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -18,9 +18,9 @@ limitations under the License. #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/core/lib/core/status.h" // Error utilities for MLIR when interacting with code using Status returns. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc index 61214108957..b174ad40a3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "llvm/ADT/Twine.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc index e4b7b854a4e..cca6981ae41 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc @@ -20,10 +20,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h index 39fd91afe40..4130e724232 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h @@ -19,7 +19,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index f8c118ac9d9..075014319df 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -23,17 +23,17 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Support/DebugStringHelper.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index a8c91c0b494..32ed528bd0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -23,10 +23,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index 6aeead516e8..f32485e070b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" -#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h index 395ba5c06ac..a398560e0e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TRANSLATE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TRANSLATE_UTILS_H_ -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index ede8130c953..7c1f69f4d92 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -22,12 +22,12 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index f7a9dbf2c81..1df4e1fbc37 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -18,10 +18,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 736e954278e..fbb775e061f 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -17,10 +17,10 @@ limitations under the License. #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Support/FileUtilities.h" // TF:llvm-project -#include "mlir/Support/MlirOptMain.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 29f9ec7eb46..ebb1462cad8 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -21,11 +21,11 @@ limitations under the License. #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/Support/FileUtilities.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Support/ToolUtilities.h" // TF:llvm-project -#include "mlir/Support/TranslateClParser.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/ToolUtilities.h" // from @llvm-project +#include "mlir/Support/TranslateClParser.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" diff --git a/tensorflow/compiler/mlir/xla/convert_op_folder.cc b/tensorflow/compiler/mlir/xla/convert_op_folder.cc index dfd7cb39bf9..42124ddf9b8 100644 --- a/tensorflow/compiler/mlir/xla/convert_op_folder.cc +++ b/tensorflow/compiler/mlir/xla/convert_op_folder.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project namespace mlir { namespace xla { diff --git a/tensorflow/compiler/mlir/xla/convert_op_folder.h b/tensorflow/compiler/mlir/xla/convert_op_folder.h index 37a4db0227f..5fe2f80561f 100644 --- a/tensorflow/compiler/mlir/xla/convert_op_folder.h +++ b/tensorflow/compiler/mlir/xla/convert_op_folder.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project namespace mlir { namespace xla { diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 95421d95504..6238e8175c4 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -20,14 +20,14 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Region.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index 6b027e9b2e4..5dfa0adac82 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -19,12 +19,12 @@ limitations under the License. #include #include "absl/types/optional.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/status.h" diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc index 82304f95e33..906dcba0083 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_module_importer.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h index c3e8c04cdcd..2fd7102c5a6 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -18,10 +18,10 @@ limitations under the License. #include -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/status.h" diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 3caa4f58725..7526248baca 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -17,10 +17,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_utils.h" -#include "mlir/IR/AffineMap.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/xla/literal.h" namespace xla { diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.h b/tensorflow/compiler/mlir/xla/hlo_utils.h index 764c40ed93b..f372cbf69bb 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/hlo_utils.h @@ -18,9 +18,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_ -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc index 9056f532715..921c4f069ec 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h" -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project namespace mlir { namespace xla_hlo_client { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h index 541ab0ebe3f..405b1ffb12e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h @@ -17,14 +17,14 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_CLIENT_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/DialectImplementation.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/SideEffects.h" // from @llvm-project namespace mlir { namespace xla_hlo_client { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 17d0b958084..abaad272acd 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -33,25 +33,25 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h index 1a864507253..02f36836f5e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -19,18 +19,18 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/DialectImplementation.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Interfaces/InferTypeOpInterface.h" // TF:llvm-project -#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc index 0143e781549..18bae4dec7a 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project namespace mlir { namespace xla { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h index 84ea3a1e1a8..079169e9c5c 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc index 0fbe5915fe8..7fb0e1c0831 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc @@ -28,20 +28,20 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/OpImplementation.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc" namespace mlir { diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h index f9cb2284526..8a3f833c7f4 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h @@ -19,16 +19,16 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/OpDefinition.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project -#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Support/Functional.h" // from @llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index dfb9ec4e837..0bdf8eb7f27 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -16,8 +16,8 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 232d1fa84e9..604db6060af 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -20,11 +20,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape.h" diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 550d151d968..670f34b4318 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -27,16 +27,16 @@ limitations under the License. #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 6f91213b31a..983d61a8af2 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_ -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 34c4c2221ca..d8b54c1acb9 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -26,8 +26,8 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/Support/STLExtras.h" // TF:llvm-project -#include "mlir/TableGen/Operator.h" // TF:llvm-project +#include "mlir/Support/STLExtras.h" // from @llvm-project +#include "mlir/TableGen/Operator.h" // from @llvm-project using llvm::raw_ostream; using llvm::RecordKeeper; diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 4dad8c5a996..51edaaf53bd 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -16,18 +16,18 @@ limitations under the License. // This file implements logic for lowering HLO dialect to LHLO dialect. #include "absl/memory/memory.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index 72ea2e18ec0..3633b32f847 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -18,17 +18,17 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Block.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index f6358d6cde7..817dfb55ec9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -24,20 +24,20 @@ limitations under the License. #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/Dialect/Traits.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Matchers.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 265466ef3a4..8d57599d397 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -28,19 +28,19 @@ limitations under the License. #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/iterator_range.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 913fc678558..327040a087f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -22,16 +22,16 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "llvm/ADT/Optional.h" -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 0769e92b8ce..0e3c59e06cd 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -16,10 +16,10 @@ limitations under the License. // This file implements logic for lowering XLA dialect to Standard dialect. #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc index 86125126390..97341879759 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc @@ -16,9 +16,9 @@ limitations under the License. // This file implements a pass to remove redundant LHLO copy operations. #include "absl/memory/memory.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index a27a27b3760..fbe8d800306 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -19,9 +19,9 @@ limitations under the License. #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project +#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/transforms/passes.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index 43f0116ef0d..b01f573a9a5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -16,14 +16,14 @@ limitations under the License. // This file implements logic for lowering LHLO dialect to Affine dialect. #include "absl/memory/memory.h" -#include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/AffineOps/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index 537703302c0..0ea29393744 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -19,21 +19,21 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project +#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index 894e4d039b8..f2ae7227a23 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -17,12 +17,12 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project +#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc b/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc index 672398672de..479f865626b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc @@ -23,14 +23,14 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc index 2e901094348..28cbbf9f6e3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc @@ -17,15 +17,15 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/TypeUtilities.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index 40add223156..781e7136123 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index 157029a04dc..237cac64ffd 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -15,11 +15,11 @@ limitations under the License. #include -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc index 644fffcc7ea..5622892c684 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index b1afd543c2e..4f1b6b081f8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -19,8 +19,8 @@ limitations under the License. #include #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h index cb7f93edc25..7656c89facb 100644 --- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h @@ -18,9 +18,9 @@ limitations under the License. #include -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { namespace xla_hlo { diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc index 2b785c4ba06..af2e58cb016 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc index ccec4d73b6e..caa51f27359 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 5ef3b445db4..d4f90ade2a2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -17,20 +17,20 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/ADT/APInt.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/AffineExpr.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index d82b2d33779..8250976eb00 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -17,11 +17,11 @@ limitations under the License. #include -#include "mlir/IR/AffineMap.h" // TF:llvm-project -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Support/DebugStringHelper.h" // TF:llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.h b/tensorflow/compiler/mlir/xla/type_to_shape.h index c9989def939..647fb56bb26 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.h +++ b/tensorflow/compiler/mlir/xla/type_to_shape.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_ #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index 98f9b36c84b..b2a7cb85686 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index ada0f9da3b8..211470bd41e 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -17,8 +17,8 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/MemoryBuffer.h" -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Translation.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/xla/debug_options_flags.h" diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 01b1fed9cac..43404bc2267 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index ca34fb504bd..4ac992011f1 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "include/dlpack/dlpack.h" // TF:dlpack +#include "include/dlpack/dlpack.h" // from @dlpack #include "tensorflow/compiler/xla/python/shared_device_buffer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc index 3c27dc662fe..5704f140eba 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" #include "absl/strings/substitute.h" -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h index db702dbc014..9550914a26f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc index b3e4002a898..79e90b74208 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc @@ -30,13 +30,13 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/AffineExpr.h" // TF:llvm-project -#include "mlir/IR/AffineMap.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/Transforms/LoopUtils.h" // TF:llvm-project -#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "mlir/Dialect/AffineOps/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Transforms/LoopUtils.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/window_util.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h index 5f01dffb756..e7d5fc4d276 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ -#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index c4ec4ea73ab..56684b1f726 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -18,14 +18,14 @@ limitations under the License. #include #include "llvm/Support/raw_ostream.h" -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // TF:llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index 0914e5ef820..1c2fc1962cf 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h" #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/comparison_util.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h index a1ec6d88644..20d2d1418ca 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h @@ -20,10 +20,10 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h index 1e0e41868ca..dd19fbe35cb 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_INJECT_ERRORS_PASS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_INJECT_ERRORS_PASS_H_ -#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 748306561d4..6f7ff5461d2 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -18,32 +18,32 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // TF:llvm-project -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // TF:llvm-project -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // TF:llvm-project -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // TF:llvm-project -#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // TF:llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:llvm-project -#include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project -#include "mlir/Dialect/GPU/Passes.h" // TF:llvm-project -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:llvm-project -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // TF:llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project -#include "mlir/Dialect/Linalg/Passes.h" // TF:llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/PatternMatch.h" // TF:llvm-project -#include "mlir/IR/Region.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project -#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // from @llvm-project +#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project +#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h index 027c3c93dca..8a8882cab30 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_KERNEL_LOWERING_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_KERNEL_LOWERING_H_ -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 3d4e1078ca2..3f17694af1d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Identifier.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h index ee0dbd6f320..f39d20efe2f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -19,10 +19,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index e471ba192e1..dc33be5341c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -18,20 +18,20 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // TF:llvm-project -#include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/OperationSupport.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Target/NVVMIR.h" // TF:llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Target/NVVMIR.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index bb852b47f22..9aeef12ac28 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ #include "absl/container/flat_hash_map.h" -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc index fa2167a4bd9..c8e01b967e7 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc @@ -22,8 +22,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" #include "tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h" diff --git a/tensorflow/core/kernels/string_lower_op.cc b/tensorflow/core/kernels/string_lower_op.cc index 07065d2777e..b262ed9dc56 100644 --- a/tensorflow/core/kernels/string_lower_op.cc +++ b/tensorflow/core/kernels/string_lower_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/strings/ascii.h" -#include "unicode/unistr.h" // TF:icu +#include "unicode/unistr.h" // from @icu #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/string_upper_op.cc b/tensorflow/core/kernels/string_upper_op.cc index d9f088a7b78..baf5aa582d6 100644 --- a/tensorflow/core/kernels/string_upper_op.cc +++ b/tensorflow/core/kernels/string_upper_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/strings/ascii.h" -#include "unicode/unistr.h" // TF:icu +#include "unicode/unistr.h" // from @icu #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/unicode_ops.cc b/tensorflow/core/kernels/unicode_ops.cc index 14ae49c837f..d3a7ad7b286 100644 --- a/tensorflow/core/kernels/unicode_ops.cc +++ b/tensorflow/core/kernels/unicode_ops.cc @@ -22,17 +22,17 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "unicode/appendable.h" // TF:icu -#include "unicode/schriter.h" // TF:icu -#include "unicode/uchar.h" // TF:icu -#include "unicode/ucnv.h" // TF:icu -#include "unicode/ucnv_err.h" // TF:icu -#include "unicode/umachine.h" // TF:icu -#include "unicode/uniset.h" // TF:icu -#include "unicode/unistr.h" // TF:icu -#include "unicode/uset.h" // TF:icu -#include "unicode/utf.h" // TF:icu -#include "unicode/utypes.h" // TF:icu +#include "unicode/appendable.h" // from @icu +#include "unicode/schriter.h" // from @icu +#include "unicode/uchar.h" // from @icu +#include "unicode/ucnv.h" // from @icu +#include "unicode/ucnv_err.h" // from @icu +#include "unicode/umachine.h" // from @icu +#include "unicode/uniset.h" // from @icu +#include "unicode/unistr.h" // from @icu +#include "unicode/uset.h" // from @icu +#include "unicode/utf.h" // from @icu +#include "unicode/utypes.h" // from @icu #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/core/kernels/unicode_script_op.cc b/tensorflow/core/kernels/unicode_script_op.cc index 085e397eba5..70ab6ef39bf 100644 --- a/tensorflow/core/kernels/unicode_script_op.cc +++ b/tensorflow/core/kernels/unicode_script_op.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "unicode/errorcode.h" // TF:icu -#include "unicode/uscript.h" // TF:icu +#include "unicode/errorcode.h" // from @icu +#include "unicode/uscript.h" // from @icu #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/port.cc b/tensorflow/core/platform/default/port.cc index 756e7e8a93a..b3a4bbbecbd 100644 --- a/tensorflow/core/platform/default/port.cc +++ b/tensorflow/core/platform/default/port.cc @@ -46,7 +46,7 @@ limitations under the License. #endif #if TENSORFLOW_USE_NUMA -#include "hwloc.h" // TF:hwloc +#include "hwloc.h" // from @hwloc #endif namespace tensorflow { diff --git a/tensorflow/core/platform/default/strong_hash.h b/tensorflow/core/platform/default/strong_hash.h index 8c25bf6c79b..f04f1b7b6ae 100644 --- a/tensorflow/core/platform/default/strong_hash.h +++ b/tensorflow/core/platform/default/strong_hash.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_ #define TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_ -#include "highwayhash/sip_hash.h" // TF:highwayhash -#include "highwayhash/state_helpers.h" // TF:highwayhash +#include "highwayhash/sip_hash.h" // from @highwayhash +#include "highwayhash/state_helpers.h" // from @highwayhash namespace tensorflow { diff --git a/tensorflow/core/platform/gif.h b/tensorflow/core/platform/gif.h index 2e69055c365..79af3822d29 100644 --- a/tensorflow/core/platform/gif.h +++ b/tensorflow/core/platform/gif.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_GIF_H_ #define TENSORFLOW_CORE_PLATFORM_GIF_H_ -#include "gif_lib.h" // TF:gif +#include "gif_lib.h" // from @gif #endif // TENSORFLOW_CORE_PLATFORM_GIF_H_ diff --git a/tensorflow/core/platform/jpeg.h b/tensorflow/core/platform/jpeg.h index b6c7d9692cd..8b67e52165b 100644 --- a/tensorflow/core/platform/jpeg.h +++ b/tensorflow/core/platform/jpeg.h @@ -22,8 +22,8 @@ limitations under the License. #include extern "C" { -#include "jerror.h" // TF:libjpeg_turbo -#include "jpeglib.h" // TF:libjpeg_turbo +#include "jerror.h" // from @libjpeg_turbo +#include "jpeglib.h" // from @libjpeg_turbo } #endif // TENSORFLOW_CORE_PLATFORM_JPEG_H_ diff --git a/tensorflow/core/platform/png.h b/tensorflow/core/platform/png.h index 43fb3927c48..401e62f5da5 100644 --- a/tensorflow/core/platform/png.h +++ b/tensorflow/core/platform/png.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" #if defined(PLATFORM_GOOGLE) && !defined(IS_MOBILE_PLATFORM) -#include "png.h" // TF:png +#include "png.h" // from @png #elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ defined(PLATFORM_POSIX_ANDROID) || defined(IS_MOBILE_PLATFORM) #include diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index 853087101d0..9a664b28246 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/flex/kernel.h" -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" diff --git a/tensorflow/lite/delegates/flex/test_util.cc b/tensorflow/lite/delegates/flex/test_util.cc index a686a54d097..8c0e40b58dd 100644 --- a/tensorflow/lite/delegates/flex/test_util.cc +++ b/tensorflow/lite/delegates/flex/test_util.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/test_util.h" #include "absl/memory/memory.h" -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/string_type.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/program_cache.cc b/tensorflow/lite/delegates/gpu/cl/program_cache.cc index 00ab1b791c4..e6735b448de 100644 --- a/tensorflow/lite/delegates/gpu/cl/program_cache.cc +++ b/tensorflow/lite/delegates/gpu/cl/program_cache.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/gpu/cl/cl_program.h" #include "tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" diff --git a/tensorflow/lite/delegates/gpu/gl/serialization.h b/tensorflow/lite/delegates/gpu/gl/serialization.h index 96c0a0b1073..c3c88b4c462 100644 --- a/tensorflow/lite/delegates/gpu/gl/serialization.h +++ b/tensorflow/lite/delegates/gpu/gl/serialization.h @@ -22,7 +22,7 @@ limitations under the License. #include #include "absl/types/span.h" -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiled_model_generated.h" diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc index b258f2c4424..7976fd54ed0 100644 --- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/metadata_generated.h" diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc index df22efdffa9..16aaafa5c94 100644 --- a/tensorflow/lite/delegates/gpu/gl_delegate.cc +++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc @@ -48,7 +48,7 @@ limitations under the License. #include "tensorflow/lite/minimal_logging.h" #ifndef TFLITE_GPU_BINARY_RELEASE -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/gpu/gl/metadata_generated.h" #include "tensorflow/lite/schema/schema_generated.h" #endif // TFLITE_GPU_BINARY_RELEASE diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc index bd17dff8192..95a358d1b9c 100644 --- a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc index 3fb520466e0..18cf55eb91c 100644 --- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" diff --git a/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder.cc index 101900f2682..c5e019fc2ee 100644 --- a/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder.cc +++ b/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/experimental/kernels/ctc_beam_search.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder_test.cc index 572b56f1225..b173f1a086c 100644 --- a/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder_test.cc +++ b/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/test_util.h" diff --git a/tensorflow/lite/experimental/kernels/hashtable.cc b/tensorflow/lite/experimental/kernels/hashtable.cc index 664262b4d5c..d1f9551ddf0 100644 --- a/tensorflow/lite/experimental/kernels/hashtable.cc +++ b/tensorflow/lite/experimental/kernels/hashtable.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" diff --git a/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc b/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc index 4e4100b3734..797b7b36b27 100644 --- a/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc +++ b/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/memory/memory.h" #include "absl/strings/match.h" -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/experimental/resource/lookup_interfaces.h" #include "tensorflow/lite/interpreter.h" diff --git a/tensorflow/lite/experimental/microfrontend/audio_microfrontend.cc b/tensorflow/lite/experimental/microfrontend/audio_microfrontend.cc index 84ab164d2c0..b8d89f69b26 100644 --- a/tensorflow/lite/experimental/microfrontend/audio_microfrontend.cc +++ b/tensorflow/lite/experimental/microfrontend/audio_microfrontend.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/context.h" #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" diff --git a/tensorflow/lite/experimental/microfrontend/audio_microfrontend_test.cc b/tensorflow/lite/experimental/microfrontend/audio_microfrontend_test.cc index a03fabb2348..40f36faa218 100644 --- a/tensorflow/lite/experimental/microfrontend/audio_microfrontend_test.cc +++ b/tensorflow/lite/experimental/microfrontend/audio_microfrontend_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/model.h" diff --git a/tensorflow/lite/experimental/support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/tensorflow/lite/experimental/support/metadata/flatbuffers_lib/flatbuffers_lib.cc index 2b1402daf7e..3f5d4e221c7 100644 --- a/tensorflow/lite/experimental/support/metadata/flatbuffers_lib/flatbuffers_lib.cc +++ b/tensorflow/lite/experimental/support/metadata/flatbuffers_lib/flatbuffers_lib.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers -#include "flatbuffers/idl.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/idl.h" // from @flatbuffers #include "include/pybind11/pybind11.h" #include "include/pybind11/pytypes.h" #include "include/pybind11/stl.h" diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index 8161860afb1..6c5315cdbda 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include #include -#include "flatbuffers/minireflect.h" // TF:flatbuffers +#include "flatbuffers/minireflect.h" // from @flatbuffers #include "tensorflow/lite/schema/reflection/schema_generated.h" namespace tflite { diff --git a/tensorflow/lite/kernels/audio_spectrogram.cc b/tensorflow/lite/kernels/audio_spectrogram.cc index f217b140b65..99457ea11b1 100644 --- a/tensorflow/lite/kernels/audio_spectrogram.cc +++ b/tensorflow/lite/kernels/audio_spectrogram.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/lite/kernels/audio_spectrogram_test.cc b/tensorflow/lite/kernels/audio_spectrogram_test.cc index d19877e17a5..0f4182ea728 100644 --- a/tensorflow/lite/kernels/audio_spectrogram_test.cc +++ b/tensorflow/lite/kernels/audio_spectrogram_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/test_util.h" diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc index 22dbf79914a..5d848bc9eab 100644 --- a/tensorflow/lite/kernels/detection_postprocess.cc +++ b/tensorflow/lite/kernels/detection_postprocess.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/lite/kernels/detection_postprocess_test.cc b/tensorflow/lite/kernels/detection_postprocess_test.cc index 3d3cb2afb10..348ea45a515 100644 --- a/tensorflow/lite/kernels/detection_postprocess_test.cc +++ b/tensorflow/lite/kernels/detection_postprocess_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/test_util.h" diff --git a/tensorflow/lite/kernels/if_test.cc b/tensorflow/lite/kernels/if_test.cc index cd7372c978a..c81300e5d1d 100644 --- a/tensorflow/lite/kernels/if_test.cc +++ b/tensorflow/lite/kernels/if_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/register.h" diff --git a/tensorflow/lite/kernels/mfcc.cc b/tensorflow/lite/kernels/mfcc.cc index 477de88e749..cba7cb132eb 100644 --- a/tensorflow/lite/kernels/mfcc.cc +++ b/tensorflow/lite/kernels/mfcc.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/kernels/internal/mfcc.h" -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/mfcc_dct.h" diff --git a/tensorflow/lite/kernels/mfcc_test.cc b/tensorflow/lite/kernels/mfcc_test.cc index 99dcc3c8a72..a6b769ccd37 100644 --- a/tensorflow/lite/kernels/mfcc_test.cc +++ b/tensorflow/lite/kernels/mfcc_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/test_util.h" diff --git a/tensorflow/lite/kernels/non_max_suppression.cc b/tensorflow/lite/kernels/non_max_suppression.cc index a21c4f62a57..ee8e407066d 100644 --- a/tensorflow/lite/kernels/non_max_suppression.cc +++ b/tensorflow/lite/kernels/non_max_suppression.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/tensor.h" diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc index d767dd6e008..00f947a9e38 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.cc +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/kernels/subgraph_test_util.h" -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/kernels/while_test.cc b/tensorflow/lite/kernels/while_test.cc index dc69e496533..324519e32a0 100644 --- a/tensorflow/lite/kernels/while_test.cc +++ b/tensorflow/lite/kernels/while_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/register.h" diff --git a/tensorflow/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/lite/schema/flatbuffer_compatibility_test.cc index 86177aeb127..8f88b6204a1 100644 --- a/tensorflow/lite/schema/flatbuffer_compatibility_test.cc +++ b/tensorflow/lite/schema/flatbuffer_compatibility_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include #include -#include "flatbuffers/flatc.h" // TF:flatbuffers +#include "flatbuffers/flatc.h" // from @flatbuffers #include "tensorflow/core/platform/platform.h" #ifdef PLATFORM_GOOGLE diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index c3f378f2d78..19b77543c66 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator.h b/tensorflow/lite/tools/optimize/calibration/calibrator.h index f2e85e1420a..ef7cea528d9 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibrator.h +++ b/tensorflow/lite/tools/optimize/calibration/calibrator.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h" diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index da1b293c84b..f8f1a9d4113 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index c35259ef437..76f2815ef0b 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/lite/tools/versioning/runtime_version.h b/tensorflow/lite/tools/versioning/runtime_version.h index f4889172746..e4c25221310 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.h +++ b/tensorflow/lite/tools/versioning/runtime_version.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers namespace tflite { // Update minimum runtime version of the given TFL flatbuffer model. From ac4bcb72fae5b6485203131c22b8af4ad0bd6456 Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Mon, 23 Mar 2020 10:44:05 -0700 Subject: [PATCH 420/492] Small Softmax cleanups: - Remove OpData. Use SoftmaxParams directly. - Only call CalculateSoftmaxOpData for quantized case, rename to CalculateQuantizedSoftmaxParams. - Add stricter type checks to CalculateQuantizedSoftmaxParams. PiperOrigin-RevId: 302467600 Change-Id: I0746b63c562484adf8077a173e41503271763888 --- tensorflow/lite/micro/kernels/softmax.cc | 95 ++++++++++++++---------- 1 file changed, 55 insertions(+), 40 deletions(-) diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc index 1f30ddc5949..85952de9d50 100644 --- a/tensorflow/lite/micro/kernels/softmax.cc +++ b/tensorflow/lite/micro/kernels/softmax.cc @@ -29,37 +29,41 @@ namespace micro { namespace activations { namespace { -TfLiteStatus CalculateQuantizedSoftmaxParams(TfLiteContext* context, - const TfLiteTensor* input, - TfLiteTensor* output, - const TfLiteSoftmaxParams* params, - SoftmaxParams* data) { - if (input->type == kTfLiteUInt8) { - TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8); - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); - } else { - TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8); - if (output->type == kTfLiteInt16) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); - // NOTE: Current int16 softmax output does not require symmetric scaling - // - so no need to verify scale here. +struct OpData { + int32_t input_multiplier = 0; + int input_left_shift = 0; + int32_t input_range_radius = 0; + int diff_min = 0; +}; + +TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + const TfLiteSoftmaxParams* params, + OpData* data) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); } else { - TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); - TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); - TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); + // NOTE: Current int16 softmax output does not require symmetric scaling + // - so no need to verify scale here. + } else { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); + TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); + } } + + static const int kScaledDiffIntegerBits = 5; + + tflite::PreprocessSoftmaxScaling( + static_cast(params->beta), + static_cast(input->params.scale), kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift); + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); } - - static const int kScaledDiffIntegerBits = 5; - - int input_left_shift; - tflite::PreprocessSoftmaxScaling(static_cast(params->beta), - static_cast(input->params.scale), - kScaledDiffIntegerBits, - &data->input_multiplier, &input_left_shift); - data->input_left_shift = input_left_shift; - data->diff_min = -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, - data->input_left_shift); return kTfLiteOk; } @@ -93,8 +97,7 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, } void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, - const SoftmaxParams& op_params) { + TfLiteSoftmaxParams* params, OpData* data) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation // always traverses the last dimension of a 4D tensor, we will pretend our 1D // tensor is 4D in a special way. We will convert a (Y) shape into a (1, @@ -102,6 +105,10 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, const int input_size = input->dims->data[0]; const int32_t shape_data[4] = {1, 1, 1, input_size}; RuntimeShape shape(4, shape_data); + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax(op_params, shape, GetTensorData(input), shape, @@ -120,8 +127,7 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, } void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, - const SoftmaxParams& op_params) { + TfLiteSoftmaxParams* params, OpData* data) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation // always traverses the last dimension of a 4D tensor, we will pretend our 2D // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, @@ -130,6 +136,10 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, const int input_size = input->dims->data[1]; const int32_t shape_data[4] = {batch_size, 1, 1, input_size}; RuntimeShape shape(4, shape_data); + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax(op_params, shape, GetTensorData(input), shape, @@ -158,8 +168,11 @@ void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, } void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, - const SoftmaxParams& op_params) { + TfLiteSoftmaxParams* params, OpData* data) { + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax( op_params, GetTensorShape(input), GetTensorData(input), @@ -183,6 +196,11 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS( + CalculateSoftmaxOpData(context, input, output, params, data)); + // TODO(ahentz): consider an implementation that works for many (all?) // dimensions. switch (input->type) { @@ -206,19 +224,16 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } case kTfLiteInt8: case kTfLiteUInt8: { - SoftmaxParams op_params; - TF_LITE_ENSURE_STATUS(CalculateQuantizedSoftmaxParams( - context, input, output, params, &op_params)); if (NumDimensions(input) == 1) { - Softmax1DQuantized(input, output, params, op_params); + Softmax1DQuantized(input, output, params, data); return kTfLiteOk; } if (NumDimensions(input) == 2) { - Softmax2DQuantized(input, output, params, op_params); + Softmax2DQuantized(input, output, params, data); return kTfLiteOk; } if (NumDimensions(input) == 4) { - Softmax4DQuantized(input, output, params, op_params); + Softmax4DQuantized(input, output, params, data); return kTfLiteOk; } TF_LITE_KERNEL_LOG( From d9d7827dd29118c85ce53bd11bc17d9bce2f13c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 10:44:39 -0700 Subject: [PATCH 421/492] Use -O3 for cuda compiles in opt, as significant cuda optimizations in cuda-clang only get enabled at -O3. PiperOrigin-RevId: 302467718 Change-Id: I7f6448c83762a9fb38d0ce47dd9b9850f89ba22d --- third_party/gpus/cuda/build_defs.bzl.tpl | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index 845866b86f2..3280d6b041f 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -24,9 +24,28 @@ def if_cuda_clang(if_true, if_false = []): "//conditions:default": if_false }) +def if_cuda_clang_opt(if_true, if_false = []): + """Shorthand for select()'ing on wheteher we're building with cuda-clang + in opt mode. + + Returns a select statement which evaluates to if_true if we're building + with cuda-clang in opt mode. Otherwise, the select statement evaluates to + if_false. + + """ + return select({ + "@local_config_cuda//cuda:using_clang_opt": if_true, + "//conditions:default": if_false + }) + def cuda_default_copts(): """Default options for all CUDA compilations.""" - return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"]) + %{cuda_extra_copts} + return if_cuda( + ["-x", "cuda", "-DGOOGLE_CUDA=1"] + ) + if_cuda_clang_opt( + # Some important CUDA optimizations are only enabled at O3. + ["-O3"] + ) + %{cuda_extra_copts} def cuda_is_configured(): """Returns true if CUDA was enabled during the configure process.""" From 06ca7fc73ca96d9a468d771273be377ed4cc6ed2 Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Mon, 23 Mar 2020 10:53:56 -0700 Subject: [PATCH 422/492] Update XNNPACK and cpuinfo dependencies PiperOrigin-RevId: 302469932 Change-Id: I3b40f89c61654431f30366387f858d2511b0c30a --- tensorflow/workspace.bzl | 8 +- third_party/cpuinfo/cpuinfo.patch | 3016 ----------------------------- third_party/cpuinfo/workspace.bzl | 14 +- 3 files changed, 8 insertions(+), 3030 deletions(-) delete mode 100644 third_party/cpuinfo/cpuinfo.patch diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 6f1feead83e..d89b96d3fc8 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -148,11 +148,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "XNNPACK", - sha256 = "190e61e50af3497bb46b8d936bd2d2d551a9aeedb02ff66388918408a54e216a", - strip_prefix = "XNNPACK-b18783570f0643560be641b193367d3906955141", + sha256 = "77a4cea07169b4d67df456d50deffaa100e587192657c68ee4f2b7c12ba133d1", + strip_prefix = "XNNPACK-479e78c7f93a5764ffb221bdead3f290c7fd8ea3", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/b18783570f0643560be641b193367d3906955141.zip", - "https://github.com/google/XNNPACK/archive/b18783570f0643560be641b193367d3906955141.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/479e78c7f93a5764ffb221bdead3f290c7fd8ea3.zip", + "https://github.com/google/XNNPACK/archive/479e78c7f93a5764ffb221bdead3f290c7fd8ea3.zip", ], ) diff --git a/third_party/cpuinfo/cpuinfo.patch b/third_party/cpuinfo/cpuinfo.patch deleted file mode 100644 index a9fa0dde0eb..00000000000 --- a/third_party/cpuinfo/cpuinfo.patch +++ /dev/null @@ -1,3016 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index de319ef..fefb60b 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -179,7 +179,6 @@ IF(CPUINFO_SUPPORTED_PLATFORM) - LIST(APPEND CPUINFO_SRCS - src/linux/smallfile.c - src/linux/multiline.c -- src/linux/current.c - src/linux/cpulist.c - src/linux/processors.c) - ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") -diff --git a/CMakeLists.txt.orig b/CMakeLists.txt.orig -deleted file mode 100644 -index a71aede..0000000 ---- a/CMakeLists.txt.orig -+++ /dev/null -@@ -1,819 +0,0 @@ --CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) -- --INCLUDE(GNUInstallDirs) -- --# ---[ Project and semantic versioning. --PROJECT(cpuinfo C CXX) -- --# ---[ Options. --SET(CPUINFO_LIBRARY_TYPE "default" CACHE STRING "Type of cpuinfo library (shared, static, or default) to build") --SET_PROPERTY(CACHE CPUINFO_LIBRARY_TYPE PROPERTY STRINGS default static shared) --SET(CPUINFO_RUNTIME_TYPE "default" CACHE STRING "Type of runtime library (shared, static, or default) to use") --SET_PROPERTY(CACHE CPUINFO_RUNTIME_TYPE PROPERTY STRINGS default static shared) --SET(CPUINFO_LOG_LEVEL "default" CACHE STRING "Minimum logging level (info with lower severity will be ignored)") --SET_PROPERTY(CACHE CPUINFO_LOG_LEVEL PROPERTY STRINGS default debug info warning error fatal none) --OPTION(CPUINFO_BUILD_TOOLS "Build command-line tools" ON) --OPTION(CPUINFO_BUILD_UNIT_TESTS "Build cpuinfo unit tests" ON) --OPTION(CPUINFO_BUILD_MOCK_TESTS "Build cpuinfo mock tests" ON) --OPTION(CPUINFO_BUILD_BENCHMARKS "Build cpuinfo micro-benchmarks" ON) -- --# ---[ CMake options --IF(CPUINFO_BUILD_UNIT_TESTS OR CPUINFO_BUILD_MOCK_TESTS) -- ENABLE_TESTING() --ENDIF() -- --MACRO(CPUINFO_TARGET_ENABLE_C99 target) -- IF(${CMAKE_VERSION} VERSION_LESS "3.1") -- IF(NOT MSVC) -- TARGET_COMPILE_OPTIONS(${target} PRIVATE -std=c99) -- ENDIF() -- ELSE() -- SET_TARGET_PROPERTIES(${target} PROPERTIES -- C_STANDARD 99 -- C_EXTENSIONS NO) -- ENDIF() --ENDMACRO() -- --MACRO(CPUINFO_TARGET_ENABLE_CXX11 target) -- IF(${CMAKE_VERSION} VERSION_LESS "3.1") -- IF(NOT MSVC) -- TARGET_COMPILE_OPTIONS(${target} PRIVATE -std=c++11) -- ENDIF() -- ELSE() -- SET_TARGET_PROPERTIES(${target} PROPERTIES -- CXX_STANDARD 11 -- CXX_EXTENSIONS NO) -- ENDIF() --ENDMACRO() -- --MACRO(CPUINFO_TARGET_RUNTIME_LIBRARY target) -- IF(MSVC AND NOT CPUINFO_RUNTIME_TYPE STREQUAL "default") -- IF(CPUINFO_RUNTIME_TYPE STREQUAL "shared") -- TARGET_COMPILE_OPTIONS(${target} PRIVATE -- "/MD$<$:d>") -- ELSEIF(CPUINFO_RUNTIME_TYPE STREQUAL "static") -- TARGET_COMPILE_OPTIONS(${target} PRIVATE -- "/MT$<$:d>") -- ENDIF() -- ENDIF() --ENDMACRO() -- --# ---[ Build flags --SET(CPUINFO_SUPPORTED_PLATFORM TRUE) --IF(NOT CMAKE_SYSTEM_PROCESSOR) -- IF(NOT IOS) -- MESSAGE(WARNING -- "Target processor architecture is not specified. " -- "cpuinfo will compile, but cpuinfo_initialize() will always fail.") -- SET(CPUINFO_SUPPORTED_PLATFORM FALSE) -- ENDIF() --ELSEIF(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64)$") -- MESSAGE(WARNING -- "Target processor architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in cpuinfo. " -- "cpuinfo will compile, but cpuinfo_initialize() will always fail.") -- SET(CPUINFO_SUPPORTED_PLATFORM FALSE) --ENDIF() -- --IF(NOT CMAKE_SYSTEM_NAME) -- MESSAGE(WARNING -- "Target operating system is not specified. " -- "cpuinfo will compile, but cpuinfo_initialize() will always fail.") -- SET(CPUINFO_SUPPORTED_PLATFORM FALSE) --ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Windows|Darwin|Linux|Android)$") -- IF(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.14" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") -- MESSAGE(WARNING -- "Target operating system \"${CMAKE_SYSTEM_NAME}\" is not supported in cpuinfo. " -- "cpuinfo will compile, but cpuinfo_initialize() will always fail.") -- SET(CPUINFO_SUPPORTED_PLATFORM FALSE) -- ENDIF() --ENDIF() -- --# ---[ Download deps --SET(CONFU_DEPENDENCIES_SOURCE_DIR ${CMAKE_SOURCE_DIR}/deps -- CACHE PATH "Confu-style dependencies source directory") --SET(CONFU_DEPENDENCIES_BINARY_DIR ${CMAKE_BINARY_DIR}/deps -- CACHE PATH "Confu-style dependencies binary directory") -- --IF(CPUINFO_BUILD_MOCK_TESTS OR CPUINFO_BUILD_UNIT_TESTS) -- IF(CPUINFO_SUPPORTED_PLATFORM AND NOT DEFINED GOOGLETEST_SOURCE_DIR) -- MESSAGE(STATUS "Downloading Google Test to ${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest (define GOOGLETEST_SOURCE_DIR to avoid it)") -- CONFIGURE_FILE(cmake/DownloadGoogleTest.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download/CMakeLists.txt") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . -- WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . -- WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download") -- SET(GOOGLETEST_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest" CACHE STRING "Google Test source directory") -- ENDIF() --ENDIF() -- --IF(CPUINFO_BUILD_BENCHMARKS) -- IF(CPUINFO_SUPPORTED_PLATFORM AND NOT DEFINED GOOGLEBENCHMARK_SOURCE_DIR) -- MESSAGE(STATUS "Downloading Google Benchmark to ${CONFU_DEPENDENCIES_SOURCE_DIR}/googlebenchmark (define GOOGLEBENCHMARK_SOURCE_DIR to avoid it)") -- CONFIGURE_FILE(cmake/DownloadGoogleBenchmark.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download/CMakeLists.txt") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . -- WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . -- WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download") -- SET(GOOGLEBENCHMARK_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googlebenchmark" CACHE STRING "Google Benchmark source directory") -- ENDIF() --ENDIF() -- --# ---[ cpuinfo library --SET(CPUINFO_SRCS -- src/init.c -- src/api.c) -- --IF(CPUINFO_SUPPORTED_PLATFORM) -- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?)$" OR IOS_ARCH MATCHES "^(i386|x86_64)$") -- LIST(APPEND CPUINFO_SRCS -- src/x86/init.c -- src/x86/info.c -- src/x86/vendor.c -- src/x86/uarch.c -- src/x86/name.c -- src/x86/topology.c -- src/x86/isa.c -- src/x86/cache/init.c -- src/x86/cache/descriptor.c -- src/x86/cache/deterministic.c) -- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") -- LIST(APPEND CPUINFO_SRCS -- src/x86/linux/init.c -- src/x86/linux/cpuinfo.c) -- ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") -- LIST(APPEND CPUINFO_SRCS src/x86/mach/init.c) -- ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Windows") -- LIST(APPEND CPUINFO_SRCS src/x86/windows/init.c) -- ENDIF() -- ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$" OR IOS_ARCH MATCHES "^(armv7.*|arm64.*)$") -- LIST(APPEND CPUINFO_SRCS -- src/arm/uarch.c -- src/arm/cache.c) -- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") -- LIST(APPEND CPUINFO_SRCS -- src/arm/linux/init.c -- src/arm/linux/cpuinfo.c -- src/arm/linux/clusters.c -- src/arm/linux/chipset.c -- src/arm/linux/midr.c -- src/arm/linux/hwcap.c) -- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") -- LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch32-isa.c) -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND ANDROID_ABI STREQUAL "armeabi") -- SET_SOURCE_FILES_PROPERTIES(src/arm/linux/aarch32-isa.c PROPERTIES COMPILE_FLAGS -marm) -- ENDIF() -- ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") -- LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch64-isa.c) -- ENDIF() -- ELSEIF(IOS) -- LIST(APPEND CPUINFO_SRCS src/arm/mach/init.c) -- ENDIF() -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android") -- LIST(APPEND CPUINFO_SRCS -- src/arm/android/properties.c) -- ENDIF() -- ENDIF() -- -- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") -- LIST(APPEND CPUINFO_SRCS -- src/linux/smallfile.c -- src/linux/multiline.c -- src/linux/current.c -- src/linux/cpulist.c -- src/linux/processors.c) -- ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") -- LIST(APPEND CPUINFO_SRCS src/mach/topology.c) -- ENDIF() -- -- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") -- SET(CMAKE_THREAD_PREFER_PTHREAD TRUE) -- SET(THREADS_PREFER_PTHREAD_FLAG TRUE) -- FIND_PACKAGE(Threads REQUIRED) -- ENDIF() --ENDIF() -- --IF(CPUINFO_LIBRARY_TYPE STREQUAL "default") -- ADD_LIBRARY(cpuinfo ${CPUINFO_SRCS}) --ELSEIF(CPUINFO_LIBRARY_TYPE STREQUAL "shared") -- ADD_LIBRARY(cpuinfo SHARED ${CPUINFO_SRCS}) --ELSEIF(CPUINFO_LIBRARY_TYPE STREQUAL "static") -- ADD_LIBRARY(cpuinfo STATIC ${CPUINFO_SRCS}) --ELSE() -- MESSAGE(FATAL_ERROR "Unsupported library type ${CPUINFO_LIBRARY_TYPE}") --ENDIF() --ADD_LIBRARY(cpuinfo_internals STATIC ${CPUINFO_SRCS}) --CPUINFO_TARGET_ENABLE_C99(cpuinfo) --CPUINFO_TARGET_ENABLE_C99(cpuinfo_internals) --CPUINFO_TARGET_RUNTIME_LIBRARY(cpuinfo) --SET_TARGET_PROPERTIES(cpuinfo PROPERTIES PUBLIC_HEADER include/cpuinfo.h) --TARGET_INCLUDE_DIRECTORIES(cpuinfo BEFORE PUBLIC include) --TARGET_INCLUDE_DIRECTORIES(cpuinfo BEFORE PRIVATE src) --TARGET_INCLUDE_DIRECTORIES(cpuinfo_internals BEFORE PUBLIC include src) --IF(CPUINFO_LOG_LEVEL STREQUAL "default") -- # default logging level: error (subject to change) -- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=2) --ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "debug") -- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=5) --ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "info") -- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=4) --ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "warning") -- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=3) --ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "error") -- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=2) --ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "fatal") -- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=1) --ELSEIF(CPUINFO_LOG_LEVEL STREQUAL "none") -- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE CPUINFO_LOG_LEVEL=0) --ELSE() -- MESSAGE(FATAL_ERROR "Unsupported logging level ${CPUINFO_LOG_LEVEL}") --ENDIF() --TARGET_COMPILE_DEFINITIONS(cpuinfo_internals PRIVATE CPUINFO_LOG_LEVEL=0) -- --IF(CPUINFO_SUPPORTED_PLATFORM) -- TARGET_COMPILE_DEFINITIONS(cpuinfo INTERFACE CPUINFO_SUPPORTED_PLATFORM=1) -- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") -- TARGET_LINK_LIBRARIES(cpuinfo PUBLIC ${CMAKE_THREAD_LIBS_INIT}) -- TARGET_LINK_LIBRARIES(cpuinfo_internals PUBLIC ${CMAKE_THREAD_LIBS_INIT}) -- TARGET_COMPILE_DEFINITIONS(cpuinfo PRIVATE _GNU_SOURCE=1) -- TARGET_COMPILE_DEFINITIONS(cpuinfo_internals PRIVATE _GNU_SOURCE=1) -- ENDIF() --ELSE() -- TARGET_COMPILE_DEFINITIONS(cpuinfo INTERFACE CPUINFO_SUPPORTED_PLATFORM=0) --ENDIF() -- --# ---[ cpuinfo dependencies: clog --IF(NOT DEFINED CLOG_SOURCE_DIR) -- SET(CLOG_SOURCE_DIR "${PROJECT_SOURCE_DIR}/deps/clog") --ENDIF() --IF(NOT TARGET clog) -- SET(CLOG_BUILD_TESTS OFF CACHE BOOL "") -- SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "") -- ADD_SUBDIRECTORY( -- "${CLOG_SOURCE_DIR}") -- # We build static version of clog but a dynamic library may indirectly depend on it -- SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON) --ENDIF() --TARGET_LINK_LIBRARIES(cpuinfo PRIVATE clog) --TARGET_LINK_LIBRARIES(cpuinfo_internals PRIVATE clog) -- --INSTALL(TARGETS cpuinfo -- LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} -- ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} -- PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) -- --# ---[ cpuinfo micro-benchmarks --IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_BENCHMARKS) -- # ---[ Build google benchmark -- IF(NOT TARGET benchmark) -- SET(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "") -- ADD_SUBDIRECTORY( -- "${GOOGLEBENCHMARK_SOURCE_DIR}" -- "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark") -- ENDIF() -- -- IF(CMAKE_SYSTEM_NAME MATCHES "^(Linux|Android)$") -- ADD_EXECUTABLE(get-current-bench bench/get-current.cc) -- TARGET_LINK_LIBRARIES(get-current-bench cpuinfo benchmark) -- ENDIF() -- -- ADD_EXECUTABLE(init-bench bench/init.cc) -- TARGET_LINK_LIBRARIES(init-bench cpuinfo benchmark) --ENDIF() -- --IF(CPUINFO_SUPPORTED_PLATFORM) -- IF(CPUINFO_BUILD_MOCK_TESTS OR CPUINFO_BUILD_UNIT_TESTS) -- # ---[ Build google test -- IF(NOT TARGET gtest) -- IF(MSVC AND NOT CPUINFO_RUNTIME_TYPE STREQUAL "static") -- SET(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -- ENDIF() -- ADD_SUBDIRECTORY( -- "${GOOGLETEST_SOURCE_DIR}" -- "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest") -- ENDIF() -- ENDIF() --ENDIF() -- --# ---[ cpuinfo mock library and mock tests --IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) -- SET(CPUINFO_MOCK_SRCS "${CPUINFO_SRCS}") -- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86_64)$") -- LIST(APPEND CPUINFO_MOCK_SRCS src/x86/mockcpuid.c) -- ENDIF() -- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") -- LIST(APPEND CPUINFO_MOCK_SRCS src/linux/mockfile.c) -- ENDIF() -- -- ADD_LIBRARY(cpuinfo_mock STATIC ${CPUINFO_MOCK_SRCS}) -- CPUINFO_TARGET_ENABLE_C99(cpuinfo_mock) -- CPUINFO_TARGET_RUNTIME_LIBRARY(cpuinfo_mock) -- SET_TARGET_PROPERTIES(cpuinfo_mock PROPERTIES PUBLIC_HEADER include/cpuinfo.h) -- TARGET_INCLUDE_DIRECTORIES(cpuinfo_mock BEFORE PUBLIC include) -- TARGET_INCLUDE_DIRECTORIES(cpuinfo_mock BEFORE PRIVATE src) -- TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PUBLIC CPUINFO_MOCK=1) -- TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE CLOG_LOG_TO_STDIO=1) -- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") -- TARGET_LINK_LIBRARIES(cpuinfo_mock PUBLIC ${CMAKE_THREAD_LIBS_INIT}) -- TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE _GNU_SOURCE=1) -- ENDIF() -- TARGET_LINK_LIBRARIES(cpuinfo_mock PRIVATE clog) -- -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a)$") -- ADD_EXECUTABLE(atm7029b-tablet-test test/mock/atm7029b-tablet.cc) -- TARGET_INCLUDE_DIRECTORIES(atm7029b-tablet-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(atm7029b-tablet-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(atm7029b-tablet-test atm7029b-tablet-test) -- -- ADD_EXECUTABLE(blu-r1-hd-test test/mock/blu-r1-hd.cc) -- TARGET_INCLUDE_DIRECTORIES(blu-r1-hd-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(blu-r1-hd-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(blu-r1-hd-test blu-r1-hd-test) -- -- ADD_EXECUTABLE(galaxy-a3-2016-eu-test test/mock/galaxy-a3-2016-eu.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-a3-2016-eu-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-a3-2016-eu-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-a3-2016-eu-test galaxy-a3-2016-eu-test) -- -- ADD_EXECUTABLE(galaxy-a8-2016-duos-test test/mock/galaxy-a8-2016-duos.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-a8-2016-duos-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-a8-2016-duos-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-a8-2016-duos-test galaxy-a8-2016-duos-test) -- -- ADD_EXECUTABLE(galaxy-grand-prime-value-edition-test test/mock/galaxy-grand-prime-value-edition.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-grand-prime-value-edition-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-grand-prime-value-edition-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-grand-prime-value-edition-test galaxy-grand-prime-value-edition-test) -- -- ADD_EXECUTABLE(galaxy-j1-2016-test test/mock/galaxy-j1-2016.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-j1-2016-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-j1-2016-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-j1-2016-test galaxy-j1-2016-test) -- -- ADD_EXECUTABLE(galaxy-j5-test test/mock/galaxy-j5.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-j5-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-j5-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-j5-test galaxy-j5-test) -- -- ADD_EXECUTABLE(galaxy-j7-prime-test test/mock/galaxy-j7-prime.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-j7-prime-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-j7-prime-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-j7-prime-test galaxy-j7-prime-test) -- -- ADD_EXECUTABLE(galaxy-j7-tmobile-test test/mock/galaxy-j7-tmobile.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-j7-tmobile-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-j7-tmobile-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-j7-tmobile-test galaxy-j7-tmobile-test) -- -- ADD_EXECUTABLE(galaxy-j7-uae-test test/mock/galaxy-j7-uae.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-j7-uae-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-j7-uae-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-j7-uae-test galaxy-j7-uae-test) -- -- ADD_EXECUTABLE(galaxy-s3-us-test test/mock/galaxy-s3-us.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s3-us-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s3-us-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s3-us-test galaxy-s3-us-test) -- -- ADD_EXECUTABLE(galaxy-s4-us-test test/mock/galaxy-s4-us.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s4-us-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s4-us-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s4-us-test galaxy-s4-us-test) -- -- ADD_EXECUTABLE(galaxy-s5-global-test test/mock/galaxy-s5-global.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s5-global-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s5-global-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s5-global-test galaxy-s5-global-test) -- -- ADD_EXECUTABLE(galaxy-s5-us-test test/mock/galaxy-s5-us.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s5-us-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s5-us-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s5-us-test galaxy-s5-us-test) -- -- ADD_EXECUTABLE(galaxy-tab-3-7.0-test test/mock/galaxy-tab-3-7.0.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-tab-3-7.0-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-tab-3-7.0-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-tab-3-7.0-test galaxy-tab-3-7.0-test) -- -- ADD_EXECUTABLE(galaxy-tab-3-lite-test test/mock/galaxy-tab-3-lite.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-tab-3-lite-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-tab-3-lite-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-tab-3-lite-test galaxy-tab-3-lite-test) -- -- ADD_EXECUTABLE(galaxy-win-duos-test test/mock/galaxy-win-duos.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-win-duos-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-win-duos-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-win-duos-test galaxy-win-duos-test) -- -- ADD_EXECUTABLE(huawei-ascend-p7-test test/mock/huawei-ascend-p7.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-ascend-p7-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-ascend-p7-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-ascend-p7-test huawei-ascend-p7-test) -- -- ADD_EXECUTABLE(huawei-honor-6-test test/mock/huawei-honor-6.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-honor-6-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-honor-6-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-honor-6-test huawei-honor-6-test) -- -- ADD_EXECUTABLE(lenovo-a6600-plus-test test/mock/lenovo-a6600-plus.cc) -- TARGET_INCLUDE_DIRECTORIES(lenovo-a6600-plus-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(lenovo-a6600-plus-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(lenovo-a6600-plus-test lenovo-a6600-plus-test) -- -- ADD_EXECUTABLE(lenovo-vibe-x2-test test/mock/lenovo-vibe-x2.cc) -- TARGET_INCLUDE_DIRECTORIES(lenovo-vibe-x2-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(lenovo-vibe-x2-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(lenovo-vibe-x2-test lenovo-vibe-x2-test) -- -- ADD_EXECUTABLE(lg-k10-eu-test test/mock/lg-k10-eu.cc) -- TARGET_INCLUDE_DIRECTORIES(lg-k10-eu-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(lg-k10-eu-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(lg-k10-eu-test lg-k10-eu-test) -- -- ADD_EXECUTABLE(lg-optimus-g-pro-test test/mock/lg-optimus-g-pro.cc) -- TARGET_INCLUDE_DIRECTORIES(lg-optimus-g-pro-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(lg-optimus-g-pro-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(lg-optimus-g-pro-test lg-optimus-g-pro-test) -- -- ADD_EXECUTABLE(moto-e-gen1-test test/mock/moto-e-gen1.cc) -- TARGET_INCLUDE_DIRECTORIES(moto-e-gen1-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(moto-e-gen1-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(moto-e-gen1-test moto-e-gen1-test) -- -- ADD_EXECUTABLE(moto-g-gen1-test test/mock/moto-g-gen1.cc) -- TARGET_INCLUDE_DIRECTORIES(moto-g-gen1-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(moto-g-gen1-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(moto-g-gen1-test moto-g-gen1-test) -- -- ADD_EXECUTABLE(moto-g-gen2-test test/mock/moto-g-gen2.cc) -- TARGET_INCLUDE_DIRECTORIES(moto-g-gen2-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(moto-g-gen2-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(moto-g-gen2-test moto-g-gen2-test) -- -- ADD_EXECUTABLE(moto-g-gen3-test test/mock/moto-g-gen3.cc) -- TARGET_INCLUDE_DIRECTORIES(moto-g-gen3-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(moto-g-gen3-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(moto-g-gen3-test moto-g-gen3-test) -- -- ADD_EXECUTABLE(moto-g-gen4-test test/mock/moto-g-gen4.cc) -- TARGET_INCLUDE_DIRECTORIES(moto-g-gen4-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(moto-g-gen4-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(moto-g-gen4-test moto-g-gen4-test) -- -- ADD_EXECUTABLE(moto-g-gen5-test test/mock/moto-g-gen5.cc) -- TARGET_INCLUDE_DIRECTORIES(moto-g-gen5-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(moto-g-gen5-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(moto-g-gen5-test moto-g-gen5-test) -- -- ADD_EXECUTABLE(nexus-s-test test/mock/nexus-s.cc) -- TARGET_INCLUDE_DIRECTORIES(nexus-s-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(nexus-s-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(nexus-s-test nexus-s-test) -- -- ADD_EXECUTABLE(nexus4-test test/mock/nexus4.cc) -- TARGET_INCLUDE_DIRECTORIES(nexus4-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(nexus4-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(nexus4-test nexus4-test) -- -- ADD_EXECUTABLE(nexus6-test test/mock/nexus6.cc) -- TARGET_INCLUDE_DIRECTORIES(nexus6-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(nexus6-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(nexus6-test nexus6-test) -- -- ADD_EXECUTABLE(nexus10-test test/mock/nexus10.cc) -- TARGET_INCLUDE_DIRECTORIES(nexus10-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(nexus10-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(nexus10-test nexus10-test) -- -- ADD_EXECUTABLE(padcod-10.1-test test/mock/padcod-10.1.cc) -- TARGET_INCLUDE_DIRECTORIES(padcod-10.1-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(padcod-10.1-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(padcod-10.1-test padcod-10.1-test) -- -- ADD_EXECUTABLE(xiaomi-redmi-2a-test test/mock/xiaomi-redmi-2a.cc) -- TARGET_INCLUDE_DIRECTORIES(xiaomi-redmi-2a-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(xiaomi-redmi-2a-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(xiaomi-redmi-2a-test xiaomi-redmi-2a-test) -- -- ADD_EXECUTABLE(xperia-sl-test test/mock/xperia-sl.cc) -- TARGET_INCLUDE_DIRECTORIES(xperia-sl-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(xperia-sl-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(xperia-sl-test xperia-sl-test) -- ENDIF() -- -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") -- ADD_EXECUTABLE(alcatel-revvl-test test/mock/alcatel-revvl.cc) -- TARGET_INCLUDE_DIRECTORIES(alcatel-revvl-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(alcatel-revvl-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(alcatel-revvl-test alcatel-revvl-test) -- -- ADD_EXECUTABLE(galaxy-a8-2018-test test/mock/galaxy-a8-2018.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-a8-2018-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-a8-2018-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-a8-2018-test galaxy-a8-2018-test) -- -- ADD_EXECUTABLE(galaxy-c9-pro-test test/mock/galaxy-c9-pro.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-c9-pro-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-c9-pro-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-c9-pro-test galaxy-c9-pro-test) -- -- ADD_EXECUTABLE(galaxy-s6-test test/mock/galaxy-s6.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s6-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s6-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s6-test galaxy-s6-test) -- -- ADD_EXECUTABLE(galaxy-s7-us-test test/mock/galaxy-s7-us.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s7-us-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s7-us-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s7-us-test galaxy-s7-us-test) -- -- ADD_EXECUTABLE(galaxy-s7-global-test test/mock/galaxy-s7-global.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s7-global-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s7-global-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s7-global-test galaxy-s7-global-test) -- -- ADD_EXECUTABLE(galaxy-s8-us-test test/mock/galaxy-s8-us.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s8-us-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s8-us-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s8-us-test galaxy-s8-us-test) -- -- ADD_EXECUTABLE(galaxy-s8-global-test test/mock/galaxy-s8-global.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s8-global-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s8-global-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s8-global-test galaxy-s8-global-test) -- -- ADD_EXECUTABLE(galaxy-s9-us-test test/mock/galaxy-s9-us.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s9-us-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s9-us-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s9-us-test galaxy-s9-us-test) -- -- ADD_EXECUTABLE(galaxy-s9-global-test test/mock/galaxy-s9-global.cc) -- TARGET_INCLUDE_DIRECTORIES(galaxy-s9-global-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(galaxy-s9-global-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(galaxy-s9-global-test galaxy-s9-global-test) -- -- ADD_EXECUTABLE(huawei-mate-8-test test/mock/huawei-mate-8.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-mate-8-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-mate-8-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-mate-8-test huawei-mate-8-test) -- -- ADD_EXECUTABLE(huawei-mate-9-test test/mock/huawei-mate-9.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-mate-9-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-mate-9-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-mate-9-test huawei-mate-9-test) -- -- ADD_EXECUTABLE(huawei-mate-10-test test/mock/huawei-mate-10.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-mate-10-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-mate-10-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-mate-10-test huawei-mate-10-test) -- -- ADD_EXECUTABLE(huawei-mate-20-test test/mock/huawei-mate-20.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-mate-20-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-mate-20-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-mate-20-test huawei-mate-20-test) -- -- ADD_EXECUTABLE(huawei-p8-lite-test test/mock/huawei-p8-lite.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-p8-lite-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-p8-lite-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-p8-lite-test huawei-p8-lite-test) -- -- ADD_EXECUTABLE(huawei-p9-lite-test test/mock/huawei-p9-lite.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-p9-lite-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-p9-lite-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-p9-lite-test huawei-p9-lite-test) -- -- ADD_EXECUTABLE(huawei-p20-pro-test test/mock/huawei-p20-pro.cc) -- TARGET_INCLUDE_DIRECTORIES(huawei-p20-pro-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(huawei-p20-pro-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(huawei-p20-pro-test huawei-p20-pro-test) -- -- ADD_EXECUTABLE(iconia-one-10-test test/mock/iconia-one-10.cc) -- TARGET_INCLUDE_DIRECTORIES(iconia-one-10-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(iconia-one-10-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(iconia-one-10-test iconia-one-10-test) -- -- ADD_EXECUTABLE(meizu-pro-6-test test/mock/meizu-pro-6.cc) -- TARGET_INCLUDE_DIRECTORIES(meizu-pro-6-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(meizu-pro-6-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(meizu-pro-6-test meizu-pro-6-test) -- -- ADD_EXECUTABLE(meizu-pro-6s-test test/mock/meizu-pro-6s.cc) -- TARGET_INCLUDE_DIRECTORIES(meizu-pro-6s-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(meizu-pro-6s-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(meizu-pro-6s-test meizu-pro-6s-test) -- -- ADD_EXECUTABLE(meizu-pro-7-plus-test test/mock/meizu-pro-7-plus.cc) -- TARGET_INCLUDE_DIRECTORIES(meizu-pro-7-plus-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(meizu-pro-7-plus-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(meizu-pro-7-plus-test meizu-pro-7-plus-test) -- -- ADD_EXECUTABLE(nexus5x-test test/mock/nexus5x.cc) -- TARGET_INCLUDE_DIRECTORIES(nexus5x-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(nexus5x-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(nexus5x-test nexus5x-test) -- -- ADD_EXECUTABLE(nexus6p-test test/mock/nexus6p.cc) -- TARGET_INCLUDE_DIRECTORIES(nexus6p-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(nexus6p-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(nexus6p-test nexus6p-test) -- -- ADD_EXECUTABLE(nexus9-test test/mock/nexus9.cc) -- TARGET_INCLUDE_DIRECTORIES(nexus9-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(nexus9-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(nexus9-test nexus9-test) -- -- ADD_EXECUTABLE(oneplus-3t-test test/mock/oneplus-3t.cc) -- TARGET_INCLUDE_DIRECTORIES(oneplus-3t-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(oneplus-3t-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(oneplus-3t-test oneplus-3t-test) -- -- ADD_EXECUTABLE(oneplus-5-test test/mock/oneplus-5.cc) -- TARGET_INCLUDE_DIRECTORIES(oneplus-5-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(oneplus-5-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(oneplus-5-test oneplus-5-test) -- -- ADD_EXECUTABLE(oneplus-5t-test test/mock/oneplus-5t.cc) -- TARGET_INCLUDE_DIRECTORIES(oneplus-5t-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(oneplus-5t-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(oneplus-5t-test oneplus-5t-test) -- -- ADD_EXECUTABLE(oppo-a37-test test/mock/oppo-a37.cc) -- TARGET_INCLUDE_DIRECTORIES(oppo-a37-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(oppo-a37-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(oppo-a37-test oppo-a37-test) -- -- ADD_EXECUTABLE(oppo-r9-test test/mock/oppo-r9.cc) -- TARGET_INCLUDE_DIRECTORIES(oppo-r9-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(oppo-r9-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(oppo-r9-test oppo-r9-test) -- -- ADD_EXECUTABLE(oppo-r15-test test/mock/oppo-r15.cc) -- TARGET_INCLUDE_DIRECTORIES(oppo-r15-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(oppo-r15-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(oppo-r15-test oppo-r15-test) -- -- ADD_EXECUTABLE(pixel-test test/mock/pixel.cc) -- TARGET_INCLUDE_DIRECTORIES(pixel-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(pixel-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(pixel-test pixel-test) -- -- ADD_EXECUTABLE(pixel-c-test test/mock/pixel-c.cc) -- TARGET_INCLUDE_DIRECTORIES(pixel-c-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(pixel-c-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(pixel-c-test pixel-c-test) -- -- ADD_EXECUTABLE(pixel-xl-test test/mock/pixel-xl.cc) -- TARGET_INCLUDE_DIRECTORIES(pixel-xl-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(pixel-xl-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(pixel-xl-test pixel-xl-test) -- -- ADD_EXECUTABLE(pixel-2-xl-test test/mock/pixel-2-xl.cc) -- TARGET_INCLUDE_DIRECTORIES(pixel-2-xl-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(pixel-2-xl-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(pixel-2-xl-test pixel-2-xl-test) -- -- ADD_EXECUTABLE(xiaomi-mi-5c-test test/mock/xiaomi-mi-5c.cc) -- TARGET_INCLUDE_DIRECTORIES(xiaomi-mi-5c-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(xiaomi-mi-5c-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(xiaomi-mi-5c-test xiaomi-mi-5c-test) -- -- ADD_EXECUTABLE(xiaomi-redmi-note-3-test test/mock/xiaomi-redmi-note-3.cc) -- TARGET_INCLUDE_DIRECTORIES(xiaomi-redmi-note-3-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(xiaomi-redmi-note-3-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(xiaomi-redmi-note-3-test xiaomi-redmi-note-3-test) -- -- ADD_EXECUTABLE(xiaomi-redmi-note-4-test test/mock/xiaomi-redmi-note-4.cc) -- TARGET_INCLUDE_DIRECTORIES(xiaomi-redmi-note-4-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(xiaomi-redmi-note-4-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(xiaomi-redmi-note-4-test xiaomi-redmi-note-4-test) -- -- ADD_EXECUTABLE(xperia-c4-dual-test test/mock/xperia-c4-dual.cc) -- TARGET_INCLUDE_DIRECTORIES(xperia-c4-dual-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(xperia-c4-dual-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(xperia-c4-dual-test xperia-c4-dual-test) -- ENDIF() -- -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64)$") -- ADD_EXECUTABLE(alldocube-iwork8-test test/mock/alldocube-iwork8.cc) -- TARGET_INCLUDE_DIRECTORIES(alldocube-iwork8-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(alldocube-iwork8-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(alldocube-iwork8-test alldocube-iwork8-test) -- -- ADD_EXECUTABLE(leagoo-t5c-test test/mock/leagoo-t5c.cc) -- TARGET_INCLUDE_DIRECTORIES(leagoo-t5c-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(leagoo-t5c-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(leagoo-t5c-test leagoo-t5c-test) -- -- ADD_EXECUTABLE(memo-pad-7-test test/mock/memo-pad-7.cc) -- TARGET_INCLUDE_DIRECTORIES(memo-pad-7-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(memo-pad-7-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(memo-pad-7-test memo-pad-7-test) -- -- ADD_EXECUTABLE(zenfone-c-test test/mock/zenfone-c.cc) -- TARGET_INCLUDE_DIRECTORIES(zenfone-c-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(zenfone-c-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(zenfone-c-test zenfone-c-test) -- -- ADD_EXECUTABLE(zenfone-2-test test/mock/zenfone-2.cc) -- TARGET_INCLUDE_DIRECTORIES(zenfone-2-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(zenfone-2-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(zenfone-2-test zenfone-2-test) -- -- ADD_EXECUTABLE(zenfone-2e-test test/mock/zenfone-2e.cc) -- TARGET_INCLUDE_DIRECTORIES(zenfone-2e-test BEFORE PRIVATE test/mock) -- TARGET_LINK_LIBRARIES(zenfone-2e-test PRIVATE cpuinfo_mock gtest) -- ADD_TEST(zenfone-2e-test zenfone-2e-test) -- ENDIF() --ENDIF() -- --# ---[ cpuinfo unit tests --IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_UNIT_TESTS) -- ADD_EXECUTABLE(init-test test/init.cc) -- CPUINFO_TARGET_ENABLE_CXX11(init-test) -- CPUINFO_TARGET_RUNTIME_LIBRARY(init-test) -- TARGET_LINK_LIBRARIES(init-test PRIVATE cpuinfo gtest gtest_main) -- ADD_TEST(init-test init-test) -- -- IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") -- ADD_EXECUTABLE(get-current-test test/get-current.cc) -- CPUINFO_TARGET_ENABLE_CXX11(get-current-test) -- CPUINFO_TARGET_RUNTIME_LIBRARY(get-current-test) -- TARGET_LINK_LIBRARIES(get-current-test PRIVATE cpuinfo gtest gtest_main) -- ADD_TEST(get-current-test get-current-test) -- ENDIF() -- -- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86_64)$") -- ADD_EXECUTABLE(brand-string-test test/name/brand-string.cc) -- CPUINFO_TARGET_ENABLE_CXX11(brand-string-test) -- CPUINFO_TARGET_RUNTIME_LIBRARY(brand-string-test) -- TARGET_LINK_LIBRARIES(brand-string-test PRIVATE cpuinfo_internals gtest gtest_main) -- ADD_TEST(brand-string-test brand-string-test) -- ENDIF() -- -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") -- ADD_LIBRARY(android_properties_interface STATIC test/name/android-properties-interface.c) -- CPUINFO_TARGET_ENABLE_C99(android_properties_interface) -- CPUINFO_TARGET_RUNTIME_LIBRARY(android_properties_interface) -- TARGET_LINK_LIBRARIES(android_properties_interface PRIVATE cpuinfo_internals) -- -- ADD_EXECUTABLE(chipset-test -- test/name/proc-cpuinfo-hardware.cc -- test/name/ro-product-board.cc -- test/name/ro-board-platform.cc -- test/name/ro-mediatek-platform.cc -- test/name/ro-arch.cc -- test/name/ro-chipname.cc -- test/name/android-properties.cc) -- CPUINFO_TARGET_ENABLE_CXX11(chipset-test) -- CPUINFO_TARGET_RUNTIME_LIBRARY(chipset-test) -- TARGET_LINK_LIBRARIES(chipset-test PRIVATE android_properties_interface gtest gtest_main) -- ADD_TEST(chipset-test chipset-test) -- -- ADD_EXECUTABLE(cache-test test/arm-cache.cc) -- CPUINFO_TARGET_ENABLE_CXX11(cache-test) -- CPUINFO_TARGET_RUNTIME_LIBRARY(cache-test) -- TARGET_COMPILE_DEFINITIONS(cache-test PRIVATE __STDC_LIMIT_MACROS=1 __STDC_CONSTANT_MACROS=1) -- TARGET_LINK_LIBRARIES(cache-test PRIVATE cpuinfo_internals gtest gtest_main) -- ADD_TEST(cache-test, cache-test) -- ENDIF() --ENDIF() -- --# ---[ Helper and debug tools --IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_TOOLS) -- ADD_EXECUTABLE(isa-info tools/isa-info.c) -- CPUINFO_TARGET_ENABLE_C99(isa-info) -- CPUINFO_TARGET_RUNTIME_LIBRARY(isa-info) -- TARGET_LINK_LIBRARIES(isa-info PRIVATE cpuinfo) -- INSTALL(TARGETS isa-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) -- -- ADD_EXECUTABLE(cpu-info tools/cpu-info.c) -- CPUINFO_TARGET_ENABLE_C99(cpu-info) -- CPUINFO_TARGET_RUNTIME_LIBRARY(cpu-info) -- TARGET_LINK_LIBRARIES(cpu-info PRIVATE cpuinfo) -- INSTALL(TARGETS cpu-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) -- -- ADD_EXECUTABLE(cache-info tools/cache-info.c) -- CPUINFO_TARGET_ENABLE_C99(cache-info) -- CPUINFO_TARGET_RUNTIME_LIBRARY(cache-info) -- TARGET_LINK_LIBRARIES(cache-info PRIVATE cpuinfo) -- INSTALL(TARGETS cache-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) -- -- IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") -- ADD_EXECUTABLE(auxv-dump tools/auxv-dump.c) -- CPUINFO_TARGET_ENABLE_C99(auxv-dump) -- CPUINFO_TARGET_RUNTIME_LIBRARY(auxv-dump) -- TARGET_LINK_LIBRARIES(auxv-dump PRIVATE ${CMAKE_DL_LIBS} cpuinfo) -- -- ADD_EXECUTABLE(cpuinfo-dump tools/cpuinfo-dump.c) -- CPUINFO_TARGET_ENABLE_C99(cpuinfo-dump) -- CPUINFO_TARGET_RUNTIME_LIBRARY(cpuinfo-dump) -- ENDIF() -- -- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86_64)$") -- ADD_EXECUTABLE(cpuid-dump tools/cpuid-dump.c) -- CPUINFO_TARGET_ENABLE_C99(cpuid-dump) -- CPUINFO_TARGET_RUNTIME_LIBRARY(cpuid-dump) -- TARGET_INCLUDE_DIRECTORIES(cpuid-dump BEFORE PRIVATE src) -- TARGET_INCLUDE_DIRECTORIES(cpuid-dump BEFORE PRIVATE include) -- INSTALL(TARGETS cpuid-dump RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) -- ENDIF() --ENDIF() -diff --git a/README.md b/README.md -index 7d383ff..ee5fb82 100644 ---- a/README.md -+++ b/README.md -@@ -152,21 +152,20 @@ pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpu_set); - - [x] Using `ro.chipname`, `ro.board.platform`, `ro.product.board`, `ro.mediatek.platform`, `ro.arch` properties (Android) - - [ ] Using kernel log (`dmesg`) on ARM Linux - - Vendor and microarchitecture detection -- - [x] Intel-designed x86/x86-64 cores (up to Kaby Lake, Airmont, and Knights Mill) -- - [x] AMD-designed x86/x86-64 cores (up to Puma/Jaguar and Zen) -+ - [x] Intel-designed x86/x86-64 cores (up to Sunny Cove, Goldmont Plus, and Knights Mill) -+ - [x] AMD-designed x86/x86-64 cores (up to Puma/Jaguar and Zen 2) - - [ ] VIA-designed x86/x86-64 cores - - [ ] Other x86 cores (DM&P, RDC, Transmeta, Cyrix, Rise) -- - [x] ARM-designed ARM cores (up to Cortex-A55 and Cortex-A75) -- - [x] Qualcomm-designed ARM cores (up to Kryo, Kryo-280, and Kryo-385) -- - [x] Nvidia-designed ARM cores (Denver) -+ - [x] ARM-designed ARM cores (up to Cortex-A55, Cortex-A77, and Neoverse E1/N1) -+ - [x] Qualcomm-designed ARM cores (Scorpion, Krait, and Kryo) -+ - [x] Nvidia-designed ARM cores (Denver and Carmel) - - [x] Samsung-designed ARM cores (Exynos) - - [x] Intel-designed ARM cores (XScale up to 3rd-gen) -- - [x] Apple-designed ARM cores (up to Hurricane) -+ - [x] Apple-designed ARM cores (up to Lightning and Thunder) - - [x] Cavium-designed ARM cores (ThunderX) - - [x] AppliedMicro-designed ARM cores (X-Gene) - - Instruction set detection - - [x] Using CPUID (x86/x86-64) -- - [x] Using dynamic code generation validator (Native Client/x86-64) - - [x] Using `/proc/cpuinfo` on 32-bit ARM EABI (Linux) - - [x] Using microarchitecture heuristics on (32-bit ARM) - - [x] Using `FPSID` and `WCID` registers (32-bit ARM) -diff --git a/bench/get-current.cc b/bench/get-current.cc -index 91b35a0..b547df0 100644 ---- a/bench/get-current.cc -+++ b/bench/get-current.cc -@@ -21,4 +21,13 @@ static void cpuinfo_get_current_core(benchmark::State& state) { - } - BENCHMARK(cpuinfo_get_current_core)->Unit(benchmark::kNanosecond); - -+static void cpuinfo_get_current_uarch_index(benchmark::State& state) { -+ cpuinfo_initialize(); -+ while (state.KeepRunning()) { -+ const uint32_t uarch_index = cpuinfo_get_current_uarch_index(); -+ benchmark::DoNotOptimize(uarch_index); -+ } -+} -+BENCHMARK(cpuinfo_get_current_uarch_index)->Unit(benchmark::kNanosecond); -+ - BENCHMARK_MAIN(); -diff --git a/cmake/DownloadGoogleTest.cmake b/cmake/DownloadGoogleTest.cmake -index d69d19a..dc86c9c 100644 ---- a/cmake/DownloadGoogleTest.cmake -+++ b/cmake/DownloadGoogleTest.cmake -@@ -4,8 +4,8 @@ PROJECT(googletest-download NONE) - - INCLUDE(ExternalProject) - ExternalProject_Add(googletest -- URL https://github.com/google/googletest/archive/release-1.8.0.zip -- URL_HASH SHA256=f3ed3b58511efd272eb074a3a6d6fb79d7c2e6a0e374323d1e6bcbcc1ef141bf -+ URL https://github.com/google/googletest/archive/release-1.10.0.zip -+ URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91 - SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest" - BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest" - CONFIGURE_COMMAND "" -diff --git a/configure.py b/configure.py -index a340c4c..0e58dba 100755 ---- a/configure.py -+++ b/configure.py -@@ -26,8 +26,8 @@ def main(args): - sources = ["init.c", "api.c"] - if build.target.is_x86 or build.target.is_x86_64: - sources += [ -- "x86/init.c", "x86/info.c", "x86/vendor.c", "x86/uarch.c", "x86/name.c", -- "x86/topology.c", -+ "x86/init.c", "x86/info.c", "x86/isa.c", "x86/vendor.c", -+ "x86/uarch.c", "x86/name.c", "x86/topology.c", - "x86/cache/init.c", "x86/cache/descriptor.c", "x86/cache/deterministic.c", - ] - if build.target.is_macos: -@@ -37,7 +37,6 @@ def main(args): - "x86/linux/init.c", - "x86/linux/cpuinfo.c", - ] -- sources.append("x86/isa.c" if not build.target.is_nacl else "x86/nacl/isa.c") - if build.target.is_arm or build.target.is_arm64: - sources += ["arm/uarch.c", "arm/cache.c"] - if build.target.is_linux or build.target.is_android: -diff --git a/include/cpuinfo.h b/include/cpuinfo.h -index 9938d2b..e4d2d0c 100644 ---- a/include/cpuinfo.h -+++ b/include/cpuinfo.h -@@ -34,10 +34,6 @@ - #define CPUINFO_ARCH_PPC64 1 - #endif - --#if defined(__pnacl__) -- #define CPUINFO_ARCH_PNACL 1 --#endif -- - #if defined(__asmjs__) - #define CPUINFO_ARCH_ASMJS 1 - #endif -@@ -80,10 +76,6 @@ - #define CPUINFO_ARCH_PPC64 0 - #endif - --#ifndef CPUINFO_ARCH_PNACL -- #define CPUINFO_ARCH_PNACL 0 --#endif -- - #ifndef CPUINFO_ARCH_ASMJS - #define CPUINFO_ARCH_ASMJS 0 - #endif -@@ -190,6 +182,12 @@ enum cpuinfo_vendor { - * Processors are designed by HiSilicon, a subsidiary of Huawei. - */ - cpuinfo_vendor_huawei = 15, -+ /** -+ * Hygon (Chengdu Haiguang Integrated Circuit Design Co., Ltd), Vendor of x86-64 processor microarchitectures. -+ * -+ * Processors are variants of AMD cores. -+ */ -+ cpuinfo_vendor_hygon = 16, - - /* Active vendors of embedded CPUs */ - -@@ -401,6 +399,8 @@ enum cpuinfo_uarch { - cpuinfo_uarch_cortex_a35 = 0x00300335, - /** ARM Cortex-A53. */ - cpuinfo_uarch_cortex_a53 = 0x00300353, -+ /** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities compared to revision 1+). */ -+ cpuinfo_uarch_cortex_a55r0 = 0x00300354, - /** ARM Cortex-A55. */ - cpuinfo_uarch_cortex_a55 = 0x00300355, - /** ARM Cortex-A57. */ -@@ -478,6 +478,10 @@ enum cpuinfo_uarch { - cpuinfo_uarch_vortex = 0x00700107, - /** Apple A12 processor (little cores). */ - cpuinfo_uarch_tempest = 0x00700108, -+ /** Apple A13 processor (big cores). */ -+ cpuinfo_uarch_lightning = 0x00700109, -+ /** Apple A13 processor (little cores). */ -+ cpuinfo_uarch_thunder = 0x0070010A, - - /** Cavium ThunderX. */ - cpuinfo_uarch_thunderx = 0x00800100, -@@ -494,6 +498,9 @@ enum cpuinfo_uarch { - - /** Applied Micro X-Gene. */ - cpuinfo_uarch_xgene = 0x00B00100, -+ -+ /* Hygon Dhyana (a modification of AMD Zen for Chinese market). */ -+ cpuinfo_uarch_dhyana = 0x01000100, - }; - - struct cpuinfo_processor { -@@ -613,6 +620,22 @@ struct cpuinfo_package { - uint32_t cluster_count; - }; - -+struct cpuinfo_uarch_info { -+ /** Type of CPU microarchitecture */ -+ enum cpuinfo_uarch uarch; -+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 -+ /** Value of CPUID leaf 1 EAX register for the microarchitecture */ -+ uint32_t cpuid; -+#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 -+ /** Value of Main ID Register (MIDR) for the microarchitecture */ -+ uint32_t midr; -+#endif -+ /** Number of logical processors with the microarchitecture */ -+ uint32_t processor_count; -+ /** Number of cores with the microarchitecture */ -+ uint32_t core_count; -+}; -+ - #ifdef __cplusplus - extern "C" { - #endif -@@ -1721,6 +1744,7 @@ const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processors(void); - const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_cores(void); - const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_clusters(void); - const struct cpuinfo_package* CPUINFO_ABI cpuinfo_get_packages(void); -+const struct cpuinfo_uarch_info* CPUINFO_ABI cpuinfo_get_uarchs(void); - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_caches(void); - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_caches(void); - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_caches(void); -@@ -1731,6 +1755,7 @@ const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processor(uint32_t index - const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_core(uint32_t index); - const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_cluster(uint32_t index); - const struct cpuinfo_package* CPUINFO_ABI cpuinfo_get_package(uint32_t index); -+const struct cpuinfo_uarch_info* CPUINFO_ABI cpuinfo_get_uarch(uint32_t index); - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_cache(uint32_t index); - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_cache(uint32_t index); - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_cache(uint32_t index); -@@ -1741,6 +1766,7 @@ uint32_t CPUINFO_ABI cpuinfo_get_processors_count(void); - uint32_t CPUINFO_ABI cpuinfo_get_cores_count(void); - uint32_t CPUINFO_ABI cpuinfo_get_clusters_count(void); - uint32_t CPUINFO_ABI cpuinfo_get_packages_count(void); -+uint32_t CPUINFO_ABI cpuinfo_get_uarchs_count(void); - uint32_t CPUINFO_ABI cpuinfo_get_l1i_caches_count(void); - uint32_t CPUINFO_ABI cpuinfo_get_l1d_caches_count(void); - uint32_t CPUINFO_ABI cpuinfo_get_l2_caches_count(void); -@@ -1752,9 +1778,31 @@ uint32_t CPUINFO_ABI cpuinfo_get_l4_caches_count(void); - */ - uint32_t CPUINFO_ABI cpuinfo_get_max_cache_size(void); - -+/** -+ * Identify the logical processor that executes the current thread. -+ * -+ * There is no guarantee that the thread will stay on the same logical processor for any time. -+ * Callers should treat the result as only a hint, and be prepared to handle NULL return value. -+ */ - const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_current_processor(void); -+ -+/** -+ * Identify the core that executes the current thread. -+ * -+ * There is no guarantee that the thread will stay on the same core for any time. -+ * Callers should treat the result as only a hint, and be prepared to handle NULL return value. -+ */ - const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_current_core(void); - -+/** -+ * Identify the microarchitecture index of the core that executes the current thread. -+ * If the system does not support such identification, the function return 0. -+ * -+ * There is no guarantee that the thread will stay on the same type of core for any time. -+ * Callers should treat the result as only a hint. -+ */ -+uint32_t CPUINFO_ABI cpuinfo_get_current_uarch_index(void); -+ - #ifdef __cplusplus - } /* extern "C" */ - #endif -diff --git a/src/api.c b/src/api.c -index b180d80..0cc5d4e 100644 ---- a/src/api.c -+++ b/src/api.c -@@ -1,9 +1,16 @@ -+#include - #include - - #include - #include - #include - -+#ifdef __linux__ -+ #include -+ -+ #include -+ #include -+#endif - - bool cpuinfo_is_initialized = false; - -@@ -20,235 +27,347 @@ uint32_t cpuinfo_packages_count = 0; - uint32_t cpuinfo_cache_count[cpuinfo_cache_level_max] = { 0 }; - uint32_t cpuinfo_max_cache_size = 0; - -+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 -+ struct cpuinfo_uarch_info* cpuinfo_uarchs = NULL; -+ uint32_t cpuinfo_uarchs_count = 0; -+#else -+ struct cpuinfo_uarch_info cpuinfo_global_uarch = { cpuinfo_uarch_unknown }; -+#endif -+ -+#ifdef __linux__ -+ uint32_t cpuinfo_linux_cpu_max = 0; -+ const struct cpuinfo_processor** cpuinfo_linux_cpu_to_processor_map = NULL; -+ const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map = NULL; -+ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 -+ const uint32_t* cpuinfo_linux_cpu_to_uarch_index_map = NULL; -+ #endif -+#endif -+ - - const struct cpuinfo_processor* cpuinfo_get_processors(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "processors"); - } - return cpuinfo_processors; - } - - const struct cpuinfo_core* cpuinfo_get_cores(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "core"); - } - return cpuinfo_cores; - } - - const struct cpuinfo_cluster* cpuinfo_get_clusters(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "clusters"); - } - return cpuinfo_clusters; - } - - const struct cpuinfo_package* cpuinfo_get_packages(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "packages"); - } - return cpuinfo_packages; - } - --const struct cpuinfo_processor* cpuinfo_get_processor(uint32_t index) { -+const struct cpuinfo_uarch_info* cpuinfo_get_uarchs() { - if (!cpuinfo_is_initialized) { -+ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "uarchs"); -+ } -+ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 -+ return cpuinfo_uarchs; -+ #else -+ return &cpuinfo_global_uarch; -+ #endif -+} -+ -+const struct cpuinfo_processor* cpuinfo_get_processor(uint32_t index) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "processor"); - } -- if (index < cpuinfo_processors_count) { -- return cpuinfo_processors + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_processors_count) { - return NULL; - } -+ return &cpuinfo_processors[index]; - } - - const struct cpuinfo_core* cpuinfo_get_core(uint32_t index) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "core"); - } -- if (index < cpuinfo_cores_count) { -- return cpuinfo_cores + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_cores_count) { - return NULL; - } -+ return &cpuinfo_cores[index]; - } - - const struct cpuinfo_cluster* cpuinfo_get_cluster(uint32_t index) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "cluster"); - } -- if (index < cpuinfo_clusters_count) { -- return cpuinfo_clusters + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_clusters_count) { - return NULL; - } -+ return &cpuinfo_clusters[index]; - } - - const struct cpuinfo_package* cpuinfo_get_package(uint32_t index) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "package"); - } -- if (index < cpuinfo_packages_count) { -- return cpuinfo_packages + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_packages_count) { - return NULL; - } -+ return &cpuinfo_packages[index]; - } - --uint32_t cpuinfo_get_processors_count(void) { -+const struct cpuinfo_uarch_info* cpuinfo_get_uarch(uint32_t index) { - if (!cpuinfo_is_initialized) { -+ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "uarch"); -+ } -+ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 -+ if CPUINFO_UNLIKELY(index >= cpuinfo_uarchs_count) { -+ return NULL; -+ } -+ return &cpuinfo_uarchs[index]; -+ #else -+ if CPUINFO_UNLIKELY(index != 0) { -+ return NULL; -+ } -+ return &cpuinfo_global_uarch; -+ #endif -+} -+ -+uint32_t cpuinfo_get_processors_count(void) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "processors_count"); - } - return cpuinfo_processors_count; - } - - uint32_t cpuinfo_get_cores_count(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "cores_count"); - } - return cpuinfo_cores_count; - } - - uint32_t cpuinfo_get_clusters_count(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "clusters_count"); - } - return cpuinfo_clusters_count; - } - - uint32_t cpuinfo_get_packages_count(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "packages_count"); - } - return cpuinfo_packages_count; - } - --const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_caches(void) { -+uint32_t cpuinfo_get_uarchs_count(void) { - if (!cpuinfo_is_initialized) { -+ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "uarchs_count"); -+ } -+ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 -+ return cpuinfo_uarchs_count; -+ #else -+ return 1; -+ #endif -+} -+ -+const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_caches(void) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1i_caches"); - } - return cpuinfo_cache[cpuinfo_cache_level_1i]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_caches(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1d_caches"); - } - return cpuinfo_cache[cpuinfo_cache_level_1d]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_caches(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l2_caches"); - } - return cpuinfo_cache[cpuinfo_cache_level_2]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l3_caches(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l3_caches"); - } - return cpuinfo_cache[cpuinfo_cache_level_3]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l4_caches(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l4_caches"); - } - return cpuinfo_cache[cpuinfo_cache_level_4]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_cache(uint32_t index) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1i_cache"); - } -- if (index < cpuinfo_cache_count[cpuinfo_cache_level_1i]) { -- return cpuinfo_cache[cpuinfo_cache_level_1i] + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_1i]) { - return NULL; - } -+ return &cpuinfo_cache[cpuinfo_cache_level_1i][index]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_cache(uint32_t index) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1d_cache"); - } -- if (index < cpuinfo_cache_count[cpuinfo_cache_level_1d]) { -- return cpuinfo_cache[cpuinfo_cache_level_1d] + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_1d]) { - return NULL; - } -+ return &cpuinfo_cache[cpuinfo_cache_level_1d][index]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_cache(uint32_t index) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l2_cache"); - } -- if (index < cpuinfo_cache_count[cpuinfo_cache_level_2]) { -- return cpuinfo_cache[cpuinfo_cache_level_2] + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_2]) { - return NULL; - } -+ return &cpuinfo_cache[cpuinfo_cache_level_2][index]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l3_cache(uint32_t index) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l3_cache"); - } -- if (index < cpuinfo_cache_count[cpuinfo_cache_level_3]) { -- return cpuinfo_cache[cpuinfo_cache_level_3] + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_3]) { - return NULL; - } -+ return &cpuinfo_cache[cpuinfo_cache_level_3][index]; - } - - const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l4_cache(uint32_t index) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l4_cache"); - } -- if (index < cpuinfo_cache_count[cpuinfo_cache_level_4]) { -- return cpuinfo_cache[cpuinfo_cache_level_4] + index; -- } else { -+ if CPUINFO_UNLIKELY(index >= cpuinfo_cache_count[cpuinfo_cache_level_4]) { - return NULL; - } -+ return &cpuinfo_cache[cpuinfo_cache_level_4][index]; - } - - uint32_t CPUINFO_ABI cpuinfo_get_l1i_caches_count(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1i_caches_count"); - } - return cpuinfo_cache_count[cpuinfo_cache_level_1i]; - } - - uint32_t CPUINFO_ABI cpuinfo_get_l1d_caches_count(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l1d_caches_count"); - } - return cpuinfo_cache_count[cpuinfo_cache_level_1d]; - } - - uint32_t CPUINFO_ABI cpuinfo_get_l2_caches_count(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l2_caches_count"); - } - return cpuinfo_cache_count[cpuinfo_cache_level_2]; - } - - uint32_t CPUINFO_ABI cpuinfo_get_l3_caches_count(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l3_caches_count"); - } - return cpuinfo_cache_count[cpuinfo_cache_level_3]; - } - - uint32_t CPUINFO_ABI cpuinfo_get_l4_caches_count(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "l4_caches_count"); - } - return cpuinfo_cache_count[cpuinfo_cache_level_4]; - } - - uint32_t CPUINFO_ABI cpuinfo_get_max_cache_size(void) { -- if (!cpuinfo_is_initialized) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { - cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "max_cache_size"); - } - return cpuinfo_max_cache_size; - } -+ -+const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_current_processor(void) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { -+ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_processor"); -+ } -+ #ifdef __linux__ -+ unsigned cpu; -+ if CPUINFO_UNLIKELY(syscall(__NR_getcpu, &cpu, NULL, NULL) != 0) { -+ return 0; -+ } -+ if CPUINFO_UNLIKELY((uint32_t) cpu >= cpuinfo_linux_cpu_max) { -+ return 0; -+ } -+ return cpuinfo_linux_cpu_to_processor_map[cpu]; -+ #else -+ return NULL; -+ #endif -+} -+ -+const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_current_core(void) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { -+ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_core"); -+ } -+ #ifdef __linux__ -+ unsigned cpu; -+ if CPUINFO_UNLIKELY(syscall(__NR_getcpu, &cpu, NULL, NULL) != 0) { -+ return 0; -+ } -+ if CPUINFO_UNLIKELY((uint32_t) cpu >= cpuinfo_linux_cpu_max) { -+ return 0; -+ } -+ return cpuinfo_linux_cpu_to_core_map[cpu]; -+ #else -+ return NULL; -+ #endif -+} -+ -+uint32_t CPUINFO_ABI cpuinfo_get_current_uarch_index(void) { -+ if CPUINFO_UNLIKELY(!cpuinfo_is_initialized) { -+ cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_uarch_index"); -+ } -+ #if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 -+ #ifdef __linux__ -+ if (cpuinfo_linux_cpu_to_uarch_index_map == NULL) { -+ /* Special case: avoid syscall on systems with only a single type of cores */ -+ return 0; -+ } -+ -+ /* General case */ -+ unsigned cpu; -+ if CPUINFO_UNLIKELY(syscall(__NR_getcpu, &cpu, NULL, NULL) != 0) { -+ return 0; -+ } -+ if CPUINFO_UNLIKELY((uint32_t) cpu >= cpuinfo_linux_cpu_max) { -+ return 0; -+ } -+ return cpuinfo_linux_cpu_to_uarch_index_map[cpu]; -+ #else -+ /* Fallback: pretend to be on the big core. */ -+ return 0; -+ #endif -+ #else -+ /* Only ARM/ARM64 processors may include cores of different types in the same package. */ -+ return 0; -+ #endif -+} -diff --git a/src/arm/cache.c b/src/arm/cache.c -index ccadeb4..c2bc7d2 100644 ---- a/src/arm/cache.c -+++ b/src/arm/cache.c -@@ -659,6 +659,7 @@ void cpuinfo_arm_decode_cache( - }; - } - break; -+ case cpuinfo_uarch_cortex_a55r0: - case cpuinfo_uarch_cortex_a55: - /* - * ARM Cortex-A55 Core Technical Reference Manual -diff --git a/src/arm/linux/api.h b/src/arm/linux/api.h -index 275d072..f99da66 100644 ---- a/src/arm/linux/api.h -+++ b/src/arm/linux/api.h -@@ -153,6 +153,7 @@ struct cpuinfo_arm_linux_processor { - uint32_t midr; - enum cpuinfo_vendor vendor; - enum cpuinfo_uarch uarch; -+ uint32_t uarch_index; - /** - * ID of the physical package which includes this logical processor. - * The value is parsed from /sys/devices/system/cpu/cpu/topology/physical_package_id -@@ -346,3 +347,6 @@ CPUINFO_INTERNAL uint32_t cpuinfo_arm_linux_detect_cluster_midr( - uint32_t max_processors, - uint32_t usable_processors, - struct cpuinfo_arm_linux_processor processors[restrict static max_processors]); -+ -+extern CPUINFO_INTERNAL const uint32_t* cpuinfo_linux_cpu_to_uarch_index_map; -+extern CPUINFO_INTERNAL uint32_t cpuinfo_linux_cpu_to_uarch_index_map_entries; -diff --git a/src/arm/linux/init.c b/src/arm/linux/init.c -index f0c432c..6272abf 100644 ---- a/src/arm/linux/init.c -+++ b/src/arm/linux/init.c -@@ -106,12 +106,14 @@ void cpuinfo_arm_linux_init(void) { - struct cpuinfo_processor* processors = NULL; - struct cpuinfo_core* cores = NULL; - struct cpuinfo_cluster* clusters = NULL; -- const struct cpuinfo_processor** linux_cpu_to_processor_map = NULL; -- const struct cpuinfo_core** linux_cpu_to_core_map = NULL; -+ struct cpuinfo_uarch_info* uarchs = NULL; - struct cpuinfo_cache* l1i = NULL; - struct cpuinfo_cache* l1d = NULL; - struct cpuinfo_cache* l2 = NULL; - struct cpuinfo_cache* l3 = NULL; -+ const struct cpuinfo_processor** linux_cpu_to_processor_map = NULL; -+ const struct cpuinfo_core** linux_cpu_to_core_map = NULL; -+ uint32_t* linux_cpu_to_uarch_index_map = NULL; - - const uint32_t max_processors_count = cpuinfo_linux_get_max_processors_count(); - cpuinfo_log_debug("system maximum processors count: %"PRIu32, max_processors_count); -@@ -400,6 +402,18 @@ void cpuinfo_arm_linux_init(void) { - } - } - -+ uint32_t uarchs_count = 0; -+ enum cpuinfo_uarch last_uarch; -+ for (uint32_t i = 0; i < arm_linux_processors_count; i++) { -+ if (bitmask_all(arm_linux_processors[i].flags, CPUINFO_LINUX_FLAG_VALID)) { -+ if (uarchs_count == 0 || arm_linux_processors[i].uarch != last_uarch) { -+ last_uarch = arm_linux_processors[i].uarch; -+ uarchs_count += 1; -+ } -+ arm_linux_processors[i].uarch_index = uarchs_count - 1; -+ } -+ } -+ - /* - * Assumptions: - * - No SMP (i.e. each core supports only one hardware thread). -@@ -432,6 +446,13 @@ void cpuinfo_arm_linux_init(void) { - goto cleanup; - } - -+ uarchs = calloc(uarchs_count, sizeof(struct cpuinfo_uarch_info)); -+ if (uarchs == NULL) { -+ cpuinfo_log_error("failed to allocate %zu bytes for descriptions of %"PRIu32" microarchitectures", -+ uarchs_count * sizeof(struct cpuinfo_uarch_info), uarchs_count); -+ goto cleanup; -+ } -+ - linux_cpu_to_processor_map = calloc(arm_linux_processors_count, sizeof(struct cpuinfo_processor*)); - if (linux_cpu_to_processor_map == NULL) { - cpuinfo_log_error("failed to allocate %zu bytes for %"PRIu32" logical processor mapping entries", -@@ -446,6 +467,15 @@ void cpuinfo_arm_linux_init(void) { - goto cleanup; - } - -+ if (uarchs_count > 1) { -+ linux_cpu_to_uarch_index_map = calloc(arm_linux_processors_count, sizeof(uint32_t)); -+ if (linux_cpu_to_uarch_index_map == NULL) { -+ cpuinfo_log_error("failed to allocate %zu bytes for %"PRIu32" uarch index mapping entries", -+ arm_linux_processors_count * sizeof(uint32_t), arm_linux_processors_count); -+ goto cleanup; -+ } -+ } -+ - l1i = calloc(valid_processors, sizeof(struct cpuinfo_cache)); - if (l1i == NULL) { - cpuinfo_log_error("failed to allocate %zu bytes for descriptions of %"PRIu32" L1I caches", -@@ -460,6 +490,22 @@ void cpuinfo_arm_linux_init(void) { - goto cleanup; - } - -+ uint32_t uarchs_index = 0; -+ for (uint32_t i = 0; i < arm_linux_processors_count; i++) { -+ if (bitmask_all(arm_linux_processors[i].flags, CPUINFO_LINUX_FLAG_VALID)) { -+ if (uarchs_index == 0 || arm_linux_processors[i].uarch != last_uarch) { -+ last_uarch = arm_linux_processors[i].uarch; -+ uarchs[uarchs_index] = (struct cpuinfo_uarch_info) { -+ .uarch = arm_linux_processors[i].uarch, -+ .midr = arm_linux_processors[i].midr, -+ }; -+ uarchs_index += 1; -+ } -+ uarchs[uarchs_index - 1].processor_count += 1; -+ uarchs[uarchs_index - 1].core_count += 1; -+ } -+ } -+ - uint32_t l2_count = 0, l3_count = 0, big_l3_size = 0, cluster_id = UINT32_MAX; - /* Indication whether L3 (if it exists) is shared between all cores */ - bool shared_l3 = true; -@@ -499,6 +545,11 @@ void cpuinfo_arm_linux_init(void) { - cores[i].midr = arm_linux_processors[i].midr; - linux_cpu_to_core_map[arm_linux_processors[i].system_processor_id] = &cores[i]; - -+ if (linux_cpu_to_uarch_index_map != NULL) { -+ linux_cpu_to_uarch_index_map[arm_linux_processors[i].system_processor_id] = -+ arm_linux_processors[i].uarch_index; -+ } -+ - struct cpuinfo_cache temp_l2 = { 0 }, temp_l3 = { 0 }; - cpuinfo_arm_decode_cache( - arm_linux_processors[i].uarch, -@@ -658,12 +709,11 @@ void cpuinfo_arm_linux_init(void) { - } - - /* Commit */ -- cpuinfo_linux_cpu_to_processor_map = linux_cpu_to_processor_map; -- cpuinfo_linux_cpu_to_core_map = linux_cpu_to_core_map; - cpuinfo_processors = processors; - cpuinfo_cores = cores; - cpuinfo_clusters = clusters; - cpuinfo_packages = &package; -+ cpuinfo_uarchs = uarchs; - cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; - cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; - cpuinfo_cache[cpuinfo_cache_level_2] = l2; -@@ -673,33 +723,42 @@ void cpuinfo_arm_linux_init(void) { - cpuinfo_cores_count = valid_processors; - cpuinfo_clusters_count = cluster_count; - cpuinfo_packages_count = 1; -+ cpuinfo_uarchs_count = uarchs_count; - cpuinfo_cache_count[cpuinfo_cache_level_1i] = valid_processors; - cpuinfo_cache_count[cpuinfo_cache_level_1d] = valid_processors; - cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; - cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; -- - cpuinfo_max_cache_size = cpuinfo_arm_compute_max_cache_size(&processors[0]); - -+ cpuinfo_linux_cpu_max = arm_linux_processors_count; -+ cpuinfo_linux_cpu_to_processor_map = linux_cpu_to_processor_map; -+ cpuinfo_linux_cpu_to_core_map = linux_cpu_to_core_map; -+ cpuinfo_linux_cpu_to_uarch_index_map = linux_cpu_to_uarch_index_map; -+ - __sync_synchronize(); - - cpuinfo_is_initialized = true; - -- linux_cpu_to_processor_map = NULL; -- linux_cpu_to_core_map = NULL; - processors = NULL; - cores = NULL; - clusters = NULL; -+ uarchs = NULL; - l1i = l1d = l2 = l3 = NULL; -+ linux_cpu_to_processor_map = NULL; -+ linux_cpu_to_core_map = NULL; -+ linux_cpu_to_uarch_index_map = NULL; - - cleanup: - free(arm_linux_processors); -- free(linux_cpu_to_processor_map); -- free(linux_cpu_to_core_map); - free(processors); - free(cores); - free(clusters); -+ free(uarchs); - free(l1i); - free(l1d); - free(l2); - free(l3); -+ free(linux_cpu_to_processor_map); -+ free(linux_cpu_to_core_map); -+ free(linux_cpu_to_uarch_index_map); - } -diff --git a/src/arm/mach/init.c b/src/arm/mach/init.c -index e64cc18..bd27259 100644 ---- a/src/arm/mach/init.c -+++ b/src/arm/mach/init.c -@@ -14,6 +14,16 @@ - #include - #include - -+/* Polyfill recent CPUFAMILY_ARM_* values for older SDKs */ -+#ifndef CPUFAMILY_ARM_MONSOON_MISTRAL -+ #define CPUFAMILY_ARM_MONSOON_MISTRAL 0xE81E7EF6 -+#endif -+#ifndef CPUFAMILY_ARM_VORTEX_TEMPEST -+ #define CPUFAMILY_ARM_VORTEX_TEMPEST 0x07D34B9F -+#endif -+#ifndef CPUFAMILY_ARM_LIGHTNING_THUNDER -+ #define CPUFAMILY_ARM_LIGHTNING_THUNDER 0x462504D2 -+#endif - - struct cpuinfo_arm_isa cpuinfo_isa = { - #if CPUINFO_ARCH_ARM -@@ -82,37 +92,34 @@ static enum cpuinfo_uarch decode_uarch(uint32_t cpu_family, uint32_t cpu_subtype - return cpuinfo_uarch_twister; - case CPUFAMILY_ARM_HURRICANE: - return cpuinfo_uarch_hurricane; --#ifdef CPUFAMILY_ARM_MONSOON_MISTRAL - case CPUFAMILY_ARM_MONSOON_MISTRAL: --#else -- case 0xe81e7ef6: -- /* Hard-coded value for older SDKs which do not define CPUFAMILY_ARM_MONSOON_MISTRAL */ --#endif - /* 2x Monsoon + 4x Mistral cores */ - return core_index < 2 ? cpuinfo_uarch_monsoon : cpuinfo_uarch_mistral; --#ifdef CPUFAMILY_ARM_VORTEX_TEMPEST - case CPUFAMILY_ARM_VORTEX_TEMPEST: --#else -- case 0x07d34b9f: -- /* Hard-coded value for older SDKs which do not define CPUFAMILY_ARM_VORTEX_TEMPEST */ --#endif - /* Hexa-core: 2x Vortex + 4x Tempest; Octa-core: 4x Cortex + 4x Tempest */ - return core_index + 4 < core_count ? cpuinfo_uarch_vortex : cpuinfo_uarch_tempest; -+ case CPUFAMILY_ARM_LIGHTNING_THUNDER: -+ /* Hexa-core: 2x Lightning + 4x Thunder; Octa-core (presumed): 4x Lightning + 4x Thunder */ -+ return core_index + 4 < core_count ? cpuinfo_uarch_lightning : cpuinfo_uarch_thunder; - default: - /* Use hw.cpusubtype for detection */ - break; - } - -- switch (cpu_subtype) { -- case CPU_SUBTYPE_ARM_V7: -- return cpuinfo_uarch_cortex_a8; -- case CPU_SUBTYPE_ARM_V7F: -- return cpuinfo_uarch_cortex_a9; -- case CPU_SUBTYPE_ARM_V7K: -- return cpuinfo_uarch_cortex_a7; -- default: -- return cpuinfo_uarch_unknown; -- } -+ #if CPUINFO_ARCH_ARM -+ switch (cpu_subtype) { -+ case CPU_SUBTYPE_ARM_V7: -+ return cpuinfo_uarch_cortex_a8; -+ case CPU_SUBTYPE_ARM_V7F: -+ return cpuinfo_uarch_cortex_a9; -+ case CPU_SUBTYPE_ARM_V7K: -+ return cpuinfo_uarch_cortex_a7; -+ default: -+ return cpuinfo_uarch_unknown; -+ } -+ #else -+ return cpuinfo_uarch_unknown; -+ #endif - } - - static void decode_package_name(char* package_name) { -@@ -244,6 +251,7 @@ void cpuinfo_arm_mach_init(void) { - struct cpuinfo_core* cores = NULL; - struct cpuinfo_cluster* clusters = NULL; - struct cpuinfo_package* packages = NULL; -+ struct cpuinfo_uarch_info* uarchs = NULL; - struct cpuinfo_cache* l1i = NULL; - struct cpuinfo_cache* l1d = NULL; - struct cpuinfo_cache* l2 = NULL; -@@ -330,21 +338,12 @@ void cpuinfo_arm_mach_init(void) { - * Thus, we whitelist CPUs known to support these instructions. - */ - switch (cpu_family) { --#ifdef CPUFAMILY_ARM_MONSOON_MISTRAL - case CPUFAMILY_ARM_MONSOON_MISTRAL: --#else -- case 0xe81e7ef6: -- /* Hard-coded value for older SDKs which do not define CPUFAMILY_ARM_MONSOON_MISTRAL */ --#endif --#ifdef CPUFAMILY_ARM_VORTEX_TEMPEST - case CPUFAMILY_ARM_VORTEX_TEMPEST: --#else -- case 0x07d34b9f: -- /* Hard-coded value for older SDKs which do not define CPUFAMILY_ARM_VORTEX_TEMPEST */ --#endif --#if CPUINFO_ARCH_ARM64 -- cpuinfo_isa.atomics = true; --#endif -+ case CPUFAMILY_ARM_LIGHTNING_THUNDER: -+ #if CPUINFO_ARCH_ARM64 -+ cpuinfo_isa.atomics = true; -+ #endif - cpuinfo_isa.fp16arith = true; - } - -@@ -379,10 +378,22 @@ void cpuinfo_arm_mach_init(void) { - num_clusters * sizeof(struct cpuinfo_cluster), num_clusters); - goto cleanup; - } -+ uarchs = calloc(num_clusters, sizeof(struct cpuinfo_uarch_info)); -+ if (uarchs == NULL) { -+ cpuinfo_log_error( -+ "failed to allocate %zu bytes for descriptions of %"PRIu32" uarchs", -+ num_clusters * sizeof(enum cpuinfo_uarch), num_clusters); -+ goto cleanup; -+ } - uint32_t cluster_idx = UINT32_MAX; - for (uint32_t i = 0; i < mach_topology.cores; i++) { - if (i == 0 || cores[i].uarch != cores[i - 1].uarch) { - cluster_idx++; -+ uarchs[cluster_idx] = (struct cpuinfo_uarch_info) { -+ .uarch = cores[i].uarch, -+ .processor_count = 1, -+ .core_count = 1, -+ }; - clusters[cluster_idx] = (struct cpuinfo_cluster) { - .processor_start = i * threads_per_core, - .processor_count = 1, -@@ -394,6 +405,8 @@ void cpuinfo_arm_mach_init(void) { - .uarch = cores[i].uarch, - }; - } else { -+ uarchs[cluster_idx].processor_count++; -+ uarchs[cluster_idx].core_count++; - clusters[cluster_idx].processor_count++; - clusters[cluster_idx].core_count++; - } -@@ -542,26 +555,25 @@ void cpuinfo_arm_mach_init(void) { - } - - /* Commit changes */ -- cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; -- cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; -- cpuinfo_cache[cpuinfo_cache_level_2] = l2; -- cpuinfo_cache[cpuinfo_cache_level_3] = l3; -- - cpuinfo_processors = processors; - cpuinfo_cores = cores; - cpuinfo_clusters = clusters; - cpuinfo_packages = packages; -- -- cpuinfo_cache_count[cpuinfo_cache_level_1i] = l1_count; -- cpuinfo_cache_count[cpuinfo_cache_level_1d] = l1_count; -- cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; -- cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; -+ cpuinfo_uarchs = uarchs; -+ cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; -+ cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; -+ cpuinfo_cache[cpuinfo_cache_level_2] = l2; -+ cpuinfo_cache[cpuinfo_cache_level_3] = l3; - - cpuinfo_processors_count = mach_topology.threads; - cpuinfo_cores_count = mach_topology.cores; - cpuinfo_clusters_count = num_clusters; - cpuinfo_packages_count = mach_topology.packages; -- -+ cpuinfo_uarchs_count = num_clusters; -+ cpuinfo_cache_count[cpuinfo_cache_level_1i] = l1_count; -+ cpuinfo_cache_count[cpuinfo_cache_level_1d] = l1_count; -+ cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; -+ cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; - cpuinfo_max_cache_size = cpuinfo_compute_max_cache_size(&processors[0]); - - __sync_synchronize(); -@@ -572,6 +584,7 @@ void cpuinfo_arm_mach_init(void) { - cores = NULL; - clusters = NULL; - packages = NULL; -+ uarchs = NULL; - l1i = l1d = l2 = l3 = NULL; - - cleanup: -@@ -579,6 +592,7 @@ cleanup: - free(cores); - free(clusters); - free(packages); -+ free(uarchs); - free(l1i); - free(l1d); - free(l2); -diff --git a/src/arm/uarch.c b/src/arm/uarch.c -index a38250a..2aef9e7 100644 ---- a/src/arm/uarch.c -+++ b/src/arm/uarch.c -@@ -58,7 +58,9 @@ void cpuinfo_arm_decode_vendor_uarch( - *uarch = cpuinfo_uarch_cortex_a35; - break; - case 0xD05: -- *uarch = cpuinfo_uarch_cortex_a55; -+ // Note: use Variant, not Revision, field -+ *uarch = (midr & CPUINFO_ARM_MIDR_VARIANT_MASK) == 0 ? -+ cpuinfo_uarch_cortex_a55r0 : cpuinfo_uarch_cortex_a55; - break; - case 0xD06: - *uarch = cpuinfo_uarch_cortex_a65; -@@ -257,9 +259,9 @@ void cpuinfo_arm_decode_vendor_uarch( - *vendor = cpuinfo_vendor_arm; - *uarch = cpuinfo_uarch_cortex_a75; - break; -- case 0x803: /* Low-power Kryo 385 "Silver" -> Cortex-A55 */ -+ case 0x803: /* Low-power Kryo 385 "Silver" -> Cortex-A55r0 */ - *vendor = cpuinfo_vendor_arm; -- *uarch = cpuinfo_uarch_cortex_a55; -+ *uarch = cpuinfo_uarch_cortex_a55r0; - break; - case 0x804: /* High-performance Kryo 485 "Gold" / "Gold Prime" -> Cortex-A76 */ - *vendor = cpuinfo_vendor_arm; -diff --git a/src/cpuinfo/common.h b/src/cpuinfo/common.h -index 6ba746e..b2b404d 100644 ---- a/src/cpuinfo/common.h -+++ b/src/cpuinfo/common.h -@@ -12,29 +12,29 @@ - #define CPUINFO_COUNT_OF(array) (sizeof(array) / sizeof(0[array])) - - #if defined(__GNUC__) -- #define CPUINFO_LIKELY(condition) (__builtin_expect(!!(condition), 1)) -- #define CPUINFO_UNLIKELY(condition) (__builtin_expect(!!(condition), 0)) -+ #define CPUINFO_LIKELY(condition) (__builtin_expect(!!(condition), 1)) -+ #define CPUINFO_UNLIKELY(condition) (__builtin_expect(!!(condition), 0)) - #else -- #define CPUINFO_LIKELY(condition) (!!(condition)) -- #define CPUINFO_UNLIKELY(condition) (!!(condition)) -+ #define CPUINFO_LIKELY(condition) (!!(condition)) -+ #define CPUINFO_UNLIKELY(condition) (!!(condition)) - #endif - - #ifndef CPUINFO_INTERNAL -- #if defined(__ELF__) -- #define CPUINFO_INTERNAL __attribute__((__visibility__("internal"))) -- #elif defined(__MACH__) -- #define CPUINFO_INTERNAL __attribute__((__visibility__("hidden"))) -- #else -- #define CPUINFO_INTERNAL -- #endif -+ #if defined(__ELF__) -+ #define CPUINFO_INTERNAL __attribute__((__visibility__("internal"))) -+ #elif defined(__MACH__) -+ #define CPUINFO_INTERNAL __attribute__((__visibility__("hidden"))) -+ #else -+ #define CPUINFO_INTERNAL -+ #endif - #endif - - #ifndef CPUINFO_PRIVATE -- #if defined(__ELF__) -- #define CPUINFO_PRIVATE __attribute__((__visibility__("hidden"))) -- #elif defined(__MACH__) -- #define CPUINFO_PRIVATE __attribute__((__visibility__("hidden"))) -- #else -- #define CPUINFO_PRIVATE -- #endif -+ #if defined(__ELF__) -+ #define CPUINFO_PRIVATE __attribute__((__visibility__("hidden"))) -+ #elif defined(__MACH__) -+ #define CPUINFO_PRIVATE __attribute__((__visibility__("hidden"))) -+ #else -+ #define CPUINFO_PRIVATE -+ #endif - #endif -diff --git a/src/cpuinfo/internal-api.h b/src/cpuinfo/internal-api.h -index f12c48d..c6eed0b 100644 ---- a/src/cpuinfo/internal-api.h -+++ b/src/cpuinfo/internal-api.h -@@ -21,11 +21,13 @@ enum cpuinfo_cache_level { - }; - - extern CPUINFO_INTERNAL bool cpuinfo_is_initialized; -+ - extern CPUINFO_INTERNAL struct cpuinfo_processor* cpuinfo_processors; - extern CPUINFO_INTERNAL struct cpuinfo_core* cpuinfo_cores; - extern CPUINFO_INTERNAL struct cpuinfo_cluster* cpuinfo_clusters; - extern CPUINFO_INTERNAL struct cpuinfo_package* cpuinfo_packages; - extern CPUINFO_INTERNAL struct cpuinfo_cache* cpuinfo_cache[cpuinfo_cache_level_max]; -+ - extern CPUINFO_INTERNAL uint32_t cpuinfo_processors_count; - extern CPUINFO_INTERNAL uint32_t cpuinfo_cores_count; - extern CPUINFO_INTERNAL uint32_t cpuinfo_clusters_count; -@@ -33,6 +35,19 @@ extern CPUINFO_INTERNAL uint32_t cpuinfo_packages_count; - extern CPUINFO_INTERNAL uint32_t cpuinfo_cache_count[cpuinfo_cache_level_max]; - extern CPUINFO_INTERNAL uint32_t cpuinfo_max_cache_size; - -+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 -+ extern CPUINFO_INTERNAL struct cpuinfo_uarch_info* cpuinfo_uarchs; -+ extern CPUINFO_INTERNAL uint32_t cpuinfo_uarchs_count; -+#else -+ extern CPUINFO_INTERNAL struct cpuinfo_uarch_info cpuinfo_global_uarch; -+#endif -+ -+#ifdef __linux__ -+ extern CPUINFO_INTERNAL uint32_t cpuinfo_linux_cpu_max; -+ extern CPUINFO_INTERNAL const struct cpuinfo_processor** cpuinfo_linux_cpu_to_processor_map; -+ extern CPUINFO_INTERNAL const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map; -+#endif -+ - CPUINFO_PRIVATE void cpuinfo_x86_mach_init(void); - CPUINFO_PRIVATE void cpuinfo_x86_linux_init(void); - #ifdef _WIN32 -diff --git a/src/linux/current.c b/src/linux/current.c -deleted file mode 100644 -index 472a4c9..0000000 ---- a/src/linux/current.c -+++ /dev/null -@@ -1,41 +0,0 @@ --#include --#include --#include --#include --#include -- --#include -- --#include --#include --#include --#include -- -- --const struct cpuinfo_processor** cpuinfo_linux_cpu_to_processor_map = NULL; --const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map = NULL; -- -- --const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_current_processor(void) { -- if (!cpuinfo_is_initialized) { -- cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_processor"); -- } -- const int cpu = sched_getcpu(); -- if (cpu >= 0) { -- return cpuinfo_linux_cpu_to_processor_map[cpu]; -- } else { -- return &cpuinfo_processors[0]; -- } --} -- --const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_current_core(void) { -- if (!cpuinfo_is_initialized) { -- cpuinfo_log_fatal("cpuinfo_get_%s called before cpuinfo is initialized", "current_core"); -- } -- const int cpu = sched_getcpu(); -- if (cpu >= 0) { -- return cpuinfo_linux_cpu_to_core_map[cpu]; -- } else { -- return &cpuinfo_cores[0]; -- } --} -diff --git a/src/x86/api.h b/src/x86/api.h -index 5f5e76d..213c2d8 100644 ---- a/src/x86/api.h -+++ b/src/x86/api.h -@@ -93,7 +93,6 @@ CPUINFO_INTERNAL struct cpuinfo_x86_isa cpuinfo_x86_detect_isa( - const struct cpuid_regs basic_info, const struct cpuid_regs extended_info, - uint32_t max_base_index, uint32_t max_extended_index, - enum cpuinfo_vendor vendor, enum cpuinfo_uarch uarch); --CPUINFO_INTERNAL struct cpuinfo_x86_isa cpuinfo_x86_nacl_detect_isa(void); - - CPUINFO_INTERNAL void cpuinfo_x86_detect_topology( - uint32_t max_base_index, -diff --git a/src/x86/cache/init.c b/src/x86/cache/init.c -index d581016..dd1f1ea 100644 ---- a/src/x86/cache/init.c -+++ b/src/x86/cache/init.c -@@ -65,7 +65,7 @@ iterate_descriptors: - } - } - -- if (vendor != cpuinfo_vendor_amd && max_base_index >= 4) { -+ if (vendor != cpuinfo_vendor_amd && vendor != cpuinfo_vendor_hygon && max_base_index >= 4) { - struct cpuid_regs leaf4; - uint32_t input_ecx = 0; - uint32_t package_cores_max = 0; -diff --git a/src/x86/cpuid.h b/src/x86/cpuid.h -index 829ec21..9e9e013 100644 ---- a/src/x86/cpuid.h -+++ b/src/x86/cpuid.h -@@ -67,18 +67,13 @@ - } - #endif - --/* -- * This instruction may be not supported by Native Client validator, -- * make sure it doesn't appear in the binary -- */ --#ifndef __native_client__ -- static inline uint64_t xgetbv(uint32_t ext_ctrl_reg) { -- #ifdef _MSC_VER -- return (uint64_t)_xgetbv((unsigned int)ext_ctrl_reg); -- #else -- uint32_t lo, hi; -- __asm__(".byte 0x0F, 0x01, 0xD0" : "=a" (lo), "=d" (hi) : "c" (ext_ctrl_reg)); -- return ((uint64_t) hi << 32) | (uint64_t) lo; -- #endif -- } --#endif -+static inline uint64_t xgetbv(uint32_t ext_ctrl_reg) { -+ #ifdef _MSC_VER -+ return (uint64_t)_xgetbv((unsigned int)ext_ctrl_reg); -+ #else -+ uint32_t lo, hi; -+ __asm__(".byte 0x0F, 0x01, 0xD0" : "=a" (lo), "=d" (hi) : "c" (ext_ctrl_reg)); -+ return ((uint64_t) hi << 32) | (uint64_t) lo; -+ #endif -+} -+ -diff --git a/src/x86/init.c b/src/x86/init.c -index d736578..244359c 100644 ---- a/src/x86/init.c -+++ b/src/x86/init.c -@@ -61,12 +61,8 @@ void cpuinfo_x86_init_processor(struct cpuinfo_x86_processor* processor) { - - cpuinfo_x86_detect_topology(max_base_index, max_extended_index, leaf1, &processor->topology); - -- #ifdef __native_client__ -- cpuinfo_isa = cpuinfo_x86_nacl_detect_isa(); -- #else -- cpuinfo_isa = cpuinfo_x86_detect_isa(leaf1, leaf0x80000001, -- max_base_index, max_extended_index, vendor, uarch); -- #endif -+ cpuinfo_isa = cpuinfo_x86_detect_isa(leaf1, leaf0x80000001, -+ max_base_index, max_extended_index, vendor, uarch); - } - if (max_extended_index >= UINT32_C(0x80000004)) { - struct cpuid_regs brand_string[3]; -diff --git a/src/x86/isa.c b/src/x86/isa.c -index d27dbca..f2e5a28 100644 ---- a/src/x86/isa.c -+++ b/src/x86/isa.c -@@ -244,6 +244,7 @@ struct cpuinfo_x86_isa cpuinfo_x86_detect_isa( - */ - break; - case cpuinfo_vendor_amd: -+ case cpuinfo_vendor_hygon: - isa.prefetch = !!((extended_info.ecx & UINT32_C(0x00000100)) | (extended_info.edx & UINT32_C(0xE0000000))); - break; - default: -@@ -265,6 +266,7 @@ struct cpuinfo_x86_isa cpuinfo_x86_detect_isa( - */ - switch (vendor) { - case cpuinfo_vendor_amd: -+ case cpuinfo_vendor_hygon: - isa.prefetchw = !!((extended_info.ecx & UINT32_C(0x00000100)) | (extended_info.edx & UINT32_C(0xE0000000))); - break; - default: -diff --git a/src/x86/linux/init.c b/src/x86/linux/init.c -index c096336..f565789 100644 ---- a/src/x86/linux/init.c -+++ b/src/x86/linux/init.c -@@ -569,9 +569,6 @@ void cpuinfo_x86_linux_init(void) { - } - - /* Commit changes */ -- cpuinfo_linux_cpu_to_processor_map = linux_cpu_to_processor_map; -- cpuinfo_linux_cpu_to_core_map = linux_cpu_to_core_map; -- - cpuinfo_processors = processors; - cpuinfo_cores = cores; - cpuinfo_clusters = clusters; -@@ -591,24 +588,32 @@ void cpuinfo_x86_linux_init(void) { - cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; - cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; - cpuinfo_cache_count[cpuinfo_cache_level_4] = l4_count; -- - cpuinfo_max_cache_size = cpuinfo_compute_max_cache_size(&processors[0]); - -+ cpuinfo_global_uarch = (struct cpuinfo_uarch_info) { -+ .uarch = x86_processor.uarch, -+ .cpuid = x86_processor.cpuid, -+ .processor_count = processors_count, -+ .core_count = cores_count, -+ }; -+ -+ cpuinfo_linux_cpu_max = x86_linux_processors_count; -+ cpuinfo_linux_cpu_to_processor_map = linux_cpu_to_processor_map; -+ cpuinfo_linux_cpu_to_core_map = linux_cpu_to_core_map; -+ - __sync_synchronize(); - - cpuinfo_is_initialized = true; - -- linux_cpu_to_processor_map = NULL; -- linux_cpu_to_core_map = NULL; - processors = NULL; - cores = NULL; - clusters = NULL; - packages = NULL; - l1i = l1d = l2 = l3 = l4 = NULL; -+ linux_cpu_to_processor_map = NULL; -+ linux_cpu_to_core_map = NULL; - - cleanup: -- free(linux_cpu_to_processor_map); -- free(linux_cpu_to_core_map); - free(x86_linux_processors); - free(processors); - free(cores); -@@ -619,4 +624,6 @@ cleanup: - free(l2); - free(l3); - free(l4); -+ free(linux_cpu_to_processor_map); -+ free(linux_cpu_to_core_map); - } -diff --git a/src/x86/mach/init.c b/src/x86/mach/init.c -index ae2be33..b44d3ad 100644 ---- a/src/x86/mach/init.c -+++ b/src/x86/mach/init.c -@@ -305,30 +305,34 @@ void cpuinfo_x86_mach_init(void) { - } - - /* Commit changes */ -+ cpuinfo_processors = processors; -+ cpuinfo_cores = cores; -+ cpuinfo_clusters = clusters; -+ cpuinfo_packages = packages; - cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; - cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; - cpuinfo_cache[cpuinfo_cache_level_2] = l2; - cpuinfo_cache[cpuinfo_cache_level_3] = l3; - cpuinfo_cache[cpuinfo_cache_level_4] = l4; - -- cpuinfo_processors = processors; -- cpuinfo_cores = cores; -- cpuinfo_clusters = clusters; -- cpuinfo_packages = packages; -- -+ cpuinfo_processors_count = mach_topology.threads; -+ cpuinfo_cores_count = mach_topology.cores; -+ cpuinfo_clusters_count = mach_topology.packages; -+ cpuinfo_packages_count = mach_topology.packages; - cpuinfo_cache_count[cpuinfo_cache_level_1i] = l1_count; - cpuinfo_cache_count[cpuinfo_cache_level_1d] = l1_count; - cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; - cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; - cpuinfo_cache_count[cpuinfo_cache_level_4] = l4_count; -- -- cpuinfo_processors_count = mach_topology.threads; -- cpuinfo_cores_count = mach_topology.cores; -- cpuinfo_clusters_count = mach_topology.packages; -- cpuinfo_packages_count = mach_topology.packages; -- - cpuinfo_max_cache_size = cpuinfo_compute_max_cache_size(&processors[0]); - -+ cpuinfo_global_uarch = (struct cpuinfo_uarch_info) { -+ .uarch = x86_processor.uarch, -+ .cpuid = x86_processor.cpuid, -+ .processor_count = mach_topology.threads, -+ .core_count = mach_topology.cores, -+ }; -+ - __sync_synchronize(); - - cpuinfo_is_initialized = true; -diff --git a/src/x86/nacl/isa.c b/src/x86/nacl/isa.c -deleted file mode 100644 -index 662be33..0000000 ---- a/src/x86/nacl/isa.c -+++ /dev/null -@@ -1,306 +0,0 @@ --#include --#include --#include -- --#include -- --#define NACL_CODE_BUNDLE_SIZE 32 --#include --#include -- --static const uint8_t cmpxchg16b_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* MOV edi, edi */ -- 0x89, 0xFF, -- /* CMPXCHG16B [r15 + rdi * 1] */ -- 0x49, 0x0F, 0xC7, 0x0C, 0x3F, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t lzcnt_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* LZCNT eax, ecx */ -- 0xF3, 0x0F, 0xBD, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t popcnt_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* POPCNT eax, ecx */ -- 0xF3, 0x0F, 0xB8, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t movbe_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* MOV ecx, ecx */ -- 0x89, 0xC9, -- /* MOVBE eax, [r15 + rcx * 1] */ -- 0x41, 0x0F, 0x38, 0xF0, 0x04, 0x0F, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t bmi_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* ANDN eax, ecx, edx */ -- 0xC4, 0xE2, 0x70, 0xF2, 0xC2, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t tbm_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* BLCS eax, ecx */ -- 0x8F, 0xE9, 0x78, 0x01, 0xD9, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t three_d_now_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* PFADD mm0, mm1 */ -- 0x0F, 0x0F, 0xC1, 0x9E, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t three_d_now_plus_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* PFNACC mm0, mm1 */ -- 0x0F, 0x0F, 0xC1, 0x8A, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t sse3_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* HADDPS xmm0, xmm1 */ -- 0xF2, 0x0F, 0x7C, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t ssse3_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* PSHUFB xmm0, xmm1 */ -- 0x66, 0x0F, 0x38, 0x00, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t sse4_1_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* PMULLD xmm0, xmm1 */ -- 0x66, 0x0F, 0x38, 0x40, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t sse4_2_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* PCMPGTQ xmm0, xmm1 */ -- 0x66, 0x0F, 0x38, 0x37, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t sse4a_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* EXTRQ xmm0, xmm1 */ -- 0x66, 0x0F, 0x79, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t aes_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* AESENC xmm0, xmm1 */ -- 0x66, 0x0F, 0x38, 0xDC, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t pclmulqdq_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* PCLMULQDQ xmm0, xmm1, 0 */ -- 0x66, 0x0F, 0x3A, 0x44, 0xC1, 0x00, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t avx_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* VPERMILPS ymm0, ymm1, 0xAA */ -- 0xC4, 0xE3, 0x7D, 0x04, 0xC1, 0xAA, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t fma3_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* VFMADDSUB213PS ymm0, ymm1, ymm2 */ -- 0xC4, 0xE2, 0x75, 0xA6, 0xC2, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t fma4_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* VFMADDPS ymm0, ymm1, ymm2, ymm3 */ -- 0xC4, 0xE3, 0xF5, 0x68, 0xC3, 0x20, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t xop_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* VPHADDBQ xmm0, xmm1 */ -- 0x8F, 0xE9, 0x78, 0xC3, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t f16c_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* VCVTPH2PS ymm0, xmm1 */ -- 0xC4, 0xE2, 0x7D, 0x13, 0xC1, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- --static const uint8_t avx2_bundle[NACL_CODE_BUNDLE_SIZE] = { -- /* VPERMPS ymm0, ymm1, ymm2 */ -- 0xC4, 0xE2, 0x75, 0x16, 0xC2, -- /* Fill remainder with HLTs */ -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, -- 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, 0xF4, --}; -- -- --struct cpuinfo_x86_isa cpuinfo_x86_nacl_detect_isa(void) { -- /* -- * Under Native Client sandbox we can't just ask the CPU: -- * - First, some instructions (XGETBV) necessary to query AVX support are not white-listed in the validator. -- * - Secondly, even if CPU supports some instruction, but validator doesn't know about it (e.g. due a bug in the -- * ISA detection in the validator), all instructions from the "unsupported" ISA extensions will be replaced by -- * HLTs when the module is loaded. -- * Thus, instead of quering the CPU about supported ISA extensions, we query the validator: we pass bundles with -- * instructions from ISA extensions to dynamic code generation APIs, and test if they are accepted. -- */ -- -- struct cpuinfo_x86_isa isa = { 0 }; -- -- struct nacl_irt_code_data_alloc nacl_irt_code_data_alloc = { 0 }; -- struct nacl_irt_dyncode nacl_irt_dyncode = { 0 }; -- if (sizeof(nacl_irt_code_data_alloc) != nacl_interface_query(NACL_IRT_CODE_DATA_ALLOC_v0_1, -- &nacl_irt_code_data_alloc, -- sizeof(nacl_irt_code_data_alloc))) -- { -- goto finish; -- } -- -- if (sizeof(nacl_irt_dyncode) != nacl_interface_query(NACL_IRT_DYNCODE_v0_1, -- &nacl_irt_dyncode, -- sizeof(nacl_irt_dyncode))) -- { -- goto finish; -- } -- -- const size_t allocation_size = 65536; -- uintptr_t code_segment = 0; -- if (0 != nacl_irt_code_data_alloc.allocate_code_data(0, allocation_size, 0, 0, &code_segment)) -- { -- goto finish; -- } -- -- isa.cmpxchg16b = !nacl_irt_dyncode.dyncode_create((void*) code_segment, cmpxchg16b_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.lzcnt = !nacl_irt_dyncode.dyncode_create((void*) code_segment, lzcnt_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.popcnt = !nacl_irt_dyncode.dyncode_create((void*) code_segment, popcnt_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.movbe = !nacl_irt_dyncode.dyncode_create((void*) code_segment, movbe_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.bmi = !nacl_irt_dyncode.dyncode_create((void*) code_segment, bmi_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.tbm = !nacl_irt_dyncode.dyncode_create((void*) code_segment, tbm_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.three_d_now = !nacl_irt_dyncode.dyncode_create((void*) code_segment, three_d_now_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.three_d_now_plus = -- !nacl_irt_dyncode.dyncode_create((void*) code_segment, three_d_now_plus_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.sse3 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, sse3_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.ssse3 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, ssse3_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.sse4_1 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, sse4_1_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.sse4_2 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, sse4_2_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.sse4a = !nacl_irt_dyncode.dyncode_create((void*) code_segment, sse4a_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.aes = !nacl_irt_dyncode.dyncode_create((void*) code_segment, aes_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.pclmulqdq = !nacl_irt_dyncode.dyncode_create((void*) code_segment, pclmulqdq_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.avx = !nacl_irt_dyncode.dyncode_create((void*) code_segment, avx_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.fma3 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, fma3_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.fma4 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, fma4_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.xop = !nacl_irt_dyncode.dyncode_create((void*) code_segment, xop_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.f16c = !nacl_irt_dyncode.dyncode_create((void*) code_segment, f16c_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- code_segment += NACL_CODE_BUNDLE_SIZE; -- -- isa.avx2 = !nacl_irt_dyncode.dyncode_create((void*) code_segment, avx2_bundle, NACL_CODE_BUNDLE_SIZE) && -- (*((const uint8_t*) code_segment) != 0xF4); -- --finish: -- return isa; --} -diff --git a/src/x86/name.c b/src/x86/name.c -index 708be1d..e0d5a5b 100644 ---- a/src/x86/name.c -+++ b/src/x86/name.c -@@ -671,6 +671,7 @@ static const char* vendor_string_map[] = { - [cpuinfo_vendor_intel] = "Intel", - [cpuinfo_vendor_amd] = "AMD", - [cpuinfo_vendor_via] = "VIA", -+ [cpuinfo_vendor_hygon] = "Hygon", - [cpuinfo_vendor_rdc] = "RDC", - [cpuinfo_vendor_dmp] = "DM&P", - [cpuinfo_vendor_transmeta] = "Transmeta", -diff --git a/src/x86/uarch.c b/src/x86/uarch.c -index ba72d8a..ecaa762 100644 ---- a/src/x86/uarch.c -+++ b/src/x86/uarch.c -@@ -79,6 +79,8 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( - case 0x5E: // Sky Lake Client DT/H/S - case 0x8E: // Kaby/Whiskey/Amber/Comet Lake Y/U - case 0x9E: // Kaby/Coffee Lake DT/H/S -+ case 0xA5: // Comet Lake H/S -+ case 0xA6: // Comet Lake U/Y - return cpuinfo_uarch_sky_lake; - case 0x66: // Cannon Lake (Core i3-8121U) - return cpuinfo_uarch_palm_cove; -@@ -94,7 +96,7 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( - return cpuinfo_uarch_bonnell; - case 0x27: // Medfield - case 0x35: // Cloverview -- case 0x36: // Cedarview, Centerton -+ case 0x36: // Cedarview, Centerton - return cpuinfo_uarch_saltwell; - case 0x37: // Bay Trail - case 0x4A: // Merrifield -@@ -110,6 +112,7 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( - return cpuinfo_uarch_goldmont; - case 0x7A: // Gemini Lake - return cpuinfo_uarch_goldmont_plus; -+ - /* Knights-series cores */ - case 0x57: - return cpuinfo_uarch_knights_landing; -@@ -173,7 +176,7 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( - case 0x38: // Godavari - case 0x30: // Kaveri - return cpuinfo_uarch_steamroller; -- case 0x60: // Carrizo -+ case 0x60: // Carrizo - case 0x65: // Bristol Ridge - case 0x70: // Stoney Ridge - return cpuinfo_uarch_excavator; -@@ -201,14 +204,22 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( - switch (model_info->model) { - case 0x01: // 14 nm Naples, Whitehaven, Summit Ridge, Snowy Owl - case 0x08: // 12 nm Pinnacle Ridge -- case 0x11: // 14 nm Raven Ridge -+ case 0x11: // 14 nm Raven Ridge, Great Horned Owl - case 0x18: // 12 nm Picasso - return cpuinfo_uarch_zen; -+ case 0x31: // Rome, Castle Peak -+ case 0x60: // Renoir - case 0x71: // Matisse - return cpuinfo_uarch_zen2; - } - } - break; -+ case cpuinfo_vendor_hygon: -+ switch (model_info->family) { -+ case 0x00: -+ return cpuinfo_uarch_dhyana; -+ } -+ break; - default: - break; - } -diff --git a/src/x86/vendor.c b/src/x86/vendor.c -index 3f3c753..2bba90d 100644 ---- a/src/x86/vendor.c -+++ b/src/x86/vendor.c -@@ -26,6 +26,11 @@ - #define auls UINT32_C(0x736C7561) - #define VIA UINT32_C(0x20414956) - -+/* Hygon vendor string: "HygonGenuine" */ -+#define Hygo UINT32_C(0x6F677948) -+#define nGen UINT32_C(0x6E65476E) -+#define uine UINT32_C(0x656E6975) -+ - /* Transmeta vendor strings: "GenuineTMx86", "TransmetaCPU" */ - #define ineT UINT32_C(0x54656E69) - #define Mx86 UINT32_C(0x3638784D) -@@ -105,6 +110,12 @@ enum cpuinfo_vendor cpuinfo_x86_decode_vendor(uint32_t ebx, uint32_t ecx, uint32 - return cpuinfo_vendor_via; - } - break; -+ case Hygo: -+ if (edx == nGen && ecx == uine) { -+ /* "HygonGenuine" */ -+ return cpuinfo_vendor_hygon; -+ } -+ break; - #if CPUINFO_ARCH_X86 - case AMDi: - if (edx == sbet && ecx == ter) { -diff --git a/src/x86/windows/init.c b/src/x86/windows/init.c -index 7a2090e..2c7e3cd 100644 ---- a/src/x86/windows/init.c -+++ b/src/x86/windows/init.c -@@ -417,9 +417,6 @@ BOOL CALLBACK cpuinfo_x86_windows_init(PINIT_ONCE init_once, PVOID parameter, PV - for (uint32_t i = 0; i < processors_count; i++) { - const uint32_t apic_id = processors[i].apic_id; - -- //linux_cpu_to_processor_map[x86_linux_processors[i].linux_id] = processors + processor_index; -- //linux_cpu_to_core_map[x86_linux_processors[i].linux_id] = cores + core_index; -- - if (x86_processor.cache.l1i.size != 0) { - const uint32_t l1i_id = apic_id & ~bit_mask(x86_processor.cache.l1i.apic_bits); - processors[i].cache.l1i = &l1i[l1i_index]; -@@ -549,30 +546,34 @@ BOOL CALLBACK cpuinfo_x86_windows_init(PINIT_ONCE init_once, PVOID parameter, PV - - - /* Commit changes */ -+ cpuinfo_processors = processors; -+ cpuinfo_cores = cores; -+ cpuinfo_clusters = clusters; -+ cpuinfo_packages = packages; - cpuinfo_cache[cpuinfo_cache_level_1i] = l1i; - cpuinfo_cache[cpuinfo_cache_level_1d] = l1d; - cpuinfo_cache[cpuinfo_cache_level_2] = l2; - cpuinfo_cache[cpuinfo_cache_level_3] = l3; - cpuinfo_cache[cpuinfo_cache_level_4] = l4; - -- cpuinfo_processors = processors; -- cpuinfo_cores = cores; -- cpuinfo_clusters = clusters; -- cpuinfo_packages = packages; -- -+ cpuinfo_processors_count = processors_count; -+ cpuinfo_cores_count = cores_count; -+ cpuinfo_clusters_count = packages_count; -+ cpuinfo_packages_count = packages_count; - cpuinfo_cache_count[cpuinfo_cache_level_1i] = l1i_count; - cpuinfo_cache_count[cpuinfo_cache_level_1d] = l1d_count; - cpuinfo_cache_count[cpuinfo_cache_level_2] = l2_count; - cpuinfo_cache_count[cpuinfo_cache_level_3] = l3_count; - cpuinfo_cache_count[cpuinfo_cache_level_4] = l4_count; -- -- cpuinfo_processors_count = processors_count; -- cpuinfo_cores_count = cores_count; -- cpuinfo_clusters_count = packages_count; -- cpuinfo_packages_count = packages_count; -- - cpuinfo_max_cache_size = cpuinfo_compute_max_cache_size(&processors[0]); - -+ cpuinfo_global_uarch = (struct cpuinfo_uarch_info) { -+ .uarch = x86_processor.uarch, -+ .cpuid = x86_processor.cpuid, -+ .processor_count = processors_count, -+ .core_count = cores_count, -+ }; -+ - MemoryBarrier(); - - cpuinfo_is_initialized = true; -diff --git a/test/arm-cache.cc b/test/arm-cache.cc -index 8373f7c..7d2e4a4 100644 ---- a/test/arm-cache.cc -+++ b/test/arm-cache.cc -@@ -766,7 +766,7 @@ TEST(QUALCOMM, snapdragon_845) { - struct cpuinfo_cache little_l2 = { 0 }; - struct cpuinfo_cache little_l3 = { 0 }; - cpuinfo_arm_decode_cache( -- cpuinfo_uarch_cortex_a55, 4, UINT32_C(0x518F803C), -+ cpuinfo_uarch_cortex_a55r0, 4, UINT32_C(0x518F803C), - &chipset, 1, 8, - &little_l1i, &little_l1d, &little_l2, &little_l3); - -@@ -910,7 +910,7 @@ TEST(SAMSUNG, exynos_9810) { - struct cpuinfo_cache little_l2 = { 0 }; - struct cpuinfo_cache little_l3 = { 0 }; - cpuinfo_arm_decode_cache( -- cpuinfo_uarch_cortex_a55, 4, UINT32_C(0x410FD051), -+ cpuinfo_uarch_cortex_a55r0, 4, UINT32_C(0x410FD051), - &chipset, 1, 8, - &little_l1i, &little_l1d, &little_l2, &little_l3); - -diff --git a/test/get-current.cc b/test/get-current.cc -index 4a80cab..f410b12 100644 ---- a/test/get-current.cc -+++ b/test/get-current.cc -@@ -3,34 +3,36 @@ - #include - - --TEST(CURRENT_PROCESSOR, not_null) { -- ASSERT_TRUE(cpuinfo_initialize()); -- -- ASSERT_TRUE(cpuinfo_get_current_processor()); --} -- - TEST(CURRENT_PROCESSOR, within_bounds) { - ASSERT_TRUE(cpuinfo_initialize()); - - const struct cpuinfo_processor* current_processor = cpuinfo_get_current_processor(); -+ if (current_processor == nullptr) { -+ GTEST_SKIP(); -+ } -+ - const struct cpuinfo_processor* processors_begin = cpuinfo_get_processors(); - const struct cpuinfo_processor* processors_end = processors_begin + cpuinfo_get_processors_count(); - ASSERT_GE(current_processor, processors_begin); - ASSERT_LT(current_processor, processors_end); - } - --TEST(CURRENT_CORE, not_null) { -- ASSERT_TRUE(cpuinfo_initialize()); -- -- ASSERT_TRUE(cpuinfo_get_current_core()); --} -- - TEST(CURRENT_CORE, within_bounds) { - ASSERT_TRUE(cpuinfo_initialize()); - - const struct cpuinfo_core* current_core = cpuinfo_get_current_core(); -+ if (current_core == nullptr) { -+ GTEST_SKIP(); -+ } -+ - const struct cpuinfo_core* cores_begin = cpuinfo_get_cores(); - const struct cpuinfo_core* cores_end = cores_begin + cpuinfo_get_cores_count(); - ASSERT_GE(current_core, cores_begin); - ASSERT_LT(current_core, cores_end); - } -+ -+TEST(CURRENT_UARCH_INDEX, within_bounds) { -+ ASSERT_TRUE(cpuinfo_initialize()); -+ -+ ASSERT_LT(cpuinfo_get_current_uarch_index(), cpuinfo_get_uarchs_count()); -+} -diff --git a/test/init.cc b/test/init.cc -index 941cb97..718eb96 100644 ---- a/test/init.cc -+++ b/test/init.cc -@@ -678,6 +678,72 @@ TEST(PACKAGE, consistent_cluster) { - cpuinfo_deinitialize(); - } - -+TEST(UARCHS_COUNT, within_bounds) { -+ ASSERT_TRUE(cpuinfo_initialize()); -+ EXPECT_NE(0, cpuinfo_get_uarchs_count()); -+ EXPECT_LE(cpuinfo_get_packages_count(), cpuinfo_get_cores_count()); -+ EXPECT_LE(cpuinfo_get_packages_count(), cpuinfo_get_processors_count()); -+ cpuinfo_deinitialize(); -+} -+ -+TEST(UARCHS, non_null) { -+ ASSERT_TRUE(cpuinfo_initialize()); -+ EXPECT_TRUE(cpuinfo_get_uarchs()); -+ cpuinfo_deinitialize(); -+} -+ -+TEST(UARCH, non_null) { -+ ASSERT_TRUE(cpuinfo_initialize()); -+ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { -+ EXPECT_TRUE(cpuinfo_get_uarch(i)); -+ } -+ cpuinfo_deinitialize(); -+} -+ -+TEST(UARCH, non_zero_processors) { -+ ASSERT_TRUE(cpuinfo_initialize()); -+ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { -+ const cpuinfo_uarch_info* uarch = cpuinfo_get_uarch(i); -+ ASSERT_TRUE(uarch); -+ -+ EXPECT_NE(0, uarch->processor_count); -+ } -+ cpuinfo_deinitialize(); -+} -+ -+TEST(UARCH, valid_processors) { -+ ASSERT_TRUE(cpuinfo_initialize()); -+ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { -+ const cpuinfo_uarch_info* uarch = cpuinfo_get_uarch(i); -+ ASSERT_TRUE(uarch); -+ -+ EXPECT_LE(uarch->processor_count, cpuinfo_get_processors_count()); -+ } -+ cpuinfo_deinitialize(); -+} -+ -+TEST(UARCH, non_zero_cores) { -+ ASSERT_TRUE(cpuinfo_initialize()); -+ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { -+ const cpuinfo_uarch_info* uarch = cpuinfo_get_uarch(i); -+ ASSERT_TRUE(uarch); -+ -+ EXPECT_NE(0, uarch->core_count); -+ } -+ cpuinfo_deinitialize(); -+} -+ -+TEST(UARCH, valid_cores) { -+ ASSERT_TRUE(cpuinfo_initialize()); -+ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { -+ const cpuinfo_uarch_info* uarch = cpuinfo_get_uarch(i); -+ ASSERT_TRUE(uarch); -+ -+ EXPECT_LE(uarch->core_count, cpuinfo_get_cores_count()); -+ } -+ cpuinfo_deinitialize(); -+} -+ - TEST(L1I_CACHES_COUNT, within_bounds) { - ASSERT_TRUE(cpuinfo_initialize()); - EXPECT_NE(0, cpuinfo_get_l1i_caches_count()); -diff --git a/test/mock/galaxy-s9-global.cc b/test/mock/galaxy-s9-global.cc -index 7a67129..6c72513 100644 ---- a/test/mock/galaxy-s9-global.cc -+++ b/test/mock/galaxy-s9-global.cc -@@ -207,7 +207,7 @@ TEST(CORES, uarch) { - case 5: - case 6: - case 7: -- ASSERT_EQ(cpuinfo_uarch_cortex_a55, cpuinfo_get_core(i)->uarch); -+ ASSERT_EQ(cpuinfo_uarch_cortex_a55r0, cpuinfo_get_core(i)->uarch); - break; - } - } -@@ -329,7 +329,7 @@ TEST(CLUSTERS, uarch) { - ASSERT_EQ(cpuinfo_uarch_exynos_m3, cpuinfo_get_cluster(i)->uarch); - break; - case 1: -- ASSERT_EQ(cpuinfo_uarch_cortex_a55, cpuinfo_get_cluster(i)->uarch); -+ ASSERT_EQ(cpuinfo_uarch_cortex_a55r0, cpuinfo_get_cluster(i)->uarch); - break; - } - } -diff --git a/test/mock/galaxy-s9-us.cc b/test/mock/galaxy-s9-us.cc -index 6df7f3c..ceea969 100644 ---- a/test/mock/galaxy-s9-us.cc -+++ b/test/mock/galaxy-s9-us.cc -@@ -168,7 +168,7 @@ TEST(CORES, uarch) { - case 5: - case 6: - case 7: -- ASSERT_EQ(cpuinfo_uarch_cortex_a55, cpuinfo_get_core(i)->uarch); -+ ASSERT_EQ(cpuinfo_uarch_cortex_a55r0, cpuinfo_get_core(i)->uarch); - break; - } - } -@@ -283,7 +283,7 @@ TEST(CLUSTERS, uarch) { - ASSERT_EQ(cpuinfo_uarch_cortex_a75, cpuinfo_get_cluster(i)->uarch); - break; - case 1: -- ASSERT_EQ(cpuinfo_uarch_cortex_a55, cpuinfo_get_cluster(i)->uarch); -+ ASSERT_EQ(cpuinfo_uarch_cortex_a55r0, cpuinfo_get_cluster(i)->uarch); - break; - } - } -@@ -817,4 +817,4 @@ int main(int argc, char* argv[]) { - cpuinfo_initialize(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); --} -\ No newline at end of file -+} -diff --git a/tools/cpu-info.c b/tools/cpu-info.c -index 7fa5187..7963c00 100644 ---- a/tools/cpu-info.c -+++ b/tools/cpu-info.c -@@ -14,6 +14,8 @@ static const char* vendor_to_string(enum cpuinfo_vendor vendor) { - return "Intel"; - case cpuinfo_vendor_amd: - return "AMD"; -+ case cpuinfo_vendor_hygon: -+ return "Hygon"; - case cpuinfo_vendor_arm: - return "ARM"; - case cpuinfo_vendor_qualcomm: -@@ -161,6 +163,8 @@ static const char* uarch_to_string(enum cpuinfo_uarch uarch) { - return "Cortex-A35"; - case cpuinfo_uarch_cortex_a53: - return "Cortex-A53"; -+ case cpuinfo_uarch_cortex_a55r0: -+ return "Cortex-A55r0"; - case cpuinfo_uarch_cortex_a55: - return "Cortex-A55"; - case cpuinfo_uarch_cortex_a57: -@@ -223,6 +227,10 @@ static const char* uarch_to_string(enum cpuinfo_uarch uarch) { - return "Vortex"; - case cpuinfo_uarch_tempest: - return "Tempest"; -+ case cpuinfo_uarch_lightning: -+ return "Lightning"; -+ case cpuinfo_uarch_thunder: -+ return "Thunder"; - case cpuinfo_uarch_thunderx: - return "ThunderX"; - case cpuinfo_uarch_thunderx2: -@@ -253,6 +261,17 @@ int main(int argc, char** argv) { - printf("\t%"PRIu32": %s\n", i, cpuinfo_get_package(i)->name); - } - #endif -+ printf("Microarchitectures:\n"); -+ for (uint32_t i = 0; i < cpuinfo_get_uarchs_count(); i++) { -+ const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(i); -+ const char* uarch_string = uarch_to_string(uarch_info->uarch); -+ if (uarch_string == NULL) { -+ printf("\t%"PRIu32"x Unknown (0x%08"PRIx32"\n", -+ uarch_info->core_count, (uint32_t) uarch_info->uarch); -+ } else { -+ printf("\t%"PRIu32"x %s\n", uarch_info->core_count, uarch_string); -+ } -+ } - printf("Cores:\n"); - for (uint32_t i = 0; i < cpuinfo_get_cores_count(); i++) { - const struct cpuinfo_core* core = cpuinfo_get_core(i); -@@ -277,17 +296,17 @@ int main(int argc, char** argv) { - } - } - printf("Logical processors"); -- #if defined(__linux__) -- printf(" (System ID)"); -- #endif -- printf(":\n"); -+ #if defined(__linux__) -+ printf(" (System ID)"); -+ #endif -+ printf(":\n"); - for (uint32_t i = 0; i < cpuinfo_get_processors_count(); i++) { - const struct cpuinfo_processor* processor = cpuinfo_get_processor(i); -- printf("\t%"PRIu32"", i); -+ printf("\t%"PRIu32"", i); - -- #if defined(__linux__) -- printf(" (%"PRId32")", processor->linux_id); -- #endif -+ #if defined(__linux__) -+ printf(" (%"PRId32")", processor->linux_id); -+ #endif - - #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 - printf(": APIC ID 0x%08"PRIx32"\n", processor->apic_id); diff --git a/third_party/cpuinfo/workspace.bzl b/third_party/cpuinfo/workspace.bzl index 77aecf5a9a9..922ab022486 100644 --- a/third_party/cpuinfo/workspace.bzl +++ b/third_party/cpuinfo/workspace.bzl @@ -2,20 +2,14 @@ load("//third_party:repo.bzl", "third_party_http_archive") -# Sanitize a dependency so that it works correctly from code that includes -# TensorFlow as a submodule. -def clean_dep(dep): - return str(Label(dep)) - def repo(): third_party_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-d6c0f915ee737f961915c9d17f1679b6777af207", - sha256 = "146fc61c3cf63d7d88db963876929a4d373f621fb65568b895efa0857f467770", + strip_prefix = "cpuinfo-0cc563acb9baac39f2c1349bc42098c4a1da59e3", + sha256 = "80625d0b69a3d69b70c2236f30db2c542d0922ccf9bb51a61bc39c49fac91a35", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pytorch/cpuinfo/archive/d6c0f915ee737f961915c9d17f1679b6777af207.tar.gz", - "https://github.com/pytorch/cpuinfo/archive/d6c0f915ee737f961915c9d17f1679b6777af207.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pytorch/cpuinfo/archive/0cc563acb9baac39f2c1349bc42098c4a1da59e3.tar.gz", + "https://github.com/pytorch/cpuinfo/archive/0cc563acb9baac39f2c1349bc42098c4a1da59e3.tar.gz", ], build_file = "//third_party/cpuinfo:BUILD.bazel", - patch_file = clean_dep("//third_party/cpuinfo:cpuinfo.patch"), ) From 9478afb61cd1756387be7e0748c62cd68578e2aa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 10:55:04 -0700 Subject: [PATCH 423/492] Remove usages of cusparse gtsv* PiperOrigin-RevId: 302470219 Change-Id: Idaa6bfaefa7f29f92525109f5170315b2d312901 --- tensorflow/core/kernels/cuda_sparse.cc | 60 ------------------- tensorflow/core/kernels/cuda_sparse.h | 31 ---------- .../kernels/tridiagonal_solve_op_gpu.cu.cc | 15 +---- 3 files changed, 1 insertion(+), 105 deletions(-) diff --git a/tensorflow/core/kernels/cuda_sparse.cc b/tensorflow/core/kernels/cuda_sparse.cc index 7485bef45a2..9d4ddc13d0d 100644 --- a/tensorflow/core/kernels/cuda_sparse.cc +++ b/tensorflow/core/kernels/cuda_sparse.cc @@ -200,66 +200,6 @@ Status GpuSparse::Initialize() { // Check the actual declarations in the cusparse.h header file. //============================================================================= -template -static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle, - int m, int n, const Scalar* dl, const Scalar* d, - const Scalar* du, Scalar* B, int ldb) { - TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), - AsCudaComplex(d), AsCudaComplex(du), - AsCudaComplex(B), ldb)); - return Status::OK(); -} - -#define GTSV_INSTANCE(Scalar, sparse_prefix) \ - template <> \ - Status GpuSparse::Gtsv(int m, int n, const Scalar* dl, \ - const Scalar* d, const Scalar* du, Scalar* B, \ - int ldb) const { \ - DCHECK(initialized_); \ - return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *gpusparse_handle_, m, n, \ - dl, d, du, B, ldb); \ - } - -TF_CALL_LAPACK_TYPES(GTSV_INSTANCE); - -#define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \ - template <> \ - Status GpuSparse::GtsvNoPivot(int m, int n, const Scalar* dl, \ - const Scalar* d, const Scalar* du, \ - Scalar* B, int ldb) const { \ - DCHECK(initialized_); \ - return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), \ - *gpusparse_handle_, m, n, dl, d, du, B, ldb); \ - } - -TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE); - -template -static inline Status GtsvStridedBatchImpl(SparseFn op, - cusparseHandle_t cusparse_handle, - int m, const Scalar* dl, - const Scalar* d, const Scalar* du, - Scalar* x, int batchCount, - int batchStride) { - TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl), - AsCudaComplex(d), AsCudaComplex(du), - AsCudaComplex(x), batchCount, batchStride)); - return Status::OK(); -} - -#define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \ - template <> \ - Status GpuSparse::GtsvStridedBatch( \ - int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \ - int batchCount, int batchStride) const { \ - DCHECK(initialized_); \ - return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix), \ - *gpusparse_handle_, m, dl, d, du, x, \ - batchCount, batchStride); \ - } - -TF_CALL_LAPACK_TYPES(GTSV_STRIDED_BATCH_INSTANCE); - template static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle, int m, int n, const Scalar* dl, const Scalar* d, diff --git a/tensorflow/core/kernels/cuda_sparse.h b/tensorflow/core/kernels/cuda_sparse.h index f96a9b2187c..5dd62037ff0 100644 --- a/tensorflow/core/kernels/cuda_sparse.h +++ b/tensorflow/core/kernels/cuda_sparse.h @@ -190,37 +190,6 @@ class GpuSparse { // Wrappers for cuSparse start here. // - // Solves tridiagonal system of equations. - // Note: Cuda Toolkit 9.0+ has better-performing gtsv2 routine. gtsv will be - // removed in Cuda Toolkit 11.0. - // See: https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv - // Returns Status::OK() if the kernel was launched successfully. - template - Status Gtsv(int m, int n, const Scalar *dl, const Scalar *d, const Scalar *du, - Scalar *B, int ldb) const; - - // Solves tridiagonal system of equations without pivoting. - // Note: Cuda Toolkit 9.0+ has better-performing gtsv2_nopivot routine. - // gtsv_nopivot will be removed in Cuda Toolkit 11.0. - // See: - // https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv_nopivot - // Returns Status::OK() if the kernel was launched successfully. - template - Status GtsvNoPivot(int m, int n, const Scalar *dl, const Scalar *d, - const Scalar *du, Scalar *B, int ldb) const; - - // Solves a batch of tridiagonal systems of equations. Doesn't support - // multiple right-hand sides per each system. Doesn't do pivoting. - // Note: Cuda Toolkit 9.0+ has better-performing gtsv2StridedBatch routine. - // gtsvStridedBatch will be removed in Cuda Toolkit 11.0. - // See: - // https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsvstridedbatch - // Returns Status::OK() if the kernel was launched successfully. - template - Status GtsvStridedBatch(int m, const Scalar *dl, const Scalar *d, - const Scalar *du, Scalar *x, int batchCount, - int batchStride) const; - // Solves tridiagonal system of equations. // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2 template diff --git a/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc b/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc index 3825e29189a..089fa8c040f 100644 --- a/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc +++ b/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc @@ -200,13 +200,6 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp { const Scalar* superdiag, const Scalar* diag, const Scalar* subdiag, Scalar* rhs, const int num_eqs, const int num_rhs) const { -#if CUDA_VERSION < 9000 - auto function = - pivoting_ ? &GpuSparse::Gtsv : &GpuSparse::GtsvNoPivot; - OP_REQUIRES_OK( - context, (cusparse_solver.get()->*function)( - num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs)); -#else auto buffer_function = pivoting_ ? &GpuSparse::Gtsv2BufferSizeExt : &GpuSparse::Gtsv2NoPivotBufferSizeExt; @@ -225,7 +218,6 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp { OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)( num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs, buffer)); -#endif // CUDA_VERSION < 9000 } void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals, @@ -318,11 +310,7 @@ class TridiagonalSolveOpGpu : public OpKernel { std::unique_ptr cusparse_solver(new GpuSparse(context)); OP_REQUIRES_OK(context, cusparse_solver->Initialize()); -#if CUDA_VERSION < 9000 - OP_REQUIRES_OK(context, cusparse_solver->GtsvStridedBatch( - matrix_size, subdiag, diag, superdiag, x, - batch_size, matrix_size)); -#else + size_t buffer_size; OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt( matrix_size, subdiag, diag, superdiag, x, @@ -335,7 +323,6 @@ class TridiagonalSolveOpGpu : public OpKernel { OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch( matrix_size, subdiag, diag, superdiag, x, batch_size, matrix_size, buffer)); -#endif // CUDA_VERSION < 9000 } void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs, From f1913bcaaaee0c667ac9cf97f54a71e148ebcab5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 11:00:23 -0700 Subject: [PATCH 424/492] add python tracer profiler interface to OSS. PiperOrigin-RevId: 302471475 Change-Id: I9bacb612f9e4d8d490663bff9de5087bc73e5280 --- tensorflow/core/profiler/internal/cpu/BUILD | 21 +++ .../profiler/internal/cpu/python_tracer.cc | 127 ++++++++++++++++++ .../profiler/internal/profiler_interface.h | 3 + tensorflow/python/BUILD | 1 + 4 files changed, 152 insertions(+) create mode 100644 tensorflow/core/profiler/internal/cpu/python_tracer.cc diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD index fe028d85cf7..d81e5d82be5 100644 --- a/tensorflow/core/profiler/internal/cpu/BUILD +++ b/tensorflow/core/profiler/internal/cpu/BUILD @@ -57,3 +57,24 @@ tf_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "python_tracer", + srcs = ["python_tracer.cc"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/internal:profiler_factory", + "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_utils", + "//tensorflow/python/profiler/internal:python_hooks", + "@com_google_absl//absl/strings", + ], + alwayslink = True, +) diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc new file mode 100644 index 00000000000..e6a910ccc69 --- /dev/null +++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/strings/str_split.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env_time.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/profiler/internal/profiler_factory.h" +#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/python/profiler/internal/python_hooks.h" + +namespace tensorflow { +namespace profiler { +namespace { + +// This profiler interface enable Python function call tracing, and forward +// the events to TraceMeRecorder. +class PythonTracer : public ProfilerInterface { + public: + explicit PythonTracer() = default; + ~PythonTracer() override; + + // Starts recording TraceMes. + Status Start() override; + + // Stops recording TraceMes. + Status Stop() override; + + // Populates user traces and thread names in response. + // The user traces and thread names are in no particular order. + Status CollectData(RunMetadata* run_metadata) override; + + Status CollectData(XSpace* space) override; + + DeviceType GetDeviceType() override { return DeviceType::kCpu; } + + private: + bool recording_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(PythonTracer); +}; + +PythonTracer::~PythonTracer() { + Stop().IgnoreError(); + PythonHooks::GetSingleton()->Finalize(); +} + +Status PythonTracer::Start() { + if (recording_) { + return errors::Internal("TraceMeRecorder already started"); + } + VLOG(1) << __FUNCTION__; + recording_ = true; + PythonHooks::GetSingleton()->Start(); + return Status::OK(); +} + +Status PythonTracer::Stop() { + if (!recording_) { + return errors::Internal("TraceMeRecorder not started"); + } + VLOG(1) << __FUNCTION__; + PythonHooks::GetSingleton()->Stop(); + recording_ = false; + return Status::OK(); +} + +Status PythonTracer::CollectData(RunMetadata* run_metadata) { + // This ProfilerInterface rely on HostTracer to serialize its trace. + // Make sure unpaired traceme don't get recorded, because it will end up + // in the wrong threads. + // We had assumed HostTracer::Stop is called when ProfilerSession try to + // serialize PythonTracer. + PythonHooks::GetSingleton()->Finalize(); + return Status::OK(); +} + +Status PythonTracer::CollectData(XSpace* space) { + // This ProfilerInterface rely on HostTracer to serialize its trace. + // Make sure unpaired traceme don't get recorded, because it will end up + // in the wrong threads. + // We had assumed HostTracer::Stop is called when ProfilerSession try to + // serialize PythonTracer. + PythonHooks::GetSingleton()->Finalize(); + return Status::OK(); +} + +} // namespace + +// Not in anonymous namespace for testing purposes. +std::unique_ptr CreatePythonTracer( + const profiler::ProfilerOptions& options) { + if (!options.enable_python_tracer) return nullptr; + // This ProfilerInterface rely on TraceMeRecorder to be active. + if (options.host_tracer_level == 0) return nullptr; + return absl::make_unique(); +} + +auto register_python_tracer_factory = [] { + bool enable; + TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_OSS_PYTHON_TRACER", true, &enable)); + if (enable) { + RegisterProfilerFactory(&CreatePythonTracer); + } + return 0; +}(); + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h index 081054f03fd..9e7819de03a 100644 --- a/tensorflow/core/profiler/internal/profiler_interface.h +++ b/tensorflow/core/profiler/internal/profiler_interface.h @@ -38,6 +38,9 @@ struct ProfilerOptions { // Inexpensive ops are not traced by default. int host_tracer_level = 2; + + // Whether to enable python function calls tracer. + bool enable_python_tracer = false; }; // Interface for tensorflow profiler plugins. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 46696993f99..c784cb47611 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5858,6 +5858,7 @@ tf_py_wrap_cc( "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/platform:stacktrace_handler", "//tensorflow/core/profiler/internal:print_model_analysis", + "//tensorflow/core/profiler/internal/cpu:python_tracer", "//tensorflow/tools/graph_transforms:transform_graph_lib", "//tensorflow/lite/toco/python:toco_python_api", "//tensorflow/python/eager:pywrap_tfe_lib", From 71ec683225c589880a90483b217b3ea0a33f408f Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 23 Mar 2020 11:29:17 -0700 Subject: [PATCH 425/492] Use JIT device for option in xla-legalize-tf-with-tf2xla pass XlaCompiler::Options uses the JIT device so using JIT device for the pass option avoids intermediate conversion. PiperOrigin-RevId: 302478250 Change-Id: I933616826e962f0b06800ab41cf58a4966d95593 --- .../mlir/xla/tests/legalize-tf-with-tf2xla.mlir | 2 +- .../xla/transforms/legalize_tf_with_tf2xla.cc | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 53df0d0a0fc..e271340f247 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU_JIT %s | FileCheck %s --dump-input-on-failure // INVALID_DEVICE: tf-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 327040a087f..2e9f1d61dd4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -81,10 +81,10 @@ static bool IsOpWhitelisted(Operation* op) { isa(op) || isa(op); } -static llvm::Optional GetJitDevice( +static llvm::Optional GetExecutionDevice( const std::string& device_type, const Location& loc) { - if (device_type == "XLA_CPU") return absl::string_view("XLA_CPU_JIT"); - if (device_type == "TPU") return absl::string_view("XLA_TPU_JIT"); + if (device_type == "XLA_CPU_JIT") return std::string("XLA_CPU"); + if (device_type == "XLA_TPU_JIT") return std::string("TPU"); // TODO(hinsu): Support GPU device along with a test for it. emitError(loc) << "unsupported device for legalization with tf2xla kernels: " @@ -94,10 +94,10 @@ static llvm::Optional GetJitDevice( static std::unique_ptr CreateDeviceMgr( const std::string& device_type, const Location& loc) { - auto jit_device_or = GetJitDevice(device_type, loc); - if (!jit_device_or) return nullptr; + auto device_or = GetExecutionDevice(device_type, loc); + if (!device_or) return nullptr; - auto* factory = tensorflow::DeviceFactory::GetFactory(device_type); + auto* factory = tensorflow::DeviceFactory::GetFactory(*device_or); if (!factory) { emitError(loc) << "failed to create DeviceFactory for device: " << device_type; @@ -113,7 +113,7 @@ static std::unique_ptr CreateDeviceMgr( } auto device = absl::make_unique( - tensorflow::SessionOptions(), tensorflow::DeviceType(*jit_device_or)); + tensorflow::SessionOptions(), tensorflow::DeviceType(device_type)); return absl::make_unique(std::move(device)); } @@ -376,7 +376,7 @@ class LegalizeTF : public FunctionPass { Option device_type_{ *this, "device-type", llvm::cl::desc("XLA device type for execution of TensorFlow ops. " - "Supports XLA_CPU and TPU for now.")}; + "Supports XLA_CPU_JIT and XLA_TPU_JIT for now.")}; }; static PassRegistration pass( From 0bea7faef9c354a0aedaba810accc9b6a9e841ed Mon Sep 17 00:00:00 2001 From: Jiho Choi Date: Mon, 23 Mar 2020 11:29:57 -0700 Subject: [PATCH 426/492] Rename CreateTfMetricsDbFromHloMetricsDb to CreateTfMetricsDbFromDeviceOpMetricsDb. PiperOrigin-RevId: 302478422 Change-Id: I3b60a7d521295c62acbefd9bf7c9ffe97067ce0f --- .../convert/op_stats_to_overview_page.cc | 2 +- .../profiler/convert/op_stats_to_tf_stats.cc | 2 +- .../profiler/utils/op_metrics_db_utils.cc | 53 ++++++++++--------- .../core/profiler/utils/op_metrics_db_utils.h | 6 +-- 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc index 3021b482d55..cb93d440249 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc @@ -149,7 +149,7 @@ OverviewPageRecommendation ComputeGenericRecommendation( OverviewPageAnalysis ComputeAnalysisResult(const OpStats& op_stats) { OverviewPageAnalysis analysis; - OpMetricsDb metrics_db = CreateTfMetricsDbFromHloMetricsDb( + OpMetricsDb metrics_db = CreateTfMetricsDbFromDeviceOpMetricsDb( op_stats.device_op_metrics_db(), /*with_idle=*/false); uint64 total_device_time_ps = metrics_db.total_time_ps(); constexpr int kNumTopOpsShown = 10; diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc index 08e73b2fea4..4ce14f54d47 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc @@ -88,7 +88,7 @@ TfStatsTable GenerateTfStatsTable(const OpMetricsDb& host_tf_metrics_db, TfStatsDatabase ConvertOpStatsToTfStats(const OpStats& op_stats) { const OpMetricsDb& host_tf_metrics_db = op_stats.host_op_metrics_db(); OpMetricsDb device_tf_metrics_db = - CreateTfMetricsDbFromHloMetricsDb(op_stats.device_op_metrics_db()); + CreateTfMetricsDbFromDeviceOpMetricsDb(op_stats.device_op_metrics_db()); double ridge_point = op_stats.perf_env().ridge_point(); TfStatsDatabase tf_stats_db; *tf_stats_db.mutable_with_idle() = diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc index 9a8973a2424..dee33f1d1ce 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc @@ -30,9 +30,9 @@ class DeviceTfOpMetricsDbBuilder : public OpMetricsDbBuilder { explicit DeviceTfOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} - void UpdateTfOpMetricsWithHloOpMetrics(absl::string_view tf_op_name, - absl::string_view tf_op_type, - const OpMetrics& hlo_op_metrics) { + void UpdateTfOpMetricsWithDeviceOpMetrics( + absl::string_view tf_op_name, absl::string_view tf_op_type, + const OpMetrics& device_op_metrics) { OpMetrics* tf_op_metrics = OpMetricsDbBuilder::LookupOrInsertNewOpMetrics( /*hlo_module_id=*/0, tf_op_name); if (tf_op_metrics->category().empty()) { @@ -40,23 +40,23 @@ class DeviceTfOpMetricsDbBuilder : public OpMetricsDbBuilder { tf_op_type == kUnknownOp ? "Unknown" : string(tf_op_type)); } // The occurrences of a TF-op is the maximum among the occurrences of all - // HLO-ops that it contains. - tf_op_metrics->set_occurrences( - std::max(tf_op_metrics->occurrences(), hlo_op_metrics.occurrences())); + // device ops that it contains. + tf_op_metrics->set_occurrences(std::max(tf_op_metrics->occurrences(), + device_op_metrics.occurrences())); tf_op_metrics->set_time_ps(tf_op_metrics->time_ps() + - hlo_op_metrics.time_ps()); + device_op_metrics.time_ps()); tf_op_metrics->set_self_time_ps(tf_op_metrics->self_time_ps() + - hlo_op_metrics.self_time_ps()); - tf_op_metrics->set_flops(tf_op_metrics->flops() + hlo_op_metrics.flops()); + device_op_metrics.self_time_ps()); + tf_op_metrics->set_flops(tf_op_metrics->flops() + + device_op_metrics.flops()); tf_op_metrics->set_bytes_accessed(tf_op_metrics->bytes_accessed() + - hlo_op_metrics.bytes_accessed()); + device_op_metrics.bytes_accessed()); } }; } // namespace -OpMetricsDbBuilder::OpMetricsDbBuilder(OpMetricsDb* db) - : db_(db) { +OpMetricsDbBuilder::OpMetricsDbBuilder(OpMetricsDb* db) : db_(db) { DCHECK_NE(db_, nullptr); DCHECK_EQ(db_->metrics_db_size(), 0); } @@ -92,28 +92,29 @@ void AddIdleOp(OpMetricsDb* db) { metrics->set_self_time_ps(idle_time_ps); } -OpMetricsDb CreateTfMetricsDbFromHloMetricsDb(const OpMetricsDb& hlo_metrics_db, - bool with_idle) { +OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( + const OpMetricsDb& device_op_metrics_db, bool with_idle) { OpMetricsDb tf_op_metrics_db; DeviceTfOpMetricsDbBuilder builder(&tf_op_metrics_db); - for (const auto& hlo_op_metrics : hlo_metrics_db.metrics_db()) { - if (!hlo_op_metrics.provenance().empty()) { - TfOp tf_op = ParseTfOpFullname(hlo_op_metrics.provenance()); - builder.UpdateTfOpMetricsWithHloOpMetrics(tf_op.name, tf_op.type, - hlo_op_metrics); + for (const auto& device_op_metrics : device_op_metrics_db.metrics_db()) { + if (!device_op_metrics.provenance().empty()) { + TfOp tf_op = ParseTfOpFullname(device_op_metrics.provenance()); + builder.UpdateTfOpMetricsWithDeviceOpMetrics(tf_op.name, tf_op.type, + device_op_metrics); } else { - DCHECK_EQ(hlo_op_metrics.name(), "IDLE"); + DCHECK_EQ(device_op_metrics.name(), "IDLE"); if (with_idle) { - builder.UpdateTfOpMetricsWithHloOpMetrics("IDLE", "IDLE", - hlo_op_metrics); + builder.UpdateTfOpMetricsWithDeviceOpMetrics("IDLE", "IDLE", + device_op_metrics); } } } - tf_op_metrics_db.set_total_op_time_ps(hlo_metrics_db.total_op_time_ps()); + tf_op_metrics_db.set_total_op_time_ps( + device_op_metrics_db.total_op_time_ps()); - tf_op_metrics_db.set_total_time_ps(with_idle - ? hlo_metrics_db.total_time_ps() - : hlo_metrics_db.total_op_time_ps()); + tf_op_metrics_db.set_total_time_ps( + with_idle ? device_op_metrics_db.total_time_ps() + : device_op_metrics_db.total_op_time_ps()); return tf_op_metrics_db; } diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/tensorflow/core/profiler/utils/op_metrics_db_utils.h index 8cd4737359c..a1f1a045cdd 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.h +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -67,9 +67,9 @@ uint64 IdleTimePs(const OpMetricsDb& metrics_db); // must have been set. void AddIdleOp(OpMetricsDb* db); -// Converts from Hlo-op metrics to Tf-op metrics. -OpMetricsDb CreateTfMetricsDbFromHloMetricsDb(const OpMetricsDb& hlo_metrics_db, - bool with_idle = true); +// Converts from the device op metrics to Tf-op metrics. +OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( + const OpMetricsDb& device_op_metrics_db, bool with_idle = true); } // namespace profiler } // namespace tensorflow From 46f815205fcccb024af606a11f9698abe3ff8648 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 11:42:38 -0700 Subject: [PATCH 427/492] Make python monitoring API visible publicly for end users to instrument their python code. PiperOrigin-RevId: 302481767 Change-Id: I778cfbda13cb83181c876c4da399c1721ff5b693 --- tensorflow/python/eager/BUILD | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 55bc5942253..315e85feb3d 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -178,10 +178,7 @@ py_library( name = "monitoring", srcs = ["monitoring.py"], srcs_version = "PY2AND3", - visibility = [ - "//tensorflow:internal", - "//third_party/py/tf_agents:__subpackages__", - ], + visibility = ["//visibility:public"], deps = [ "//tensorflow/python:c_api_util", "//tensorflow/python:pywrap_tf_session", From c6ec2565db03ff3a02e696ca889bf0cef104eed7 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Mon, 23 Mar 2020 11:49:54 -0700 Subject: [PATCH 428/492] Disable flaky rebatch_dataset_test PiperOrigin-RevId: 302483379 Change-Id: I1cd480ab6917f58899568d5ded8c07fb5625e22b --- tensorflow/python/data/experimental/kernel_tests/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index c9facc23ae5..a8e87326743 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -580,6 +580,10 @@ tf_py_test( name = "rebatch_dataset_test", size = "small", srcs = ["rebatch_dataset_test.py"], + tags = [ + "manual", # TODO(b/152215379) + "notap", + ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", From 828fe43cf3aaf1a60ed0b263d4af4b3442e79121 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Mon, 23 Mar 2020 12:19:59 -0700 Subject: [PATCH 429/492] Add BatchMatMul built-in op for TF Lite PiperOrigin-RevId: 302489633 Change-Id: Ie4ad2abad069b1e5bc654fc51caf0bcbc99b714f --- tensorflow/lite/builtin_ops.h | 1 + .../lite/core/api/flatbuffer_conversions.cc | 1 + tensorflow/lite/kernels/BUILD | 14 ++ tensorflow/lite/kernels/batch_matmul.cc | 156 ++++++++++++++++ tensorflow/lite/kernels/batch_matmul_test.cc | 169 ++++++++++++++++++ tensorflow/lite/kernels/builtin_op_kernels.h | 1 + tensorflow/lite/kernels/internal/BUILD | 2 + .../kernels/internal/optimized/batch_matmul.h | 118 ++++++++++++ .../kernels/internal/reference/batch_matmul.h | 105 +++++++++++ tensorflow/lite/kernels/register.cc | 1 + tensorflow/lite/kernels/register_ref.cc | 1 + tensorflow/lite/schema/schema.fbs | 9 +- tensorflow/lite/schema/schema_generated.h | 132 ++++++++++++-- tensorflow/lite/toco/tflite/op_version.cc | 1 + 14 files changed, 699 insertions(+), 12 deletions(-) create mode 100644 tensorflow/lite/kernels/batch_matmul.cc create mode 100644 tensorflow/lite/kernels/batch_matmul_test.cc create mode 100644 tensorflow/lite/kernels/internal/optimized/batch_matmul.h create mode 100644 tensorflow/lite/kernels/internal/reference/batch_matmul.h diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index c4e2907ffa9..85140289ac1 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -152,6 +152,7 @@ typedef enum { kTfLiteBuiltinSelectV2 = 123, kTfLiteBuiltinDensify = 124, kTfLiteBuiltinSegmentSum = 125, + kTfLiteBuiltinBatchMatmul = 126, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 6621e608d35..83b4159cce0 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -840,6 +840,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SCATTER_ND: case BuiltinOperator_DENSIFY: case BuiltinOperator_SEGMENT_SUM: + case BuiltinOperator_BATCH_MATMUL: break; } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 1f04cc3ee47..872d3c0822b 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -426,6 +426,7 @@ cc_library( "arg_min_max.cc", "audio_spectrogram.cc", "basic_rnn.cc", + "batch_matmul.cc", "batch_to_space_nd.cc", "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", @@ -849,6 +850,19 @@ cc_test( ], ) +cc_test( + name = "batch_matmul_test", + size = "small", + srcs = ["batch_matmul_test.cc"], + deps = [ + ":builtin_ops", + ":test_main", + ":test_util", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "cast_test", size = "small", diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc new file mode 100644 index 00000000000..30bc624a218 --- /dev/null +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -0,0 +1,156 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace batch_matmul { + +static const int kInputLHSTensor = 0; +static const int kInputRHSTensor = 1; +static const int kOutputTensor = 0; + +// This file has two implementations of Transpose. +enum KernelType { + kReference, + kGenericOptimized, +}; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const RuntimeShape& extended_lhs_shape, + const RuntimeShape& extended_rhs_shape, + int output_rank, TfLiteTensor* output) { + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); + // Fill in any broadcast dimensions. + for (int i = 0; i < output_rank - 2; ++i) { + const int lhs_dim = extended_lhs_shape.Dims(i); + const int rhs_dim = extended_rhs_shape.Dims(i); + int broadcast_dim = lhs_dim; + if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) { + broadcast_dim = rhs_dim; + } + output_shape->data[i] = broadcast_dim; + } + // Fill in the matmul dimensions. + output_shape->data[output_rank - 2] = + extended_lhs_shape.Dims(output_rank - 2); + output_shape->data[output_rank - 1] = + extended_rhs_shape.Dims(output_rank - 1); + TfLiteStatus stat = context->ResizeTensor(context, output, output_shape); + return stat; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* lhs_data = GetInput(context, node, kInputLHSTensor); + const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, lhs_data->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, rhs_data->type, kTfLiteFloat32); + // Support dimensions between 2 and 5, inclusive. + TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2); + TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5); + TF_LITE_ENSURE(context, NumDimensions(rhs_data) >= 2); + TF_LITE_ENSURE(context, NumDimensions(rhs_data) <= 5); + + const int lhs_rank = NumDimensions(lhs_data); + const int rhs_rank = NumDimensions(rhs_data); + const int output_rank = std::max(lhs_rank, rhs_rank); + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data)); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data)); + + // Ensure any batch dimensions obey broacasting rules. + for (int i = 0; i < output_rank - 2; ++i) { + const int lhs_dim = extended_lhs_shape.Dims(i); + const int rhs_dim = extended_rhs_shape.Dims(i); + if (lhs_dim != rhs_dim) { + if (lhs_dim != 1) { + TF_LITE_ENSURE_EQ(context, rhs_dim, 1); + } + } + } + // Ensure other dimensions work for matrix multiplication. + TF_LITE_ENSURE_EQ(context, extended_lhs_shape.Dims(output_rank - 1), + extended_rhs_shape.Dims(output_rank - 2)); + return ResizeOutputTensor(context, extended_lhs_shape, extended_rhs_shape, + output_rank, output); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor); + const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + switch (lhs->type) { + case kTfLiteFloat32: + if (kernel_type == kGenericOptimized) { + optimized_ops::BatchMatMul( + GetTensorShape(lhs), GetTensorData(lhs), GetTensorShape(rhs), + GetTensorData(rhs), GetTensorShape(output), + GetTensorData(output), + CpuBackendContext::GetFromContext(context)); + } else { + reference_ops::BatchMatMul( + GetTensorShape(lhs), GetTensorData(lhs), GetTensorShape(rhs), + GetTensorData(rhs), GetTensorShape(output), + GetTensorData(output)); + } + break; + default: + TF_LITE_KERNEL_LOG(context, + "Currently BatchMatMul doesn't support type: %s", + TfLiteTypeGetName(lhs->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace batch_matmul + +TfLiteRegistration* Register_BATCH_MATMUL_REF() { + static TfLiteRegistration r = {nullptr, nullptr, batch_matmul::Prepare, + batch_matmul::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_matmul::Prepare, + batch_matmul::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_MATMUL() { + return Register_BATCH_MATMUL_GENERIC_OPTIMIZED(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/batch_matmul_test.cc b/tensorflow/lite/kernels/batch_matmul_test.cc new file mode 100644 index 00000000000..9b33ebef542 --- /dev/null +++ b/tensorflow/lite/kernels/batch_matmul_test.cc @@ -0,0 +1,169 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class BatchMatMulOpModel : public SingleOpModel { + public: + BatchMatMulOpModel(const TensorData& lhs, const TensorData& rhs) { + lhs_id_ = AddInput(lhs); + rhs_id_ = AddInput(rhs); + output_id_ = AddOutput(lhs.type); + SetBuiltinOp(BuiltinOperator_BATCH_MATMUL, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)}); + } + + int lhs() const { return lhs_id_; } + int rhs() const { return rhs_id_; } + std::vector GetOutput() { return ExtractVector(output_id_); } + std::vector GetOutputShape() { return GetTensorShape(output_id_); } + + protected: + int lhs_id_; + int rhs_id_; + int output_id_; +}; + +TEST(BatchMatMulOpModelTest, Float32Test_Simple) { + BatchMatMulOpModel model({TensorType_FLOAT32, {1, 2, 3}}, + {TensorType_FLOAT32, {1, 3, 4}}); + model.PopulateTensor(model.lhs(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f, + 104.0f, 257.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) { + BatchMatMulOpModel model({TensorType_FLOAT32, {2, 2, 3}}, + {TensorType_FLOAT32, {2, 3, 4}}); + model.PopulateTensor(model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f, + 104.0f, 257.0f, 482.0f, 662.0f, 554.0f, 761.0f, + 626.0f, 860.0f, 698.0f, 959.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) { + BatchMatMulOpModel model({TensorType_FLOAT32, {2, 2, 3}}, + {TensorType_FLOAT32, {3, 4}}); + model.PopulateTensor(model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f, + 104.0f, 257.0f, 194.0f, 266.0f, 266.0f, 365.0f, + 338.0f, 464.0f, 410.0f, 563.0f})); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) { + BatchMatMulOpModel model({TensorType_FLOAT32, {2, 1, 3, 2}}, + {TensorType_FLOAT32, {3, 2, 4}}); + model.PopulateTensor(model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + model.Invoke(); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray( + {23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f, + 127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f, + 123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f, + 71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f, + 303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f, + 181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f, + 233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f, + 449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f, + 485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f})); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFiveD) { + BatchMatMulOpModel model({TensorType_FLOAT32, {1, 2, 1, 3, 2}}, + {TensorType_FLOAT32, {3, 2, 4}}); + model.PopulateTensor(model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + model.Invoke(); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray( + {23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f, + 127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f, + 123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f, + 71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f, + 303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f, + 181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f, + 233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f, + 449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f, + 485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f})); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 3, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) { + BatchMatMulOpModel model({TensorType_FLOAT32, {4, 5}}, + {TensorType_FLOAT32, {3, 1, 5, 2}}); + model.PopulateTensor( + model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); + model.PopulateTensor( + model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}); + + model.Invoke(); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({145.0f, 370.0f, 595.0f, 820.0f, 220.0f, 570.0f, + 920.0f, 1270.0f, 295.0f, 770.0f, 1245.0f, 1720.0f, + 370.0f, 970.0f, 1570.0f, 2170.0f, 445.0f, 1170.0f, + 1895.0f, 2620.0f, 520.0f, 1370.0f, 2220.0f, 3070.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1, 4, 2})); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index e5f00ddd229..1c73f06487b 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -36,6 +36,7 @@ TfLiteRegistration* Register_ARG_MAX(); TfLiteRegistration* Register_ARG_MIN(); TfLiteRegistration* Register_AVERAGE_POOL_2D(); TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); +TfLiteRegistration* Register_BATCH_MATMUL(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN(); TfLiteRegistration* Register_CAST(); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index c9e6c082b53..e7612e39c71 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -213,6 +213,7 @@ cc_library( name = "optimized_base", srcs = [], hdrs = [ + "optimized/batch_matmul.h", "optimized/depthwiseconv_3x3_filter_common.h", "optimized/depthwiseconv_float.h", "optimized/depthwiseconv_multithread.h", @@ -416,6 +417,7 @@ cc_library( hdrs = [ "reference/add.h", "reference/arg_min_max.h", + "reference/batch_matmul.h", "reference/binary_function.h", "reference/ceil.h", "reference/comparisons.h", diff --git a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h new file mode 100644 index 00000000000..03cef848026 --- /dev/null +++ b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h @@ -0,0 +1,118 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, + const RuntimeShape& rhs_shape, const float* rhs_data, + const RuntimeShape& output_shape, float* output_data, + CpuBackendContext* context) { + using ::tflite::cpu_backend_gemm::Gemm; + using ::tflite::cpu_backend_gemm::GemmParams; + using ::tflite::cpu_backend_gemm::MatrixParams; + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(5, lhs_shape); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(5, rhs_shape); + + // Determine which dimension is the broadcast dimension. + auto broadcast_dim = [](int lhs_dim, int rhs_dim) { + if (lhs_dim == rhs_dim) return lhs_dim; + if (lhs_dim == 1) return rhs_dim; + TFLITE_DCHECK_EQ(rhs_dim, 1); + return lhs_dim; + }; + + // Compute the "extent" for iterating on this dimension. + // If we are broadcasting, then don't advance (i.e return 0). + auto extent = [](const RuntimeShape& shape, int x) { + if (shape.Dims(x) == 1) { + return 0; + } + int prod = 1; + for (int i = x + 1; i < shape.DimensionsCount(); ++i) { + prod *= shape.Dims(i); + } + return prod; + }; + + const int batch_dim0 = + broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); + const int batch_dim1 = + broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); + const int batch_dim2 = + broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); + + const int lhs_ext0 = extent(extended_lhs_shape, 0); + const int lhs_ext1 = extent(extended_lhs_shape, 1); + const int lhs_ext2 = extent(extended_lhs_shape, 2); + const int rhs_ext0 = extent(extended_rhs_shape, 0); + const int rhs_ext1 = extent(extended_rhs_shape, 1); + const int rhs_ext2 = extent(extended_rhs_shape, 2); + + // Set params for each matrix multiply. + const int lhs_rows = extended_lhs_shape.Dims(3); + const int rhs_cols = extended_rhs_shape.Dims(4); + const int accum_depth = extended_lhs_shape.Dims(4); + + MatrixParams lhs_params; + lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + lhs_params.rows = lhs_rows; + lhs_params.cols = accum_depth; + + MatrixParams rhs_params; + rhs_params.order = cpu_backend_gemm::Order::kColMajor; + rhs_params.rows = accum_depth; + rhs_params.cols = rhs_cols; + + MatrixParams dst_params; + dst_params.order = cpu_backend_gemm::Order::kColMajor; + dst_params.rows = lhs_rows; + dst_params.cols = rhs_cols; + + for (int b0 = 0; b0 < batch_dim0; ++b0) { + const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); + const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); + for (int b1 = 0; b1 < batch_dim1; ++b1) { + const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; + const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; + for (int b2 = 0; b2 < batch_dim2; ++b2) { + const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; + const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; + float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + + b1 * batch_dim2 + b2) * + lhs_rows * rhs_cols; + GemmParams gemm_params; + cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2, + dst_params, out_ptr, gemm_params, context); + } + } + } +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_ diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h new file mode 100644 index 00000000000..4fe84aa3388 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, + const RuntimeShape& rhs_shape, const float* rhs_data, + const RuntimeShape& output_shape, float* output_data) { + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(5, lhs_shape); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(5, rhs_shape); + + // Determine which dimension is the broadcast dimension. + auto broadcast_dim = [](int lhs_dim, int rhs_dim) { + if (lhs_dim == rhs_dim) return lhs_dim; + if (lhs_dim == 1) return rhs_dim; + TFLITE_DCHECK_EQ(rhs_dim, 1); + return lhs_dim; + }; + + // Compute the "extent" for iterating on this dimension. + // If we are broadcasting, then don't advance (i.e return 0). + auto extent = [](const RuntimeShape& shape, int x) { + if (shape.Dims(x) == 1) { + return 0; + } + int prod = 1; + for (int i = x + 1; i < shape.DimensionsCount(); ++i) { + prod *= shape.Dims(i); + } + return prod; + }; + + const int batch_dim0 = + broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); + const int batch_dim1 = + broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); + const int batch_dim2 = + broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); + + const int lhs_ext0 = extent(extended_lhs_shape, 0); + const int lhs_ext1 = extent(extended_lhs_shape, 1); + const int lhs_ext2 = extent(extended_lhs_shape, 2); + const int rhs_ext0 = extent(extended_rhs_shape, 0); + const int rhs_ext1 = extent(extended_rhs_shape, 1); + const int rhs_ext2 = extent(extended_rhs_shape, 2); + + // Set params for each matrix multiply. + const int lhs_rows = extended_lhs_shape.Dims(3); + const int rhs_cols = extended_rhs_shape.Dims(4); + const int accum_depth = extended_lhs_shape.Dims(4); + + for (int b0 = 0; b0 < batch_dim0; ++b0) { + const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); + const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); + for (int b1 = 0; b1 < batch_dim1; ++b1) { + const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; + const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; + for (int b2 = 0; b2 < batch_dim2; ++b2) { + const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; + const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; + float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + + b1 * batch_dim2 + b2) * + lhs_rows * rhs_cols; + for (int j = 0; j < rhs_cols; ++j) { + for (int i = 0; i < lhs_rows; ++i) { + float total = 0.f; + for (int k = 0; k < accum_depth; ++k) { + total += + lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k]; + } + int idx = lhs_rows * j + i; + out_ptr[idx] = total; + } + } + } + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_ diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index cf9f8b99ee4..1e148a0c1f5 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -278,6 +278,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND()); AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY()); AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM()); + AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 2381e8f8c9d..426f8a8e896 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -134,6 +134,7 @@ TfLiteRegistration* Register_HARD_SWISH_REF(); TfLiteRegistration* Register_DEPTH_TO_SPACE_REF(); TfLiteRegistration* Register_SELECT_V2(); TfLiteRegistration* Register_SEGMENT_SUM(); +TfLiteRegistration* Register_BATCH_MATMUL_REF(); namespace { diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 5c12d74c067..24cd73eef7a 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -346,7 +346,8 @@ enum BuiltinOperator : byte { SCATTER_ND = 122, SELECT_V2 = 123, DENSIFY = 124, - SEGMENT_SUM = 125 + SEGMENT_SUM = 125, + BATCH_MATMUL = 126 } @@ -451,7 +452,8 @@ union BuiltinOptions { ScatterNdOptions, SelectV2Options, DensifyOptions, - SegmentSumOptions + SegmentSumOptions, + BatchMatMulOptions } enum Padding : byte { SAME, VALID } @@ -945,6 +947,9 @@ table DensifyOptions { table SegmentSumOptions { } +table BatchMatMulOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 8caf2409b96..609eac198fb 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -346,6 +346,9 @@ struct DensifyOptionsT; struct SegmentSumOptions; struct SegmentSumOptionsT; +struct BatchMatMulOptions; +struct BatchMatMulOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -771,11 +774,12 @@ enum BuiltinOperator { BuiltinOperator_SELECT_V2 = 123, BuiltinOperator_DENSIFY = 124, BuiltinOperator_SEGMENT_SUM = 125, + BuiltinOperator_BATCH_MATMUL = 126, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_SEGMENT_SUM + BuiltinOperator_MAX = BuiltinOperator_BATCH_MATMUL }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -902,13 +906,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] { BuiltinOperator_SCATTER_ND, BuiltinOperator_SELECT_V2, BuiltinOperator_DENSIFY, - BuiltinOperator_SEGMENT_SUM + BuiltinOperator_SEGMENT_SUM, + BuiltinOperator_BATCH_MATMUL }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[127] = { + static const char * const names[128] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1035,13 +1040,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "SELECT_V2", "DENSIFY", "SEGMENT_SUM", + "BATCH_MATMUL", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_SEGMENT_SUM)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BATCH_MATMUL)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -1148,11 +1154,12 @@ enum BuiltinOptions { BuiltinOptions_SelectV2Options = 98, BuiltinOptions_DensifyOptions = 99, BuiltinOptions_SegmentSumOptions = 100, + BuiltinOptions_BatchMatMulOptions = 101, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_SegmentSumOptions + BuiltinOptions_MAX = BuiltinOptions_BatchMatMulOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1254,13 +1261,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] { BuiltinOptions_ScatterNdOptions, BuiltinOptions_SelectV2Options, BuiltinOptions_DensifyOptions, - BuiltinOptions_SegmentSumOptions + BuiltinOptions_SegmentSumOptions, + BuiltinOptions_BatchMatMulOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[102] = { + static const char * const names[103] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1362,13 +1370,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "SelectV2Options", "DensifyOptions", "SegmentSumOptions", + "BatchMatMulOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_SegmentSumOptions)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BatchMatMulOptions)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -1777,6 +1786,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2609,6 +2622,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_SegmentSumOptions ? reinterpret_cast(value) : nullptr; } + tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() { + return type == BuiltinOptions_BatchMatMulOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() const { + return type == BuiltinOptions_BatchMatMulOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -9109,6 +9130,46 @@ inline flatbuffers::Offset CreateSegmentSumOptions( flatbuffers::Offset CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct BatchMatMulOptionsT : public flatbuffers::NativeTable { + typedef BatchMatMulOptions TableType; + BatchMatMulOptionsT() { + } +}; + +struct BatchMatMulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BatchMatMulOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BatchMatMulOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BatchMatMulOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit BatchMatMulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BatchMatMulOptionsBuilder &operator=(const BatchMatMulOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateBatchMatMulOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + BatchMatMulOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; tflite::BuiltinOperator builtin_code; @@ -9545,6 +9606,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const { return builtin_options_type() == tflite::BuiltinOptions_SegmentSumOptions ? static_cast(builtin_options()) : nullptr; } + const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -9981,6 +10045,10 @@ template<> inline const tflite::SegmentSumOptions *Operator::builtin_options_as< return builtin_options_as_SegmentSumOptions(); } +template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as() const { + return builtin_options_as_BatchMatMulOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -13392,6 +13460,29 @@ inline flatbuffers::Offset CreateSegmentSumOptions(flatbuffer _fbb); } +inline BatchMatMulOptionsT *BatchMatMulOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BatchMatMulOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void BatchMatMulOptions::UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset BatchMatMulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBatchMatMulOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BatchMatMulOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBatchMatMulOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -14197,6 +14288,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -14615,6 +14710,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -15021,6 +15120,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(value); + return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -15427,6 +15530,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new tflite::SegmentSumOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_BatchMatMulOptions: { + value = new tflite::BatchMatMulOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -15934,6 +16041,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index b375706f6c7..bbec4f91646 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -57,6 +57,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kDiv, 1}, "1.6.0"}, {{OperatorType::kBatchToSpaceND, 1}, "1.6.0"}, {{OperatorType::kBatchToSpaceND, 2}, "1.14.0"}, + {{OperatorType::kBatchMatMul, 1}, kPendingReleaseOpVersion}, {{OperatorType::kCast, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 2}, "1.14.0"}, From 54d0a198a7aeec70da6f5d3e0338940585e3810f Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 23 Mar 2020 12:42:40 -0700 Subject: [PATCH 430/492] Bump open source LLVM revision to a711a3a46039154c38eade8bef1138b77fdb05ee PiperOrigin-RevId: 302494638 Change-Id: I9ab5519582eb32e58e8482a1ead63036a1134b8e --- tensorflow/compiler/mlir/BUILD | 2 +- tensorflow/compiler/mlir/xla/BUILD | 2 +- .../xla/transforms/lhlo_legalize_to_affine.cc | 2 +- .../mlir_gpu/experimental/conv_emitter/BUILD | 2 +- .../experimental/conv_emitter/conv_emitter.cc | 2 +- tensorflow/workspace.bzl | 4 +- third_party/mlir/BUILD | 292 ++++++++++++++++-- third_party/mlir/test.BUILD | 22 +- 8 files changed, 290 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 7ad8a80695d..b776ee77493 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -89,7 +89,7 @@ cc_library( "//tensorflow/compiler/mlir/xla:xla_lower", "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", "//tensorflow/compiler/mlir/xla:xla_test_passes", - "@llvm-project//mlir:AffineOps", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:QuantOps", ], ) diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 2a76a75da50..6597eeaa967 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -221,7 +221,7 @@ cc_library( "//tensorflow/compiler/xla:status", "@com_google_absl//absl/memory", "@llvm-project//llvm:support", - "@llvm-project//mlir:AffineOps", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index b01f573a9a5..15b91edbd8d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -16,7 +16,7 @@ limitations under the License. // This file implements logic for lowering LHLO dialect to Affine dialect. #include "absl/memory/memory.h" -#include "mlir/Dialect/AffineOps/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD index 2eac0018a4b..ab02cfae96b 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD @@ -32,7 +32,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/types:span", "@llvm-project//llvm:support", - "@llvm-project//mlir:AffineOps", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:TransformUtils", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc index 79e90b74208..5ec8d3bb334 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc @@ -30,7 +30,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/AffineOps/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index d89b96d3fc8..61527d8c976 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -589,8 +589,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "4a7f2032a350bc7eefd26709563f65216df3e2ce" - LLVM_SHA256 = "e43e9067427a331542733d5863b2e94369ed95b59af9999dcabdd5315ff46373" + LLVM_COMMIT = "a711a3a46039154c38eade8bef1138b77fdb05ee" + LLVM_SHA256 = "b070be6653ac61e42649afcda0a02dee027cd610c1e2929663ca67fdb1301679" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index fbbdb73cecc..819aed65efe 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -139,11 +139,15 @@ filegroup( ], ) +##---------------------------------------------------------------------------## +# Affine dialect. +##---------------------------------------------------------------------------## + filegroup( name = "AffineOpsTdFiles", srcs = [ - "include/mlir/Dialect/AffineOps/AffineOps.td", - "include/mlir/Dialect/AffineOps/AffineOpsBase.td", + "include/mlir/Dialect/Affine/IR/AffineOps.td", + "include/mlir/Dialect/Affine/IR/AffineOpsBase.td", "include/mlir/Interfaces/LoopLikeInterface.td", "include/mlir/Interfaces/SideEffects.td", ":OpBaseTdFiles", @@ -156,24 +160,112 @@ gentbl( tbl_outs = [ ( "-gen-op-decls", - "include/mlir/Dialect/AffineOps/AffineOps.h.inc", + "include/mlir/Dialect/Affine/IR/AffineOps.h.inc", ), ( "-gen-op-defs", - "include/mlir/Dialect/AffineOps/AffineOps.cpp.inc", + "include/mlir/Dialect/Affine/IR/AffineOps.cpp.inc", ), ( "-gen-dialect-decls", - "include/mlir/Dialect/AffineOps/AffineOpsDialect.h.inc", + "include/mlir/Dialect/Affine/IR/AffineOpsDialect.h.inc", ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/AffineOps/AffineOps.td", + td_file = "include/mlir/Dialect/Affine/IR/AffineOps.td", td_srcs = [ ":AffineOpsTdFiles", ], ) +##---------------------------------------------------------------------------## +# AVX512 dialect. +##---------------------------------------------------------------------------## + +filegroup( + name = "AVX512TdFiles", + srcs = [ + "include/mlir/Dialect/AVX512/AVX512.td", + "include/mlir/Dialect/LLVMIR/LLVMOpBase.td", + "include/mlir/IR/OpBase.td", + "include/mlir/Interfaces/SideEffects.td", + ], +) + +gentbl( + name = "AVX512IncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-dialect-decls -dialect=avx512", + "include/mlir/Dialect/AVX512/AVX512Dialect.h.inc", + ), + ( + "-gen-op-decls", + "include/mlir/Dialect/AVX512/AVX512.h.inc", + ), + ( + "-gen-op-defs", + "include/mlir/Dialect/AVX512/AVX512.cpp.inc", + ), + ( + "-gen-op-doc", + "g3doc/Dialects/AVX512/AVX512.md", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/AVX512/AVX512.td", + td_srcs = [ + ":AVX512TdFiles", + ], +) + +cc_library( + name = "AVX512", + srcs = [ + "lib/Dialect/AVX512/IR/AVX512Dialect.cpp", + ], + hdrs = [ + "include/mlir/Dialect/AVX512/AVX512Dialect.h", + ], + includes = ["include"], + deps = [ + ":AVX512IncGen", + ":IR", + ":SideEffects", + ":VectorOps", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + ], +) + +cc_library( + name = "AVX512ToLLVM", + srcs = glob([ + "lib/Conversion/AVX512ToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/mlir/Conversion/AVX512ToLLVM/*.h", + ]), + includes = ["include"], + deps = [ + ":AVX512", + ":EDSC", + ":IR", + ":LLVMAVX512", + ":LLVMDialect", + ":LLVMTransforms", + ":Pass", + ":StandardOps", + ":Support", + ":Transforms", + ":VectorOps", + ":VectorToLLVM", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + ], +) + filegroup( name = "LoopOpsTdFiles", srcs = [ @@ -214,7 +306,7 @@ cc_library( hdrs = ["include/mlir/Dialect/LoopOps/Passes.h"], includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":IR", ":LoopOps", ":Pass", @@ -303,19 +395,19 @@ cc_library( ) cc_library( - name = "AffineOps", + name = "Affine", srcs = glob( [ - "lib/Dialect/AffineOps/*.cpp", - "lib/Dialect/AffineOps/*.h", - "lib/Dialect/AffineOps/EDSC/*.cpp", + "lib/Dialect/Affine/IR/*.cpp", + "lib/Dialect/Affine/IR/*.h", + "lib/Dialect/Affine/EDSC/*.cpp", ], ) + [ "include/mlir/Transforms/InliningUtils.h", ], hdrs = glob([ - "include/mlir/Dialect/AffineOps/*.h", - "include/mlir/Dialect/AffineOps/EDSC/*.h", + "include/mlir/Dialect/Affine/IR/*.h", + "include/mlir/Dialect/Affine/EDSC/*.h", ]), includes = ["include"], deps = [ @@ -330,6 +422,29 @@ cc_library( ], ) +cc_library( + name = "AffineTransforms", + srcs = glob([ + "lib/Dialect/Affine/Transforms/*.cpp", + ]), + hdrs = [ + "include/mlir/Dialect/Affine/Passes.h", + ], + includes = ["include"], + deps = [ + ":Affine", + ":Analysis", + ":IR", + ":LoopOps", + ":Pass", + ":StandardOps", + ":Support", + ":Transforms", + ":VectorOps", + "@llvm-project//llvm:support", + ], +) + cc_library( name = "AffineToStandardTransforms", srcs = glob([ @@ -339,7 +454,7 @@ cc_library( hdrs = glob(["include/mlir/Conversion/AffineToStandard/*.h"]), includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":IR", ":LoopOps", ":Pass", @@ -481,6 +596,28 @@ cc_library( ], ) +cc_library( + name = "StandardOpsTransforms", + srcs = glob( + [ + "lib/Dialect/StandardOps/Transforms/*.cpp", + "lib/Dialect/StandardOps/Transforms/*.h", + ], + ), + hdrs = glob([ + "include/mlir/Dialect/StandardOps/Transforms/*.h", + ]), + includes = ["include"], + deps = [ + ":Analysis", + ":ControlFlowInterfaces", + ":IR", + ":StandardOps", + ":Support", + "@llvm-project//llvm:support", + ], +) + cc_library( name = "VectorOps", srcs = glob( @@ -497,7 +634,7 @@ cc_library( ]), includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":Analysis", ":DialectUtils", ":EDSC", @@ -567,6 +704,94 @@ cc_library( ], ) +filegroup( + name = "LLVMAVX512TdFiles", + srcs = [ + "include/mlir/Dialect/LLVMIR/LLVMAVX512.td", + ":LLVMOpsTdFiles", + ], +) + +gentbl( + name = "LLVMAVX512IncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-dialect-decls -dialect=llvm_avx512", + "include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h.inc", + ), + ( + "-gen-op-decls", + "include/mlir/Dialect/LLVMIR/LLVMAVX512.h.inc", + ), + ( + "-gen-op-defs", + "include/mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc", + ), + ( + "-gen-op-doc", + "g3doc/Dialects/LLVMIR/LLVMAVX512.md", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/LLVMIR/LLVMAVX512.td", + td_srcs = [ + ":LLVMAVX512TdFiles", + ], +) + +cc_library( + name = "LLVMAVX512", + srcs = [ + "lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp", + ], + hdrs = [ + "include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h", + ], + includes = ["include"], + deps = [ + ":IR", + ":LLVMAVX512IncGen", + ":LLVMDialect", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + ], +) + +gentbl( + name = "LLVMAVX512ConversionIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-llvmir-conversions", + "include/mlir/Dialect/LLVMIR/LLVMAVX512Conversions.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/LLVMIR/LLVMAVX512.td", + td_srcs = [ + ":LLVMAVX512TdFiles", + ], +) + +cc_library( + name = "TargetLLVMAVX512Intr", + srcs = [ + "lib/Target/LLVMIR/LLVMAVX512Intr.cpp", + ], + includes = ["include"], + deps = [ + ":IR", + ":LLVMAVX512", + ":LLVMAVX512ConversionIncGen", + ":LLVMIRModuleTranslation", + ":Translation", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + ], + alwayslink = 1, +) + cc_library( name = "LLVMDialect", srcs = glob( @@ -575,6 +800,8 @@ cc_library( "lib/Dialect/LLVMIR/IR/*.h", ], exclude = [ + "lib/Dialect/LLVMIR/IR/*AVX512*.cpp", + "lib/Dialect/LLVMIR/IR/*AVX512*.h", "lib/Dialect/LLVMIR/IR/NVVM*.cpp", "lib/Dialect/LLVMIR/IR/NVVM*.h", "lib/Dialect/LLVMIR/IR/ROCDL*.cpp", @@ -586,6 +813,7 @@ cc_library( "include/mlir/Dialect/LLVMIR/*.h", ], exclude = [ + "include/mlir/Dialect/LLVMIR/*AVX512*.h", "include/mlir/Dialect/LLVMIR/NVVM*.h", "include/mlir/Dialect/LLVMIR/ROCDL*.h", ], @@ -1393,7 +1621,7 @@ cc_library( ]), includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":Analysis", ":ControlFlowInterfaces", ":IR", @@ -1474,7 +1702,7 @@ cc_library( ]), includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":Analysis", ":IR", ":LoopLikeInterface", @@ -1513,7 +1741,7 @@ cc_library( ], includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":AffineToStandardTransforms", ":GPUDialect", ":IR", @@ -1537,7 +1765,7 @@ cc_library( ], includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":GPUDialect", ":LoopOps", ":LoopsToGPU", @@ -1552,7 +1780,7 @@ cc_library( cc_library( name = "CFGTransforms", srcs = [ - "lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp", + "lib/Conversion/LoopToStandard/LoopToStandard.cpp", ], hdrs = [ "include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h", @@ -1573,7 +1801,7 @@ cc_library( cc_library( name = "LLVMTransforms", srcs = [ - "lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp", + "lib/Conversion/StandardToLLVM/StandardToLLVM.cpp", ], hdrs = [ "include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h", @@ -1763,7 +1991,7 @@ cc_library( ), includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":CallOpInterfaces", ":IR", ":LoopOps", @@ -1833,6 +2061,7 @@ cc_library( ":LLVMDialect", ":LLVMIRModuleTranslation", ":Support", + ":TargetLLVMAVX512Intr", ":Translation", "@llvm-project//llvm:core", "@llvm-project//llvm:ir_reader", @@ -1943,9 +2172,6 @@ cc_library( srcs = [ "lib/Support/MlirOptMain.cpp", ], - hdrs = [ - "include/mlir/Analysis/Passes.h", - ], includes = ["include"], deps = [ ":Analysis", @@ -1970,6 +2196,7 @@ cc_library( ":VectorToLLVM", ":VectorToLoops", "@llvm-project//llvm:support", + "@llvm-project//mlir/test:TestAffine", "@llvm-project//mlir/test:TestDialect", "@llvm-project//mlir/test:TestIR", "@llvm-project//mlir/test:TestPass", @@ -2026,7 +2253,10 @@ cc_library( ], defines = ["MLIR_CUDA_CONVERSIONS_ENABLED"], deps = [ - ":AffineOps", + ":AVX512", + ":AVX512ToLLVM", + ":Affine", + ":AffineTransforms", ":Analysis", ":FxpMathOps", ":GPUDialect", @@ -2037,6 +2267,7 @@ cc_library( ":GPUToVulkanTransforms", ":GPUTransforms", ":IR", + ":LLVMAVX512", ":LLVMDialect", ":LLVMIRTransforms", ":LinalgOps", @@ -2103,6 +2334,7 @@ cc_binary( ":Transforms", "@llvm-project//llvm:all_targets", "@llvm-project//llvm:support", + "@llvm-project//mlir/test:TestAffine", "@llvm-project//mlir/test:TestDialect", "@llvm-project//mlir/test:TestIR", "@llvm-project//mlir/test:TestPass", @@ -2661,7 +2893,7 @@ cc_library( ], includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":AffineToStandardTransforms", ":Analysis", ":CFGTransforms", @@ -2832,7 +3064,7 @@ cc_library( ]), includes = ["include"], deps = [ - ":AffineOps", + ":Affine", ":EDSC", ":IR", ":LLVMDialect", @@ -2867,7 +3099,7 @@ exports_files( "include/mlir/IR/OpBase.td", "include/mlir/Transforms/InliningUtils.h", ], - visibility = ["@llvm-project//mlir:friends"], + visibility = [":friends"], ) exports_files( @@ -2875,5 +3107,5 @@ exports_files( "include/mlir/Interfaces/InferTypeOpInterface.td", "include/mlir/Interfaces/LoopLikeInterface.td", ], - visibility = ["@llvm-project//mlir:friends"], + visibility = [":friends"], ) diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD index 3ba2bbf0fb1..5d569827860 100644 --- a/third_party/mlir/test.BUILD +++ b/third_party/mlir/test.BUILD @@ -153,16 +153,18 @@ cc_library( srcs = glob([ "lib/Transforms/*.cpp", ]), + defines = ["MLIR_CUDA_CONVERSIONS_ENABLED"], includes = ["lib/TestDialect"], deps = [ ":TestDialect", ":TestLinalgTransformPatternsIncGen", ":TestVectorTransformPatternsIncGen", "@llvm-project//llvm:support", - "@llvm-project//mlir:AffineOps", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:EDSC", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToCUDATransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", @@ -179,6 +181,24 @@ cc_library( ], ) +cc_library( + name = "TestAffine", + srcs = glob([ + "lib/Dialect/Affine/*.cpp", + ]), + deps = [ + "@llvm-project//llvm:support", + "@llvm-project//mlir:Affine", + "@llvm-project//mlir:AffineTransforms", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorOps", + ], +) + cc_library( name = "TestSPIRV", srcs = glob([ From 1b6491369cf517d81fa10d48be7df18717eafd5e Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Mon, 23 Mar 2020 12:56:39 -0700 Subject: [PATCH 431/492] Small fixes for mean and mul in Metal backend. PiperOrigin-RevId: 302498070 Change-Id: I97c6000bbdbf3187f18177e1524be7ea0d7f3b74 --- .../lite/delegates/gpu/metal/kernels/mean.cc | 2 +- .../lite/delegates/gpu/metal/kernels/mul.cc | 23 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc b/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc index 8c888d0bca1..20ad71eb123 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc @@ -62,7 +62,7 @@ std::string GetMeanCode() { for (int h = 0; h < params.src_size.y; h++) { const int buffer_index = (gid.z * params.src_size.y + h) * params.src_size.x + w; - sum += src_buffer[buffer_index]; + sum += float4(src_buffer[buffer_index]); } } sum /= size; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc b/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc index 15d03e103ca..21a04f2fc35 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc @@ -37,12 +37,13 @@ namespace gpu { namespace metal { namespace { -std::string GetMaxUnpoolingCode() { +std::string GetApplyMaskCode() { std::string shader_source = R"( #include using namespace metal; struct uniforms { - int4 src_size; + int4 src_0_size; + int4 src_1_size; int4 dst_size; }; @@ -55,12 +56,12 @@ std::string GetMaxUnpoolingCode() { if (X >= params.dst_size.x || Y >= params.dst_size.y) { return; } - int src_0_index = (gid.z * params.src_size.y + static_cast(gid.y)) * - params.src_size.x + static_cast(gid.x); + int src_0_index = (gid.z * params.src_0_size.y + static_cast(gid.y)) * + params.src_0_size.x + static_cast(gid.x); int src_1_index = 0; if (params.dst_size.z == 1) { // [H, W, C] x [H, W, 0][0] - src_1_index = static_cast(gid.y) * params.src_size.x + + src_1_index = static_cast(gid.y) * params.src_1_size.x + static_cast(gid.x); } else if (params.src_0_size.y == params.src_1_size.y && params.src_0_size.x == params.src_1_size.x) { @@ -68,11 +69,13 @@ std::string GetMaxUnpoolingCode() { src_1_index = src_0_index; } else { // [H, W, C] x [0, 0, C] - src_1_index = gid.z * params.src_size.y * params.src_size.x ; + src_1_index = gid.z * params.src_1_size.y * params.src_1_size.x ; } - FLT4 value = src_buffer_0[src_index] * src_buffer_1[src_1_index]; + FLT4 value = src_buffer_0[src_0_index] * src_buffer_1[src_1_index]; + int linear_index = (gid.z * params.dst_size.y + static_cast(gid.y)) * + params.dst_size.x + static_cast(gid.x); $2 - output_buffer[linear_index] = value; + dst_buffer[linear_index] = value; } )"; return shader_source; @@ -86,7 +89,7 @@ std::vector ApplyMask(int id, ValueId input_id_0, auto desc = std::make_shared(); desc->id = id; desc->is_linkable = false; - desc->shader_source = GetMaxUnpoolingCode(); + desc->shader_source = GetApplyMaskCode(); desc->input_buffers = { {input_id_0, "device FLT4* const src_buffer_0"}, // data @@ -94,7 +97,7 @@ std::vector ApplyMask(int id, ValueId input_id_0, }; desc->output_buffer = { - output_id, "device FLT4* output_buffer", + output_id, "device FLT4* dst_buffer", [input_id_0, input_id_1](const std::map& buffers) { return buffers.find(input_id_0)->second; }}; From 7f6bd0a85c7ad46fccae70feaf194d15afd1b564 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 13:09:53 -0700 Subject: [PATCH 432/492] Revert attempt at implementing "safe gradients" for functions with singularities. After some debate it was decided that not suppressing NaNs is safer. PiperOrigin-RevId: 302501155 Change-Id: I3715d5b54c710419f628c9a43af26f2e85faf8f4 --- .../python/kernel_tests/cwise_ops_test.py | 57 ----- tensorflow/python/ops/math_grad.py | 203 ++++-------------- 2 files changed, 46 insertions(+), 214 deletions(-) diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 743a413d08c..8d29b464f85 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors @@ -28,9 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -1235,59 +1232,5 @@ class PolyvalTest(test.TestCase): math_ops.polyval(coeffs, x) -class SingularGradientOpTest(test.TestCase): - - @test_util.run_deprecated_v1 - def testGradientAtSingularity(self): - if not compat.forward_compatible(2020, 6, 14): - self.skipTest("Skipping test for future functionality.") - - ops_and_singularity = [ - (gen_math_ops.reciprocal, (0.,)), - (gen_math_ops.rsqrt, (0.,)), - (gen_math_ops.sqrt, (0.,)), - (gen_math_ops.sqrt_grad, ( - 0., - 0., - )), - (gen_math_ops.reciprocal_grad, ( - 1., - 0., - )), - (gen_math_ops.tan, (np.pi / 2,)), - (gen_math_ops.log, (0.,)), - (gen_math_ops.log1p, (-1.,)), - (gen_math_ops.acosh, (0.,)), - (gen_math_ops.asin, (1.,)), - (gen_math_ops.acos, (1.,)), - (gen_math_ops.atan2, (0., 0.)), - (gen_math_ops.div, (1., 0.)), - (gen_math_ops.div_no_nan, (1., 0.)), - (gen_math_ops.real_div, (1., 0.)), - (math_ops.pow, (0., -1.)), - ] - for op, singularity in ops_and_singularity: - for dtype in (dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64, - dtypes_lib.complex64, dtypes_lib.complex128): - if dtype.is_complex and op in [ - gen_math_ops.asin, gen_math_ops.acos, gen_math_ops.atan2 - ]: - continue - if dtype == dtypes_lib.half and op in [ - gen_math_ops.acosh, gen_math_ops.asin, gen_math_ops.acos, - gen_math_ops.atan2 - ]: - continue - with self.cached_session(): - print("op = ", op, ", singularity = ", singularity, ", type = ", - dtype) - args = [constant_op.constant(s, dtype=dtype) for s in singularity] - grad_y = constant_op.constant(0, dtype=dtype) - y = op(*args) - g = gradients_impl.gradients(y, args, grad_ys=grad_y) - g_val = self.evaluate(g) - self.assertAllEqual(g_val, np.zeros(len(singularity))) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index cc39861f91e..8ce35de006a 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np from tensorflow.python.client import pywrap_tf_session as c_api -from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -597,12 +596,8 @@ def _SqrtGradGrad(op, grad): a = op.inputs[0] y = op.outputs[0] # y = 0.5 * b / conj(a) with ops.control_dependencies([grad]): - if compat.forward_compatible(2020, 6, 14): - ga = gen_math_ops.xdivy(grad, a) - return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga - else: - ga = grad / a - return -math_ops.conj(ga) * y, 0.5 * ga + ga = grad / a + return -math_ops.conj(ga) * y, 0.5 * ga @ops.RegisterGradient("Rsqrt") @@ -631,10 +626,7 @@ def _ExpGrad(op, grad): y = op.outputs[0] # y = e^x with ops.control_dependencies([grad]): y = math_ops.conj(y) - if compat.forward_compatible(2020, 6, 14): - return math_ops.mul_no_nan(y, grad) - else: - return grad * y + return grad * y @ops.RegisterGradient("Expm1") @@ -644,10 +636,7 @@ def _Expm1Grad(op, grad): with ops.control_dependencies([grad]): x = math_ops.conj(x) y = math_ops.exp(x) - if compat.forward_compatible(2020, 6, 14): - return math_ops.mul_no_nan(y, grad) - else: - return grad * y + return grad * y @ops.RegisterGradient("Log") @@ -656,10 +645,7 @@ def _LogGrad(op, grad): x = op.inputs[0] with ops.control_dependencies([grad]): x = math_ops.conj(x) - if compat.forward_compatible(2020, 6, 14): - return gen_math_ops.xdivy(grad, x) - else: - return grad * math_ops.reciprocal(x) + return grad * math_ops.reciprocal(x) @ops.RegisterGradient("Log1p") @@ -668,10 +654,7 @@ def _Log1pGrad(op, grad): x = op.inputs[0] with ops.control_dependencies([grad]): x = math_ops.conj(x) - if compat.forward_compatible(2020, 6, 14): - return gen_math_ops.xdivy(grad, 1 + x) - else: - return grad * math_ops.reciprocal(1 + x) + return grad * math_ops.reciprocal(1 + x) @ops.RegisterGradient("Xlogy") @@ -767,10 +750,7 @@ def _AcoshGrad(op, grad): y = op.outputs[0] with ops.control_dependencies([grad]): y = math_ops.conj(y) - if compat.forward_compatible(2020, 6, 14): - return math_ops.xdivy(grad, math_ops.sinh(y)) - else: - return grad / math_ops.sinh(y) + return grad / math_ops.sinh(y) @ops.RegisterGradient("Atanh") @@ -838,10 +818,7 @@ def _LgammaGrad(op, grad): x = op.inputs[0] with ops.control_dependencies([grad]): x = math_ops.conj(x) - if compat.forward_compatible(2020, 6, 14): - return math_ops.mul_no_nan(math_ops.digamma(x), grad) - else: - return grad * math_ops.digamma(x) + return grad * math_ops.digamma(x) @ops.RegisterGradient("Digamma") @@ -851,10 +828,7 @@ def _DigammaGrad(op, grad): with ops.control_dependencies([grad]): x = math_ops.conj(x) partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x) - if compat.forward_compatible(2020, 6, 14): - return math_ops.mul_no_nan(partial_x, grad) - else: - return grad * partial_x + return grad * partial_x @ops.RegisterGradient("Dawsn") @@ -908,10 +882,7 @@ def _BesselI0eGrad(op, grad): y = op.outputs[0] with ops.control_dependencies([grad]): partial_x = (math_ops.bessel_i1e(x) - math_ops.sign(x) * y) - if compat.forward_compatible(2020, 6, 14): - return math_ops.mul_no_nan(partial_x, grad) - else: - return grad * partial_x + return grad * partial_x @ops.RegisterGradient("BesselI1e") @@ -932,10 +903,7 @@ def _BesselI1eGrad(op, grad): dy_dx = math_ops.bessel_i0e(safe_x) - y * ( math_ops.sign(safe_x) + math_ops.reciprocal(safe_x)) dy_dx = array_ops.where_v2(x_is_not_tiny, dy_dx, 0.5 + zeros) - if compat.forward_compatible(2020, 6, 14): - return math_ops.mul_no_nan(dy_dx, grad) - else: - return grad * dy_dx + return grad * dy_dx @ops.RegisterGradient("Igamma") @@ -953,15 +921,8 @@ def _IgammaGrad(op, grad): # and Gamma'(a) can grow large. partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) - if compat.forward_compatible(2020, 6, 14): - return (array_ops.reshape( - math_ops.reduce_sum(math_ops.mul_no_nan(partial_a, grad), ra), sa), - array_ops.reshape( - math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx), - sx)) - else: - return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), - array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) + return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Igammac") @@ -993,18 +954,10 @@ def _BetaincGrad(op, grad): partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) + math_ops.xlogy(a - 1, x) - log_beta) - # TODO(b/36815900): Mark None return values as NotImplemented - if compat.forward_compatible(2020, 6, 14): - return ( - None, # da - None, # db - array_ops.reshape( - math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx), sx)) - else: - return ( - None, # da - None, # db - array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) + return ( + None, # da + None, # db + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Zeta") @@ -1022,15 +975,8 @@ def _ZetaGrad(op, grad): x = math_ops.conj(x) q = math_ops.conj(q) partial_q = -x * math_ops.zeta(x + 1, q) - # TODO(b/36815900): Mark None return values as NotImplemented - if compat.forward_compatible(2020, 6, 14): - return (None, - array_ops.reshape( - math_ops.reduce_sum(math_ops.mul_no_nan(partial_q, grad), rq), - sq)) - else: - return (None, - array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) + return (None, + array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) @ops.RegisterGradient("Polygamma") @@ -1048,15 +994,8 @@ def _PolygammaGrad(op, grad): n = math_ops.conj(n) x = math_ops.conj(x) partial_x = math_ops.polygamma(n + 1, x) - # TODO(b/36815900): Mark None return values as NotImplemented - if compat.forward_compatible(2020, 6, 14): - return (None, - array_ops.reshape( - math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx), - sx)) - else: - return (None, - array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) + return (None, + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Sigmoid") @@ -1110,10 +1049,7 @@ def _TanGrad(op, grad): x = math_ops.conj(x) secx = math_ops.reciprocal(math_ops.cos(x)) secx2 = math_ops.square(secx) - if compat.forward_compatible(2020, 6, 14): - return math_ops.mul_no_nan(secx2, grad) - else: - return secx2 * grad + return secx2 * grad @ops.RegisterGradient("Asin") @@ -1125,11 +1061,8 @@ def _AsinGrad(op, grad): x2 = math_ops.square(x) one = constant_op.constant(1, dtype=grad.dtype) den = math_ops.sqrt(math_ops.subtract(one, x2)) - if compat.forward_compatible(2020, 6, 14): - return math_ops.xdivy(grad, den) - else: - inv = math_ops.reciprocal(den) - return grad * inv + inv = math_ops.reciprocal(den) + return grad * inv @ops.RegisterGradient("Acos") @@ -1141,11 +1074,8 @@ def _AcosGrad(op, grad): x2 = math_ops.square(x) one = constant_op.constant(1, dtype=grad.dtype) den = math_ops.sqrt(math_ops.subtract(one, x2)) - if compat.forward_compatible(2020, 6, 14): - return -math_ops.xdivy(grad, den) - else: - inv = math_ops.reciprocal(den) - return -grad * inv + inv = math_ops.reciprocal(den) + return -grad * inv @ops.RegisterGradient("Atan") @@ -1166,10 +1096,7 @@ def _Atan2Grad(op, grad): y = op.inputs[0] x = op.inputs[1] with ops.control_dependencies([grad]): - if compat.forward_compatible(2020, 6, 14): - grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y))) - else: - grad_inv = grad / (math_ops.square(x) + math_ops.square(y)) + grad_inv = grad / (math_ops.square(x) + math_ops.square(y)) return x * grad_inv, -y * grad_inv @@ -1328,20 +1255,11 @@ def _DivGrad(op, grad): rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) - if compat.forward_compatible(2020, 6, 14): - return (array_ops.reshape( - math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx), - array_ops.reshape( - math_ops.reduce_sum( - math_ops.mul_no_nan( - math_ops.divide(math_ops.divide(-x, y), y), grad), ry), - sy)) - else: - return (array_ops.reshape( - math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx), - array_ops.reshape( - math_ops.reduce_sum( - grad * math_ops.divide(math_ops.divide(-x, y), y), ry), sy)) + return (array_ops.reshape( + math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.divide(math_ops.divide(-x, y), y), ry), sy)) @ops.RegisterGradient("FloorDiv") @@ -1381,21 +1299,11 @@ def _RealDivGrad(op, grad): rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) - if compat.forward_compatible(2020, 6, 14): - return (array_ops.reshape( - math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx), - array_ops.reshape( - math_ops.reduce_sum( - math_ops.mul_no_nan( - math_ops.realdiv(math_ops.realdiv(-x, y), y), grad), - ry), sy)) - else: - return (array_ops.reshape( - math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx), - array_ops.reshape( - math_ops.reduce_sum( - grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), - sy)) + return (array_ops.reshape( + math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) @ops.RegisterGradient("DivNoNan") @@ -1408,21 +1316,12 @@ def _DivNoNanGrad(op, grad): rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) - if compat.forward_compatible(2020, 6, 14): - return (array_ops.reshape( - math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), - array_ops.reshape( - math_ops.reduce_sum( - math_ops.mul_no_nan( - math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), - grad), ry), sy)) - else: - return (array_ops.reshape( - math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), - array_ops.reshape( - math_ops.reduce_sum( - grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), - ry), sy)) + return (array_ops.reshape( + math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), + ry), sy)) @ops.RegisterGradient("Pow") @@ -1430,7 +1329,6 @@ def _PowGrad(op, grad): """Returns grad * (y*x^(y-1), z*log(x)).""" x = op.inputs[0] y = op.inputs[1] - use_mul_no_nan = compat.forward_compatible(2020, 6, 14) skip_input_indices = None try: skip_input_indices = op.skip_input_indices @@ -1440,10 +1338,7 @@ def _PowGrad(op, grad): y): x = math_ops.conj(x) y = math_ops.conj(y) - if use_mul_no_nan: - return gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), None - else: - return grad * y * math_ops.pow(x, y - 1), None + return grad * y * math_ops.pow(x, y - 1), None except AttributeError: # No gradient skipping, so do the full gradient computation @@ -1455,10 +1350,7 @@ def _PowGrad(op, grad): y = math_ops.conj(y) if skip_input_indices is None or 0 not in skip_input_indices: - if use_mul_no_nan: - gx = gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad) - else: - gx = grad * y * math_ops.pow(x, y - 1) + gx = grad * y * math_ops.pow(x, y - 1) if must_reduce_x: gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) else: @@ -1476,10 +1368,7 @@ def _PowGrad(op, grad): mask = x > 0 safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) - if use_mul_no_nan: - gy = gen_math_ops.mul_no_nan(z * log_x, grad) - else: - gy = grad * z * log_x + gy = grad * z * log_x if must_reduce_y: gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy) else: From c45d13f92b921dae044a5639f8b1bc560fe9b71c Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Mon, 23 Mar 2020 13:15:07 -0700 Subject: [PATCH 433/492] [tf.data service] disable data_service_test on windows. This is temporary while we figure out why the test fails on windows. PiperOrigin-RevId: 302502113 Change-Id: I1bfc9231d3181d711278f0be413ccd585279eb30 --- tensorflow/core/data/service/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index b0b6ce3f3e7..46665846acc 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -264,6 +264,7 @@ cc_library( tf_cc_test( name = "data_service_test", srcs = ["data_service_test.cc"], + tags = ["no_windows"], deps = [ ":compression_utils", ":grpc_master_impl", From e7bbb424808eb7ebbeb959b993496adafd024609 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 23 Mar 2020 13:26:06 -0700 Subject: [PATCH 434/492] Remove the unnecessary type check from legacy RNN code. PiperOrigin-RevId: 302504219 Change-Id: I3724c843c1fabbb9d8bbf52073b1b3417eae5fde --- tensorflow/python/ops/rnn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index adda1f5e564..031e807e8b0 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -23,7 +23,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -161,7 +160,6 @@ def _is_keras_rnn_cell(rnn_cell): # Keras cells never had zero_state method, which was from the original # interface from TF RNN cell. return (not isinstance(rnn_cell, rnn_cell_impl.RNNCell) and - isinstance(rnn_cell, base_layer.Layer) and getattr(rnn_cell, "zero_state", None) is None) From 28c22acb3a4bbc056b92d38524dc20c76b1b3c0b Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 23 Mar 2020 13:26:12 -0700 Subject: [PATCH 435/492] Remove unneeded symbols from python layers. The input_spec should be exported via tf_export already in keras. PiperOrigin-RevId: 302504252 Change-Id: I73aa4cc9d8adadb23b1f5aa455ddfd1008e2080e --- tensorflow/python/BUILD | 1 - tensorflow/python/layers/layers.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c784cb47611..74df39049cc 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -7105,7 +7105,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":layers_base", - "//tensorflow/python/keras/engine:input_spec", "//tensorflow/python/keras/legacy_tf_layers:convolutional", "//tensorflow/python/keras/legacy_tf_layers:core", "//tensorflow/python/keras/legacy_tf_layers:normalization", diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py index 93eec38a08c..f052ae66a6a 100644 --- a/tensorflow/python/layers/layers.py +++ b/tensorflow/python/layers/layers.py @@ -24,7 +24,6 @@ from __future__ import print_function # Base objects. from tensorflow.python.layers.base import Layer -from tensorflow.python.keras.engine.input_spec import InputSpec # Core layers. from tensorflow.python.layers.core import Dense From 25ae1e130a3b3ad9c9c89030eae1e0f8d8d1b1ab Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 23 Mar 2020 13:26:12 -0700 Subject: [PATCH 436/492] Improves public-facing documentation for `tf.keras.utils.custom_object_scope`. PiperOrigin-RevId: 302504282 Change-Id: Ieb892c2681557a99ddec6f6ba2f9db01efde61df --- .../python/keras/utils/generic_utils.py | 73 +++++++------------ ...flow.keras.utils.custom_object_scope.pbtxt | 9 +++ .../golden/v1/tensorflow.keras.utils.pbtxt | 8 +- ...flow.keras.utils.custom_object_scope.pbtxt | 9 +++ .../golden/v2/tensorflow.keras.utils.pbtxt | 8 +- 5 files changed, 53 insertions(+), 54 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.utils.custom_object_scope.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.utils.custom_object_scope.pbtxt diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 970ec755c80..5331806fc23 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -47,25 +47,31 @@ _SKIP_FAILED_SERIALIZATION = False _LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config' -@keras_export('keras.utils.CustomObjectScope') +@keras_export('keras.utils.custom_object_scope', # pylint: disable=g-classes-have-attributes + 'keras.utils.CustomObjectScope') class CustomObjectScope(object): - """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. + """Exposes custom classes/functions to Keras deserialization internals. - Code within a `with` statement will be able to access custom objects - by name. Changes to global custom objects persist - within the enclosing `with` statement. At end of the `with` statement, - global custom objects are reverted to state - at beginning of the `with` statement. + Under a scope `with custom_object_scope(objects_dict)`, Keras methods such + as `tf.keras.models.load_model` or `tf.keras.models.model_from_config` + will be able to deserialize any custom object referenced by a + saved config (e.g. a custom layer or metric). Example: - Consider a custom object `MyObject` (e.g. a class): + Consider a custom regularizer `my_regularizer`: ```python - with CustomObjectScope({'MyObject':MyObject}): - layer = Dense(..., kernel_regularizer='MyObject') - # save, load, etc. will recognize custom object by name + layer = Dense(3, kernel_regularizer=my_regularizer) + config = layer.get_config() # Config contains a reference to `my_regularizer` + ... + # Later: + with custom_object_scope({'my_regularizer': my_regularizer}): + layer = Dense.from_config(config) ``` + + Arguments: + *args: Dictionary or dictionaries of `{name: object}` pairs. """ def __init__(self, *args): @@ -83,50 +89,19 @@ class CustomObjectScope(object): _GLOBAL_CUSTOM_OBJECTS.update(self.backup) -@keras_export('keras.utils.custom_object_scope') -def custom_object_scope(*args): - """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. - - Convenience wrapper for `CustomObjectScope`. - Code within a `with` statement will be able to access custom objects - by name. Changes to global custom objects persist - within the enclosing `with` statement. At end of the `with` statement, - global custom objects are reverted to state - at beginning of the `with` statement. - - Example: - - Consider a custom object `MyObject` - - ```python - with custom_object_scope({'MyObject':MyObject}): - layer = Dense(..., kernel_regularizer='MyObject') - # save, load, etc. will recognize custom object by name - ``` - - Arguments: - *args: Variable length list of dictionaries of name, class pairs to add to - custom objects. - - Returns: - Object of type `CustomObjectScope`. - """ - return CustomObjectScope(*args) - - @keras_export('keras.utils.get_custom_objects') def get_custom_objects(): """Retrieves a live reference to the global dictionary of custom objects. Updating and clearing custom objects using `custom_object_scope` is preferred, but `get_custom_objects` can - be used to directly access `_GLOBAL_CUSTOM_OBJECTS`. + be used to directly access the current collection of custom objects. Example: ```python - get_custom_objects().clear() - get_custom_objects()['MyObject'] = MyObject + get_custom_objects().clear() + get_custom_objects()['MyObject'] = MyObject ``` Returns: @@ -158,7 +133,7 @@ def register_keras_serializable(package='Custom', name=None): Arguments: package: The package that this class belongs to. name: The name to serialize this class under in this package. If None, the - class's name will be used. + class' name will be used. Returns: A decorator that registers the decorated class with the passed names. @@ -806,3 +781,9 @@ def default(method): def is_default(method): """Check if a method is decorated with the `default` wrapper.""" return getattr(method, '_is_default', False) + + +# Aliases + + +custom_object_scope = CustomObjectScope # pylint: disable=invalid-name diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.custom_object_scope.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.custom_object_scope.pbtxt new file mode 100644 index 00000000000..bfee151ae28 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.custom_object_scope.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.keras.utils.custom_object_scope" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=args, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt index 6f0000b84fb..ae616a1a620 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt @@ -28,14 +28,14 @@ tf_module { name: "SequenceEnqueuer" mtype: "" } + member { + name: "custom_object_scope" + mtype: "" + } member_method { name: "convert_all_kernels_in_model" argspec: "args=[\'model\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "custom_object_scope" - argspec: "args=[], varargs=args, keywords=None, defaults=None" - } member_method { name: "deserialize_keras_object" argspec: "args=[\'identifier\', \'module_objects\', \'custom_objects\', \'printable_module_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'object\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.custom_object_scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.custom_object_scope.pbtxt new file mode 100644 index 00000000000..bfee151ae28 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.custom_object_scope.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.keras.utils.custom_object_scope" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=args, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt index 6f0000b84fb..ae616a1a620 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt @@ -28,14 +28,14 @@ tf_module { name: "SequenceEnqueuer" mtype: "" } + member { + name: "custom_object_scope" + mtype: "" + } member_method { name: "convert_all_kernels_in_model" argspec: "args=[\'model\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "custom_object_scope" - argspec: "args=[], varargs=args, keywords=None, defaults=None" - } member_method { name: "deserialize_keras_object" argspec: "args=[\'identifier\', \'module_objects\', \'custom_objects\', \'printable_module_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'object\'], " From 36b108d7837f7d714ff24698bcb473dbd6b3aadd Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Mon, 23 Mar 2020 13:29:35 -0700 Subject: [PATCH 437/492] Add support for adding custom metric names in model_to_estimator API. PiperOrigin-RevId: 302504968 Change-Id: I7260f0d495ec2b97b1a759ee0ff5c212c55c4fa2 --- tensorflow/python/keras/estimator/__init__.py | 63 ++++++++++++++++--- .../v2/tensorflow.keras.estimator.pbtxt | 2 +- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py index 448cbf18854..4dde07e770b 100644 --- a/tensorflow/python/keras/estimator/__init__.py +++ b/tensorflow/python/keras/estimator/__init__.py @@ -130,13 +130,13 @@ def model_to_estimator( @keras_export('keras.estimator.model_to_estimator', v1=[]) -def model_to_estimator_v2( - keras_model=None, - keras_model_path=None, - custom_objects=None, - model_dir=None, - config=None, - checkpoint_format='checkpoint'): +def model_to_estimator_v2(keras_model=None, + keras_model_path=None, + custom_objects=None, + model_dir=None, + config=None, + checkpoint_format='checkpoint', + metric_names_map=None): """Constructs an `Estimator` instance from given keras model. If you use infrastructure or other tooling that relies on Estimators, you can @@ -169,6 +169,41 @@ def model_to_estimator_v2( estimator.train(input_fn, steps=1) ``` + To customize the estimator `eval_metric_ops` names, you can pass in the + `metric_names_map` dictionary mapping the keras model output metric names + to the custom names as follows: + + ```python + input_a = tf.keras.layers.Input(shape=(16,), name='input_a') + input_b = tf.keras.layers.Input(shape=(16,), name='input_b') + dense = tf.keras.layers.Dense(8, name='dense_1') + interm_a = dense(input_a) + interm_b = dense(input_b) + merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge') + output_a = tf.keras.layers.Dense(3, activation='softmax', name='dense_2')( + merged) + output_b = tf.keras.layers.Dense(2, activation='softmax', name='dense_3')( + merged) + keras_model = tf.keras.models.Model( + inputs=[input_a, input_b], outputs=[output_a, output_b]) + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_3': 'categorical_accuracy' + }) + + metric_names_map = { + 'dense_2_categorical_accuracy': 'acc_1', + 'dense_3_categorical_accuracy': 'acc_2', + } + keras_est = tf.keras.estimator.model_to_estimator( + keras_model=keras_model, + config=config, + metric_names_map=metric_names_map) + ``` + Args: keras_model: A compiled Keras model object. This argument is mutually exclusive with `keras_model_path`. Estimator's `model_fn` uses the @@ -197,6 +232,17 @@ def model_to_estimator_v2( `tf.train.Checkpoint`. Currently, saving object-based checkpoints from `model_to_estimator` is only supported by Functional and Sequential models. Defaults to 'checkpoint'. + metric_names_map: Optional dictionary mapping Keras model output metric + names to custom names. This can be used to override the default Keras + model output metrics names in a multi IO model use case and provide custom + names for the `eval_metric_ops` in Estimator. + The Keras model metric names can be obtained using `model.metrics_names` + excluding any loss metrics such as total loss and output losses. + For example, if your Keras model has two outputs `out_1` and `out_2`, + with `mse` loss and `acc` metric, then `model.metrics_names` will be + `['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc']`. + The model metric names excluding the loss metrics will be + `['out_1_acc', 'out_2_acc']`. Returns: An Estimator from given keras model. @@ -223,5 +269,6 @@ def model_to_estimator_v2( model_dir=model_dir, config=config, checkpoint_format=checkpoint_format, - use_v2_estimator=True) + use_v2_estimator=True, + metric_names_map=metric_names_map) # LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py) diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt index 81fcfd87cda..d9415ba4e54 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt @@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator" tf_module { member_method { name: "model_to_estimator" - argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\'], " + argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\', \'metric_names_map\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\', \'None\'], " } } From fd89b050988b7742a8336f8c3b13ded6b77e37dd Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Mon, 23 Mar 2020 13:32:11 -0700 Subject: [PATCH 438/492] TFLite GPU: Replace tflite::gpu::Status with absl::Status. PiperOrigin-RevId: 302505499 Change-Id: I72e21f78b660ef3c983b3863ba7ffbe29b34e66c --- tensorflow/lite/delegates/gpu/BUILD | 6 - tensorflow/lite/delegates/gpu/api.h | 27 +- tensorflow/lite/delegates/gpu/cl/api.cc | 198 +-- tensorflow/lite/delegates/gpu/cl/api.h | 4 +- tensorflow/lite/delegates/gpu/cl/buffer.cc | 22 +- tensorflow/lite/delegates/gpu/cl/buffer.h | 28 +- .../lite/delegates/gpu/cl/cl_command_queue.cc | 101 +- .../lite/delegates/gpu/cl/cl_command_queue.h | 47 +- .../lite/delegates/gpu/cl/cl_context.cc | 23 +- tensorflow/lite/delegates/gpu/cl/cl_context.h | 9 +- tensorflow/lite/delegates/gpu/cl/cl_device.cc | 8 +- tensorflow/lite/delegates/gpu/cl/cl_device.h | 8 +- tensorflow/lite/delegates/gpu/cl/cl_errors.h | 7 +- tensorflow/lite/delegates/gpu/cl/cl_kernel.cc | 68 +- tensorflow/lite/delegates/gpu/cl/cl_kernel.h | 30 +- .../lite/delegates/gpu/cl/cl_program.cc | 54 +- tensorflow/lite/delegates/gpu/cl/cl_program.h | 18 +- tensorflow/lite/delegates/gpu/cl/egl_sync.cc | 17 +- tensorflow/lite/delegates/gpu/cl/egl_sync.h | 6 +- .../lite/delegates/gpu/cl/environment.cc | 24 +- .../lite/delegates/gpu/cl/environment.h | 4 +- .../lite/delegates/gpu/cl/gl_interop.cc | 74 +- tensorflow/lite/delegates/gpu/cl/gl_interop.h | 38 +- .../lite/delegates/gpu/cl/gpu_api_delegate.cc | 18 +- .../delegates/gpu/cl/inference_context.cc | 71 +- .../lite/delegates/gpu/cl/inference_context.h | 40 +- .../lite/delegates/gpu/cl/kernels/add.cc | 6 +- .../lite/delegates/gpu/cl/kernels/add.h | 4 +- .../lite/delegates/gpu/cl/kernels/cl_test.cc | 32 +- .../lite/delegates/gpu/cl/kernels/cl_test.h | 26 +- .../delegates/gpu/cl/kernels/concat_xy.cc | 10 +- .../lite/delegates/gpu/cl/kernels/concat_xy.h | 8 +- .../lite/delegates/gpu/cl/kernels/concat_z.cc | 13 +- .../lite/delegates/gpu/cl/kernels/concat_z.h | 8 +- .../lite/delegates/gpu/cl/kernels/conv_3d.cc | 20 +- .../lite/delegates/gpu/cl/kernels/conv_3d.h | 46 +- .../gpu/cl/kernels/conv_buffer_1x1.cc | 39 +- .../gpu/cl/kernels/conv_buffer_1x1.h | 77 +- .../gpu/cl/kernels/conv_constants.cc | 23 +- .../delegates/gpu/cl/kernels/conv_constants.h | 29 +- .../delegates/gpu/cl/kernels/conv_powervr.cc | 39 +- .../delegates/gpu/cl/kernels/conv_powervr.h | 83 +- .../delegates/gpu/cl/kernels/conv_texture.cc | 34 +- .../delegates/gpu/cl/kernels/conv_texture.h | 75 +- .../delegates/gpu/cl/kernels/converter.cc | 79 +- .../gpu/cl/kernels/convolution_transposed.cc | 22 +- .../gpu/cl/kernels/convolution_transposed.h | 25 +- .../cl/kernels/convolution_transposed_3d.cc | 15 +- .../cl/kernels/convolution_transposed_3d.h | 20 +- .../cl/kernels/convolution_transposed_3x3.cc | 21 +- .../cl/kernels/convolution_transposed_3x3.h | 16 +- .../convolution_transposed_3x3_thin.cc | 19 +- .../kernels/convolution_transposed_3x3_thin.h | 18 +- .../cl/kernels/convolution_transposed_4x4.cc | 18 +- .../cl/kernels/convolution_transposed_4x4.h | 16 +- .../cl/kernels/convolution_transposed_thin.cc | 16 +- .../cl/kernels/convolution_transposed_thin.h | 18 +- .../gpu/cl/kernels/depth_wise_conv.cc | 21 +- .../gpu/cl/kernels/depth_wise_conv.h | 25 +- .../gpu/cl/kernels/depth_wise_conv_3d.cc | 14 +- .../gpu/cl/kernels/depth_wise_conv_3d.h | 20 +- .../gpu/cl/kernels/depth_wise_conv_3x3.cc | 28 +- .../gpu/cl/kernels/depth_wise_conv_3x3.h | 27 +- .../delegates/gpu/cl/kernels/elementwise.cc | 4 +- .../delegates/gpu/cl/kernels/elementwise.h | 2 +- .../gpu/cl/kernels/fully_connected.cc | 17 +- .../gpu/cl/kernels/fully_connected.h | 25 +- .../delegates/gpu/cl/kernels/gpu_operation.cc | 17 +- .../delegates/gpu/cl/kernels/gpu_operation.h | 28 +- .../lite/delegates/gpu/cl/kernels/lstm.cc | 11 +- .../lite/delegates/gpu/cl/kernels/lstm.h | 8 +- .../delegates/gpu/cl/kernels/max_unpooling.cc | 22 +- .../delegates/gpu/cl/kernels/max_unpooling.h | 16 +- .../lite/delegates/gpu/cl/kernels/mean.cc | 8 +- .../lite/delegates/gpu/cl/kernels/mean.h | 6 +- .../delegates/gpu/cl/kernels/multiply_add.cc | 48 +- .../delegates/gpu/cl/kernels/multiply_add.h | 80 +- .../lite/delegates/gpu/cl/kernels/padding.cc | 10 +- .../lite/delegates/gpu/cl/kernels/padding.h | 8 +- .../lite/delegates/gpu/cl/kernels/pooling.cc | 24 +- .../lite/delegates/gpu/cl/kernels/pooling.h | 16 +- .../lite/delegates/gpu/cl/kernels/prelu.cc | 14 +- .../lite/delegates/gpu/cl/kernels/prelu.h | 22 +- .../gpu/cl/kernels/quantize_and_dequantize.cc | 14 +- .../gpu/cl/kernels/quantize_and_dequantize.h | 19 +- .../lite/delegates/gpu/cl/kernels/relu.cc | 4 +- .../lite/delegates/gpu/cl/kernels/relu.h | 2 +- .../lite/delegates/gpu/cl/kernels/reshape.cc | 11 +- .../lite/delegates/gpu/cl/kernels/reshape.h | 8 +- .../delegates/gpu/cl/kernels/reshapex4.cc | 11 +- .../lite/delegates/gpu/cl/kernels/reshapex4.h | 8 +- .../lite/delegates/gpu/cl/kernels/resize.cc | 20 +- .../lite/delegates/gpu/cl/kernels/resize.h | 16 +- .../lite/delegates/gpu/cl/kernels/softmax.cc | 10 +- .../lite/delegates/gpu/cl/kernels/softmax.h | 8 +- .../delegates/gpu/cl/kernels/softmax1x1.cc | 4 +- .../delegates/gpu/cl/kernels/softmax1x1.h | 4 +- .../gpu/cl/kernels/space_to_depth.cc | 8 +- .../delegates/gpu/cl/kernels/space_to_depth.h | 8 +- .../delegates/gpu/cl/kernels/strided_slice.cc | 10 +- .../delegates/gpu/cl/kernels/strided_slice.h | 8 +- .../delegates/gpu/cl/kernels/transpose.cc | 11 +- .../lite/delegates/gpu/cl/kernels/transpose.h | 8 +- .../lite/delegates/gpu/cl/kernels/winograd.cc | 45 +- .../lite/delegates/gpu/cl/kernels/winograd.h | 38 +- .../gpu/cl/kernels/work_group_picking.cc | 49 +- .../gpu/cl/kernels/work_group_picking.h | 24 +- .../lite/delegates/gpu/cl/linear_storage.cc | 22 +- .../lite/delegates/gpu/cl/linear_storage.h | 36 +- .../lite/delegates/gpu/cl/opencl_wrapper.cc | 8 +- .../lite/delegates/gpu/cl/opencl_wrapper.h | 2 +- .../lite/delegates/gpu/cl/program_cache.cc | 33 +- .../lite/delegates/gpu/cl/program_cache.h | 19 +- .../gpu/cl/selectors/convolution_selector.cc | 93 +- .../gpu/cl/selectors/convolution_selector.h | 20 +- .../convolution_transposed_selector.cc | 21 +- .../convolution_transposed_selector.h | 8 +- .../cl/selectors/default/default_selector.cc | 13 +- .../gpu/cl/selectors/default_selector.h | 11 +- .../cl/selectors/dw_convolution_selector.cc | 38 +- .../cl/selectors/dw_convolution_selector.h | 8 +- .../cl/selectors/fully_connected_selector.cc | 40 +- .../cl/selectors/fully_connected_selector.h | 8 +- .../gpu/cl/selectors/operation_selector.cc | 61 +- .../gpu/cl/selectors/operation_selector.h | 11 +- .../gpu/cl/selectors/simple_selectors.cc | 83 +- .../gpu/cl/selectors/simple_selectors.h | 59 +- .../delegates/gpu/cl/storage_type_util.cc | 1 + tensorflow/lite/delegates/gpu/cl/tensor.cc | 151 ++- tensorflow/lite/delegates/gpu/cl/tensor.h | 62 +- .../lite/delegates/gpu/cl/tensor_test.cc | 25 +- .../gpu/cl/testing/performance_profiling.cc | 19 +- tensorflow/lite/delegates/gpu/cl/texture2d.cc | 26 +- tensorflow/lite/delegates/gpu/cl/texture2d.h | 36 +- tensorflow/lite/delegates/gpu/common/BUILD | 5 +- .../lite/delegates/gpu/common/convert.cc | 90 +- .../lite/delegates/gpu/common/convert.h | 24 +- .../delegates/gpu/common/custom_parsers.cc | 8 +- .../delegates/gpu/common/custom_parsers.h | 6 +- .../delegates/gpu/common/memory_management.cc | 35 +- .../delegates/gpu/common/memory_management.h | 21 +- .../memory_management/equality_assignment.h | 8 +- .../greedy_by_breadth_assignment.cc | 6 +- .../greedy_by_breadth_assignment.h | 2 +- .../greedy_by_size_assignment.cc | 14 +- .../greedy_by_size_assignment.h | 4 +- .../greedy_in_order_assignment.h | 10 +- .../min_cost_flow_assignment.cc | 4 +- .../min_cost_flow_assignment.h | 2 +- .../memory_management/naive_assignment.h | 4 +- tensorflow/lite/delegates/gpu/common/model.h | 117 +- .../delegates/gpu/common/model_builder.cc | 1096 +++++++++-------- .../lite/delegates/gpu/common/model_builder.h | 16 +- .../lite/delegates/gpu/common/operations.cc | 15 +- .../lite/delegates/gpu/common/operations.h | 5 +- tensorflow/lite/delegates/gpu/common/status.h | 108 +- .../gpu/common/testing/interpreter_utils.cc | 35 +- .../gpu/common/testing/interpreter_utils.h | 16 +- .../transformations/add_quant_adjustments.cc | 2 +- .../transformations/fuse_add_to_conv.cc | 8 +- .../transformations/fuse_mul_to_conv.cc | 8 +- .../common/transformations/make_padding.cc | 6 +- .../match_dilated_convolution.cc | 2 +- .../transformations/merge_padding_with.cc | 8 +- .../gpu/common/transformations/remove_noop.cc | 8 +- .../gpu/common/workgroup_selection.cc | 13 +- .../gpu/common/workgroup_selection.h | 7 +- tensorflow/lite/delegates/gpu/delegate.cc | 35 +- tensorflow/lite/delegates/gpu/gl/api.cc | 90 +- tensorflow/lite/delegates/gpu/gl/api.h | 24 +- tensorflow/lite/delegates/gpu/gl/api2.cc | 197 +-- tensorflow/lite/delegates/gpu/gl/api2.h | 4 +- .../lite/delegates/gpu/gl/command_queue.cc | 18 +- .../lite/delegates/gpu/gl/command_queue.h | 8 +- tensorflow/lite/delegates/gpu/gl/compiler.cc | 18 +- tensorflow/lite/delegates/gpu/gl/compiler.h | 8 +- .../gpu/gl/compiler/compiled_node.cc | 6 +- .../delegates/gpu/gl/compiler/compiled_node.h | 4 +- .../delegates/gpu/gl/compiler/preprocessor.cc | 16 +- .../delegates/gpu/gl/compiler/preprocessor.h | 2 +- .../lite/delegates/gpu/gl/compiler/rename.cc | 8 +- .../lite/delegates/gpu/gl/compiler/rename.h | 2 +- .../gpu/gl/compiler/shader_codegen.cc | 17 +- .../gpu/gl/compiler/shader_codegen.h | 3 +- .../gpu/gl/converters/bhwc_to_phwc4.cc | 18 +- .../gpu/gl/converters/bhwc_to_phwc4.h | 8 +- .../gpu/gl/converters/bhwc_to_phwc4_test.cc | 6 +- .../gpu/gl/converters/phwc4_to_bhwc.cc | 18 +- .../gpu/gl/converters/phwc4_to_bhwc.h | 8 +- .../gpu/gl/converters/phwc4_to_bhwc_test.cc | 6 +- .../lite/delegates/gpu/gl/egl_context.cc | 42 +- .../lite/delegates/gpu/gl/egl_context.h | 18 +- .../lite/delegates/gpu/gl/egl_environment.cc | 35 +- .../lite/delegates/gpu/gl/egl_environment.h | 10 +- .../lite/delegates/gpu/gl/egl_surface.cc | 11 +- .../lite/delegates/gpu/gl/egl_surface.h | 6 +- tensorflow/lite/delegates/gpu/gl/gl_buffer.cc | 24 +- tensorflow/lite/delegates/gpu/gl/gl_buffer.h | 68 +- .../lite/delegates/gpu/gl/gl_buffer_test.cc | 2 +- tensorflow/lite/delegates/gpu/gl/gl_call.h | 27 +- tensorflow/lite/delegates/gpu/gl/gl_errors.cc | 42 +- tensorflow/lite/delegates/gpu/gl/gl_errors.h | 4 +- .../lite/delegates/gpu/gl/gl_program.cc | 58 +- tensorflow/lite/delegates/gpu/gl/gl_program.h | 13 +- tensorflow/lite/delegates/gpu/gl/gl_shader.cc | 12 +- tensorflow/lite/delegates/gpu/gl/gl_shader.h | 6 +- tensorflow/lite/delegates/gpu/gl/gl_sync.cc | 22 +- tensorflow/lite/delegates/gpu/gl/gl_sync.h | 12 +- .../lite/delegates/gpu/gl/gl_texture.cc | 82 +- tensorflow/lite/delegates/gpu/gl/gl_texture.h | 50 +- .../lite/delegates/gpu/gl/kernels/add.cc | 12 +- .../lite/delegates/gpu/gl/kernels/concat.cc | 35 +- .../lite/delegates/gpu/gl/kernels/conv.cc | 20 +- .../delegates/gpu/gl/kernels/converter.cc | 86 +- .../gpu/gl/kernels/converter_test.cc | 12 +- .../gpu/gl/kernels/depthwise_conv.cc | 6 +- .../delegates/gpu/gl/kernels/elementwise.cc | 21 +- .../gpu/gl/kernels/fully_connected.cc | 6 +- .../lite/delegates/gpu/gl/kernels/lstm.cc | 6 +- .../delegates/gpu/gl/kernels/max_unpooling.cc | 6 +- .../lite/delegates/gpu/gl/kernels/mean.cc | 8 +- .../lite/delegates/gpu/gl/kernels/mul.cc | 18 +- .../lite/delegates/gpu/gl/kernels/pad.cc | 12 +- .../lite/delegates/gpu/gl/kernels/pooling.cc | 24 +- .../lite/delegates/gpu/gl/kernels/prelu.cc | 25 +- .../gpu/gl/kernels/quantize_and_dequantize.cc | 6 +- .../lite/delegates/gpu/gl/kernels/registry.cc | 10 +- .../lite/delegates/gpu/gl/kernels/relu.cc | 6 +- .../lite/delegates/gpu/gl/kernels/reshape.cc | 10 +- .../lite/delegates/gpu/gl/kernels/resize.cc | 18 +- .../lite/delegates/gpu/gl/kernels/slice.cc | 6 +- .../lite/delegates/gpu/gl/kernels/softmax.cc | 22 +- .../gpu/gl/kernels/space_to_depth.cc | 6 +- .../delegates/gpu/gl/kernels/test_util.cc | 10 +- .../lite/delegates/gpu/gl/kernels/test_util.h | 8 +- .../gpu/gl/kernels/transpose_conv.cc | 12 +- .../lite/delegates/gpu/gl/node_shader.h | 4 +- .../lite/delegates/gpu/gl/object_manager.cc | 19 +- .../lite/delegates/gpu/gl/object_manager.h | 14 +- .../lite/delegates/gpu/gl/request_gpu_info.cc | 4 +- .../lite/delegates/gpu/gl/request_gpu_info.h | 2 +- tensorflow/lite/delegates/gpu/gl/runtime.cc | 128 +- tensorflow/lite/delegates/gpu/gl/runtime.h | 20 +- .../delegates/gpu/gl/runtime/shared_buffer.h | 4 +- .../lite/delegates/gpu/gl/serialization.cc | 54 +- .../lite/delegates/gpu/gl/serialization.h | 16 +- .../delegates/gpu/gl/serialization_test.cc | 15 +- tensorflow/lite/delegates/gpu/gl_delegate.cc | 42 +- tensorflow/lite/delegates/gpu/metal/api.cc | 39 +- tensorflow/lite/delegates/gpu/metal/api.h | 4 +- tensorflow/lite/delegates/gpu/metal/common.h | 7 +- tensorflow/lite/delegates/gpu/metal/common.mm | 14 +- .../lite/delegates/gpu/metal/common_test.mm | 8 +- .../delegates/gpu/metal/compiled_model.cc | 12 +- .../lite/delegates/gpu/metal/compiled_model.h | 7 +- .../gpu/metal/compiled_model_test.mm | 22 +- .../lite/delegates/gpu/metal/compute_task.h | 19 +- .../lite/delegates/gpu/metal/compute_task.mm | 39 +- .../delegates/gpu/metal/inference_context.h | 14 +- .../delegates/gpu/metal/inference_context.mm | 27 +- .../gpu/metal/inference_context_test.mm | 18 +- .../delegates/gpu/metal/kernels/add_test.mm | 13 +- .../gpu/metal/kernels/concat_test.mm | 17 +- .../delegates/gpu/metal/kernels/conv_test.mm | 21 +- .../gpu/metal/kernels/custom_registry.cc | 12 +- .../gpu/metal/kernels/custom_registry.h | 10 +- .../gpu/metal/kernels/depthwise_conv_test.mm | 13 +- .../gpu/metal/kernels/elementwise_test.mm | 77 +- .../gpu/metal/kernels/fully_connected_test.mm | 5 +- .../gpu/metal/kernels/max_unpooling_test.mm | 5 +- .../delegates/gpu/metal/kernels/mean_test.mm | 5 +- .../delegates/gpu/metal/kernels/mul_test.mm | 17 +- .../gpu/metal/kernels/padding_test.mm | 13 +- .../gpu/metal/kernels/pooling_test.mm | 15 +- .../delegates/gpu/metal/kernels/prelu_test.mm | 17 +- .../delegates/gpu/metal/kernels/relu_test.mm | 17 +- .../gpu/metal/kernels/reshape_test.mm | 16 +- .../gpu/metal/kernels/resize_test.mm | 25 +- .../delegates/gpu/metal/kernels/slice_test.mm | 25 +- .../gpu/metal/kernels/softmax_test.mm | 13 +- .../gpu/metal/kernels/space_to_depth_test.mm | 8 +- .../delegates/gpu/metal/kernels/test_util.h | 14 +- .../delegates/gpu/metal/kernels/test_util.mm | 22 +- .../gpu/metal/kernels/transpose_conv_test.mm | 25 +- .../lite/delegates/gpu/metal_delegate.mm | 38 +- tensorflow/lite/delegates/gpu/spi.h | 12 +- 286 files changed, 3926 insertions(+), 3881 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 72af2534988..b5fff1d84d5 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -12,12 +12,6 @@ exports_files([ "metal_delegate.h", ]) -# Primary purpose of this config is to replace ::util::Status with our custom -# light implementation ::tflite::gpu::StatusLite to reduce binary size. Besides -# that, certain features that were hard to communicate without full open source -# were hidden away too such as compiled models, serialization, and metadata. -# While the latter will be fully available with the open source release, the -# former will have to stay until absl::Status is released. config_setting( name = "tflite_gpu_binary_release", values = {"copt": "-DTFLITE_GPU_BINARY_RELEASE"}, diff --git a/tensorflow/lite/delegates/gpu/api.h b/tensorflow/lite/delegates/gpu/api.h index 803983214e2..921f2d54006 100644 --- a/tensorflow/lite/delegates/gpu/api.h +++ b/tensorflow/lite/delegates/gpu/api.h @@ -220,7 +220,8 @@ class InferenceBuilder { // Sets new shape for the input if underlying implementation and graph // structure allows dynamic tensors. - virtual Status SetInputShape(int index, const Dimensions& dimensions) = 0; + virtual absl::Status SetInputShape(int index, + const Dimensions& dimensions) = 0; // Updates object definitions for the given index. Implementation may allow // to use different layouts and/or data type conversions between objects @@ -229,21 +230,21 @@ class InferenceBuilder { // A user, however, has an input in DataType::FLOAT16, DataLayout::PHWC4. // An implementation may allow this transformation to happen automatically // under the hood. - virtual Status SetInputObjectDef(int index, ObjectDef def) = 0; - virtual Status SetOutputObjectDef(int index, ObjectDef def) = 0; - virtual Status SetAllInputObjectDefsTo(ObjectDef def) { + virtual absl::Status SetInputObjectDef(int index, ObjectDef def) = 0; + virtual absl::Status SetOutputObjectDef(int index, ObjectDef def) = 0; + virtual absl::Status SetAllInputObjectDefsTo(ObjectDef def) { auto input_defs = inputs(); for (int i = 0; i < input_defs.size(); ++i) { RETURN_IF_ERROR(SetInputObjectDef(i, def)); } - return OkStatus(); + return absl::OkStatus(); } - virtual Status SetAllOutputObjectDefsTo(ObjectDef def) { + virtual absl::Status SetAllOutputObjectDefsTo(ObjectDef def) { auto output_defs = outputs(); for (int i = 0; i < output_defs.size(); ++i) { RETURN_IF_ERROR(SetOutputObjectDef(i, def)); } - return OkStatus(); + return absl::OkStatus(); } // Creates new instance of the inference runner. InferenceBuilder stays valid @@ -251,7 +252,7 @@ class InferenceBuilder { // // This method may take significant time to prepare new inference runner. For // example, it may require to compile OpenGL shaders. - virtual Status Build(std::unique_ptr* runner) = 0; + virtual absl::Status Build(std::unique_ptr* runner) = 0; }; // Runs prepared inference. Every object marked as external needs to be set @@ -268,12 +269,12 @@ class InferenceRunner { // Setters allow to set or change external object for the given index. Note, // object need to match object definition set before in InferenceBuilder. - virtual Status GetInputObject(int index, TensorObject* object) = 0; - virtual Status GetOutputObject(int index, TensorObject* object) = 0; - virtual Status SetInputObject(int index, TensorObject object) = 0; - virtual Status SetOutputObject(int index, TensorObject object) = 0; + virtual absl::Status GetInputObject(int index, TensorObject* object) = 0; + virtual absl::Status GetOutputObject(int index, TensorObject* object) = 0; + virtual absl::Status SetInputObject(int index, TensorObject object) = 0; + virtual absl::Status SetOutputObject(int index, TensorObject object) = 0; - virtual Status Run() = 0; + virtual absl::Status Run() = 0; }; // Encapsulated compilation/runtime tradeoffs. diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index 4e85f92c6de..a6488c51ce4 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -54,22 +54,22 @@ class NoopTensorTie : public TensorTie { return def.external_def == def.internal_def; } - Status SetExternalObject(TensorObject obj) final { + absl::Status SetExternalObject(TensorObject obj) final { if (!def().external_def.object_def.user_provided) { - return InvalidArgumentError("Tensor object is readonly."); + return absl::InvalidArgumentError("Tensor object is readonly."); } if (!IsValid(def().external_def, obj)) { - return InvalidArgumentError("Given object is not valid"); + return absl::InvalidArgumentError("Given object is not valid"); } obj_ = obj; - return OkStatus(); + return absl::OkStatus(); } TensorObject GetExternalObject() final { return obj_; } - Status CopyToExternalObject() final { return OkStatus(); } + absl::Status CopyToExternalObject() final { return absl::OkStatus(); } - Status CopyFromExternalObject() final { return OkStatus(); } + absl::Status CopyFromExternalObject() final { return absl::OkStatus(); } private: TensorObject obj_; @@ -93,45 +93,45 @@ class DefaultTensorTie : public TensorTie { converter_builder.IsSupported(def.external_def, def.internal_def); } - static Status New(const TensorTieDef& def, TensorObject internal_object, - TensorObjectConverterBuilder* converter_builder, - Environment* env, std::unique_ptr* tie) { + static absl::Status New(const TensorTieDef& def, TensorObject internal_object, + TensorObjectConverterBuilder* converter_builder, + Environment* env, std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def, internal_object); RETURN_IF_ERROR(tie_impl->Init(converter_builder, env)); *tie = std::move(tie_impl); - return OkStatus(); + return absl::OkStatus(); } - Status CopyToExternalObject() final { + absl::Status CopyToExternalObject() final { if (!converter_to_) { - return UnavailableError("Conversion is not available"); + return absl::UnavailableError("Conversion is not available"); } return converter_to_->Convert(internal_obj_, GetExternalObject()); } - Status CopyFromExternalObject() final { + absl::Status CopyFromExternalObject() final { if (!converter_from_) { - return UnavailableError("Conversion is not available"); + return absl::UnavailableError("Conversion is not available"); } return converter_from_->Convert(GetExternalObject(), internal_obj_); } - Status SetExternalObject(TensorObject obj) final { + absl::Status SetExternalObject(TensorObject obj) final { if (!def().external_def.object_def.user_provided) { - return InvalidArgumentError("External object is read-only"); + return absl::InvalidArgumentError("External object is read-only"); } if (!IsValid(def().external_def, obj)) { - return InvalidArgumentError("Given object is not valid"); + return absl::InvalidArgumentError("Given object is not valid"); } external_obj_ = obj; - return OkStatus(); + return absl::OkStatus(); } TensorObject GetExternalObject() final { return external_obj_; } private: - Status Init(TensorObjectConverterBuilder* converter_builder, - Environment* env) { + absl::Status Init(TensorObjectConverterBuilder* converter_builder, + Environment* env) { RETURN_IF_ERROR(converter_builder->MakeConverter( def().internal_def, def().external_def, &converter_to_)); RETURN_IF_ERROR(converter_builder->MakeConverter( @@ -139,10 +139,10 @@ class DefaultTensorTie : public TensorTie { return MaybeAllocateExternalObject(env); } - Status MaybeAllocateExternalObject(Environment* env) { + absl::Status MaybeAllocateExternalObject(Environment* env) { const TensorObjectDef& d = def().external_def; if (d.object_def.user_provided) { - return OkStatus(); + return absl::OkStatus(); } switch (d.object_def.object_type) { case ObjectType::CPU_MEMORY: { @@ -170,9 +170,9 @@ class DefaultTensorTie : public TensorTie { break; } default: - return InternalError("Unexpected object type"); + return absl::InternalError("Unexpected object type"); } - return OkStatus(); + return absl::OkStatus(); } const TensorObject internal_obj_; @@ -198,26 +198,26 @@ class TwoStepTensorTie : public TensorTie { DefaultTensorTie::IsSupported(defs.second, converter_builder); } - static Status New(const TensorTieDef& def, TensorObject internal_object, - TensorObjectConverterBuilder* converter_builder, - Environment* env, std::unique_ptr* tie) { + static absl::Status New(const TensorTieDef& def, TensorObject internal_object, + TensorObjectConverterBuilder* converter_builder, + Environment* env, std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def); RETURN_IF_ERROR(tie_impl->Init(internal_object, converter_builder, env)); *tie = std::move(tie_impl); - return OkStatus(); + return absl::OkStatus(); } - Status CopyToExternalObject() final { + absl::Status CopyToExternalObject() final { RETURN_IF_ERROR(inner_tie_->CopyToExternalObject()); return outer_tie_->CopyToExternalObject(); } - Status CopyFromExternalObject() final { + absl::Status CopyFromExternalObject() final { RETURN_IF_ERROR(outer_tie_->CopyFromExternalObject()); return inner_tie_->CopyFromExternalObject(); } - Status SetExternalObject(TensorObject obj) final { + absl::Status SetExternalObject(TensorObject obj) final { return outer_tie_->SetExternalObject(obj); } @@ -241,9 +241,9 @@ class TwoStepTensorTie : public TensorTie { return std::make_pair(outer_def, inner_def); } - Status Init(TensorObject internal_object, - TensorObjectConverterBuilder* converter_builder, - Environment* env) { + absl::Status Init(TensorObject internal_object, + TensorObjectConverterBuilder* converter_builder, + Environment* env) { auto defs = MakeOuterInnerDefs(def()); RETURN_IF_ERROR(DefaultTensorTie::New(defs.second, internal_object, converter_builder, env, &inner_tie_)); @@ -274,27 +274,27 @@ class GlBufferHolder : public TensorTie { return DefaultTensorTie::IsSupported(MakeClDef(def), converter_builder); } - static Status New(const TensorTieDef& def, TensorObject internal_object, - TensorObjectConverterBuilder* converter_builder, - GlInteropFabric* gl_interop_fabric, Environment* env, - std::unique_ptr* tie) { + static absl::Status New(const TensorTieDef& def, TensorObject internal_object, + TensorObjectConverterBuilder* converter_builder, + GlInteropFabric* gl_interop_fabric, Environment* env, + std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def, gl_interop_fabric, env); RETURN_IF_ERROR(DefaultTensorTie::New(MakeClDef(def), internal_object, converter_builder, env, &tie_impl->tie_)); *tie = std::move(tie_impl); - return OkStatus(); + return absl::OkStatus(); } - Status SetExternalObject(TensorObject obj) final { + absl::Status SetExternalObject(TensorObject obj) final { auto ssbo = absl::get_if(&obj); if (!ssbo) { - return InvalidArgumentError("Missing OpenGL SSBO"); + return absl::InvalidArgumentError("Missing OpenGL SSBO"); } auto old_ssbo = absl::get_if(&external_obj_); if (old_ssbo && ssbo->id == old_ssbo->id) { - return OkStatus(); + return absl::OkStatus(); } if (cl_object_.memory()) { gl_interop_fabric_->UnregisterMemory(cl_object_.memory()); @@ -304,16 +304,18 @@ class GlBufferHolder : public TensorTie { external_obj_ = obj; RETURN_IF_ERROR(tie_->SetExternalObject(OpenClBuffer{cl_object_.memory()})); gl_interop_fabric_->RegisterMemory(cl_object_.memory()); - return OkStatus(); + return absl::OkStatus(); } TensorObject GetExternalObject() final { return external_obj_; } - Status CopyFromExternalObject() final { + absl::Status CopyFromExternalObject() final { return tie_->CopyFromExternalObject(); } - Status CopyToExternalObject() final { return tie_->CopyToExternalObject(); } + absl::Status CopyToExternalObject() final { + return tie_->CopyToExternalObject(); + } private: static TensorTieDef MakeClDef(const TensorTieDef& def) { @@ -358,20 +360,20 @@ class TensorTieFactory { TwoStepTensorTie::IsSupported(def, *converter_builder_)); } - Status NewTensorTie(const TensorTieDef& def, - std::unique_ptr* tie) { + absl::Status NewTensorTie(const TensorTieDef& def, + std::unique_ptr* tie) { TensorObject internal_object = TensorToObj(*context_.GetTensor(def.id)); auto converter = converter_builder_.get(); if (NoopTensorTie::IsSupported(def)) { *tie = absl::make_unique(def, internal_object); - return OkStatus(); + return absl::OkStatus(); } if (DefaultTensorTie::IsSupported(def, *converter)) { return DefaultTensorTie::New(def, internal_object, converter, &env_, tie); } if (GlBufferHolder::IsSupported(def, *converter)) { if (!gl_interop_fabric_) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "GL object is used but InferenceEnvironmentOptions does not have " "EGL display and context set."); } @@ -381,7 +383,7 @@ class TensorTieFactory { if (TwoStepTensorTie::IsSupported(def, *converter)) { return TwoStepTensorTie::New(def, internal_object, converter, &env_, tie); } - return UnimplementedError("Unsupported tensor tie definition."); + return absl::UnimplementedError("Unsupported tensor tie definition."); } private: @@ -400,9 +402,9 @@ class InferenceRunnerImpl : public InferenceRunner { context_(std::move(context)), gl_interop_fabric_(std::move(gl_interop_fabric)) {} - Status Initialize(const std::vector& inputs, - const std::vector& outputs, - TensorTieFactory* factory) { + absl::Status Initialize(const std::vector& inputs, + const std::vector& outputs, + TensorTieFactory* factory) { RETURN_IF_ERROR(LinkTensors(inputs, factory, &inputs_)); return LinkTensors(outputs, factory, &outputs_); } @@ -415,37 +417,37 @@ class InferenceRunnerImpl : public InferenceRunner { return GetExternalDefinitions(outputs_); } - Status GetInputObject(int index, TensorObject* object) override { + absl::Status GetInputObject(int index, TensorObject* object) override { if (index < 0 || index >= inputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } *object = inputs_[index]->GetExternalObject(); - return OkStatus(); + return absl::OkStatus(); } - Status GetOutputObject(int index, TensorObject* object) override { + absl::Status GetOutputObject(int index, TensorObject* object) override { if (index < 0 || index >= outputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } *object = outputs_[index]->GetExternalObject(); - return OkStatus(); + return absl::OkStatus(); } - Status SetInputObject(int index, TensorObject object) override { + absl::Status SetInputObject(int index, TensorObject object) override { if (index < 0 || index >= inputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } return inputs_[index]->SetExternalObject(object); } - Status SetOutputObject(int index, TensorObject object) override { + absl::Status SetOutputObject(int index, TensorObject object) override { if (index < 0 || index >= outputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } return outputs_[index]->SetExternalObject(object); } - Status Run() override { + absl::Status Run() override { if (gl_interop_fabric_) { RETURN_IF_ERROR(gl_interop_fabric_->Start()); } @@ -460,20 +462,20 @@ class InferenceRunnerImpl : public InferenceRunner { if (gl_interop_fabric_) { RETURN_IF_ERROR(gl_interop_fabric_->Finish()); } - return OkStatus(); + return absl::OkStatus(); } private: - static Status LinkTensors(const std::vector& defs, - TensorTieFactory* factory, - std::vector>* objects) { + static absl::Status LinkTensors( + const std::vector& defs, TensorTieFactory* factory, + std::vector>* objects) { objects->reserve(defs.size()); for (auto& def : defs) { std::unique_ptr object; RETURN_IF_ERROR(factory->NewTensorTie(def, &object)); objects->push_back(std::move(object)); } - return OkStatus(); + return absl::OkStatus(); } static std::vector GetExternalDefinitions( @@ -511,9 +513,9 @@ class InferenceBuilderImpl : public InferenceBuilder { explicit InferenceBuilderImpl(Environment* environment) : environment_(environment) {} - Status Initialize(const InferenceOptions& options, - const InferenceEnvironmentOptions& env_options, - const GraphFloat32& graph) { + absl::Status Initialize(const InferenceOptions& options, + const InferenceEnvironmentOptions& env_options, + const GraphFloat32& graph) { context_ = absl::make_unique(); InferenceContext::CreateInferenceInfo create_info; create_info.precision = GetPrecision(options); @@ -533,7 +535,7 @@ class InferenceBuilderImpl : public InferenceBuilder { inputs_ = LinkTensors(graph, graph.inputs()); outputs_ = LinkTensors(graph, graph.outputs()); - return OkStatus(); + return absl::OkStatus(); } std::vector inputs() const override { @@ -544,40 +546,42 @@ class InferenceBuilderImpl : public InferenceBuilder { return GetExternalDefinitions(outputs_); } - Status SetInputShape(int index, const Dimensions& dimensions) override { + absl::Status SetInputShape(int index, const Dimensions& dimensions) override { if (index < 0 || index >= inputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } - return UnimplementedError("Changing input shapes is not supported"); + return absl::UnimplementedError("Changing input shapes is not supported"); } - Status SetInputObjectDef(int index, ObjectDef new_def) override { + absl::Status SetInputObjectDef(int index, ObjectDef new_def) override { if (index < 0 || index >= inputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } auto def = inputs_[index]; def.external_def.object_def = new_def; if (!tie_factory_->IsSupported(def)) { - return InvalidArgumentError("New object definition is not supported."); + return absl::InvalidArgumentError( + "New object definition is not supported."); } inputs_[index] = def; - return OkStatus(); + return absl::OkStatus(); } - Status SetOutputObjectDef(int index, ObjectDef new_def) override { + absl::Status SetOutputObjectDef(int index, ObjectDef new_def) override { if (index < 0 || index >= outputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } auto def = outputs_[index]; def.external_def.object_def = new_def; if (!tie_factory_->IsSupported(def)) { - return InvalidArgumentError("New object definition is not supported."); + return absl::InvalidArgumentError( + "New object definition is not supported."); } outputs_[index] = def; - return OkStatus(); + return absl::OkStatus(); } - Status Build(std::unique_ptr* runner) override { + absl::Status Build(std::unique_ptr* runner) override { if (gl_interop_fabric_ && !HasGlObjects()) { // destroy interop layer when there are no GL objects to avoid // extra synchronization cost. @@ -588,7 +592,7 @@ class InferenceBuilderImpl : public InferenceBuilder { RETURN_IF_ERROR( runner_impl->Initialize(inputs_, outputs_, tie_factory_.get())); *runner = std::move(runner_impl); - return OkStatus(); + return absl::OkStatus(); } private: @@ -696,7 +700,7 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { explicit InferenceEnvironmentImpl(const InferenceEnvironmentOptions& options) : options_(options) {} - Status Init() { + absl::Status Init() { RETURN_IF_ERROR(LoadOpenCL()); properties_.is_opencl_available = true; @@ -716,13 +720,13 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { properties_.is_cl_to_gl_fast_sync_supported = IsEglSyncFromClEventSupported(); if (options_.IsGlAware() && !properties_.is_gl_sharing_supported) { - return UnavailableError("GL sharing is not supported"); + return absl::UnavailableError("GL sharing is not supported"); } CLContext context; if (options_.context) { if (options_.IsGlAware()) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "OpenCL context and EGL parameters are set in the same time."); } context = CLContext(options_.context, /* has_ownership = */ false); @@ -754,11 +758,11 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { return environment_.Init(); } - Status NewInferenceBuilder(const InferenceOptions& options, - GraphFloat32 model, - std::unique_ptr* builder) final { + absl::Status NewInferenceBuilder( + const InferenceOptions& options, GraphFloat32 model, + std::unique_ptr* builder) final { if (!IsValid(options)) { - return InvalidArgumentError("InferenceOptions are invalid."); + return absl::InvalidArgumentError("InferenceOptions are invalid."); } InferenceOptions resolved_options = options; ResolveAutoPriority(&resolved_options); @@ -776,7 +780,7 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { RETURN_IF_ERROR( builder_impl->Initialize(resolved_options, options_, model)); *builder = std::move(builder_impl); - return OkStatus(); + return absl::OkStatus(); } std::vector GetSerializedBinaryCache() const final { @@ -800,18 +804,18 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { } // namespace -Status NewInferenceEnvironment( +absl::Status NewInferenceEnvironment( const InferenceEnvironmentOptions& options, std::unique_ptr* environment, InferenceEnvironmentProperties* properties) { auto env_impl = absl::make_unique(options); - Status status = env_impl->Init(); + absl::Status status = env_impl->Init(); if (properties) { *properties = env_impl->properties(); } RETURN_IF_ERROR(status); *environment = std::move(env_impl); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/api.h b/tensorflow/lite/delegates/gpu/cl/api.h index 2ac5ce2e28b..9d3f9f7214c 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.h +++ b/tensorflow/lite/delegates/gpu/cl/api.h @@ -70,7 +70,7 @@ class InferenceEnvironment { public: virtual ~InferenceEnvironment() {} - virtual Status NewInferenceBuilder( + virtual absl::Status NewInferenceBuilder( const InferenceOptions& options, GraphFloat32 model, std::unique_ptr* builder) = 0; @@ -112,7 +112,7 @@ struct InferenceEnvironmentOptions { // Creates new OpenCL environment that needs to stay around until all inference // runners are destroyed. -Status NewInferenceEnvironment( +absl::Status NewInferenceEnvironment( const InferenceEnvironmentOptions& options, std::unique_ptr* environment, InferenceEnvironmentProperties* properties /* optional */); diff --git a/tensorflow/lite/delegates/gpu/cl/buffer.cc b/tensorflow/lite/delegates/gpu/cl/buffer.cc index 51d9a59e888..207cdec5122 100644 --- a/tensorflow/lite/delegates/gpu/cl/buffer.cc +++ b/tensorflow/lite/delegates/gpu/cl/buffer.cc @@ -21,8 +21,10 @@ namespace tflite { namespace gpu { namespace cl { namespace { -Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, const void* data, - CLContext* context, Buffer* result) { + +absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, + const void* data, CLContext* context, + Buffer* result) { cl_mem_flags flags = gpu_read_only ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE; if (data != nullptr) { flags |= CL_MEM_COPY_HOST_PTR; @@ -31,14 +33,14 @@ Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, const void* data, cl_mem buffer = clCreateBuffer(context->context(), flags, size_in_bytes, const_cast(data), &error_code); if (!buffer) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to allocate device memory with clCreateBuffer", CLErrorCodeToString(error_code))); } *result = Buffer(buffer, size_in_bytes); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -69,18 +71,18 @@ void Buffer::Release() { } } -Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context, - Buffer* result) { +absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context, + Buffer* result) { return CreateBuffer(size_in_bytes, true, nullptr, context, result); } -Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, - CLContext* context, Buffer* result) { +absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, + CLContext* context, Buffer* result) { return CreateBuffer(size_in_bytes, true, data, context, result); } -Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context, - Buffer* result) { +absl::Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context, + Buffer* result) { return CreateBuffer(size_in_bytes, false, nullptr, context, result); } diff --git a/tensorflow/lite/delegates/gpu/cl/buffer.h b/tensorflow/lite/delegates/gpu/cl/buffer.h index 4282d9c0898..84c3292084b 100644 --- a/tensorflow/lite/delegates/gpu/cl/buffer.h +++ b/tensorflow/lite/delegates/gpu/cl/buffer.h @@ -51,11 +51,11 @@ class Buffer { // Writes data to a buffer. Data should point to a region that // has exact size in bytes as size_in_bytes(constructor parameter). template - Status WriteData(CLCommandQueue* queue, const absl::Span data); + absl::Status WriteData(CLCommandQueue* queue, const absl::Span data); // Reads data from Buffer into CPU memory. template - Status ReadData(CLCommandQueue* queue, std::vector* result) const; + absl::Status ReadData(CLCommandQueue* queue, std::vector* result) const; private: void Release(); @@ -64,29 +64,31 @@ class Buffer { size_t size_; }; -Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context, - Buffer* result); +absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context, + Buffer* result); -Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, - CLContext* context, Buffer* result); +absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, + CLContext* context, Buffer* result); -Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context, - Buffer* result); +absl::Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context, + Buffer* result); template -Status Buffer::WriteData(CLCommandQueue* queue, const absl::Span data) { +absl::Status Buffer::WriteData(CLCommandQueue* queue, + const absl::Span data) { if (size_ != sizeof(T) * data.size()) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "absl::Span data size is different from buffer allocated size."); } RETURN_IF_ERROR(queue->EnqueueWriteBuffer(buffer_, size_, data.data())); - return OkStatus(); + return absl::OkStatus(); } template -Status Buffer::ReadData(CLCommandQueue* queue, std::vector* result) const { +absl::Status Buffer::ReadData(CLCommandQueue* queue, + std::vector* result) const { if (size_ % sizeof(T) != 0) { - return UnknownError("Wrong element size(typename T is not correct?"); + return absl::UnknownError("Wrong element size(typename T is not correct?"); } const int elements_count = size_ / sizeof(T); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc index 328cdaf0a6e..7b74840c5e6 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc @@ -56,8 +56,9 @@ void CLCommandQueue::Release() { } } -Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size, CLEvent* event) { +absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size, + CLEvent* event) { std::vector local(3); std::vector global(3); for (int i = 0; i < 3; ++i) { @@ -72,30 +73,31 @@ Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, *event = CLEvent(resulting_event); } if (error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat("Failed to clEnqueueNDRangeKernel - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError( + absl::StrCat("Failed to clEnqueueNDRangeKernel - ", + CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size) { +absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size) { return DispatchImplicit(kernel, grid, work_group_size, nullptr); } -Status CLCommandQueue::EnqueueEvent(CLEvent* event) { +absl::Status CLCommandQueue::EnqueueEvent(CLEvent* event) { cl_event resulting_event; const int error_code = clEnqueueMarker(queue_, &resulting_event); *event = CLEvent(resulting_event); if (error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat("Failed to clEnqueueMarker - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError(absl::StrCat("Failed to clEnqueueMarker - ", + CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region, - const void* data) { +absl::Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region, + const void* data) { const size_t origin[] = {0, 0, 0}; const size_t r[] = {static_cast(region.x), static_cast(region.y), @@ -103,16 +105,16 @@ Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region, auto error_code = clEnqueueWriteImage(queue_, memory, CL_TRUE, origin, r, 0, 0, data, 0, nullptr, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to upload data to GPU (clEnqueueWriteImage) - ", CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region, - void* data) { +absl::Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region, + void* data) { const size_t origin[] = {0, 0, 0}; const size_t r[] = {static_cast(region.x), static_cast(region.y), @@ -120,45 +122,47 @@ Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region, auto error_code = clEnqueueReadImage(queue_, memory, CL_TRUE, origin, r, 0, 0, data, 0, nullptr, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to read data from GPU (clEnqueueReadImage) - ", CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status CLCommandQueue::EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, - const void* data) { +absl::Status CLCommandQueue::EnqueueWriteBuffer(cl_mem memory, + size_t size_in_bytes, + const void* data) { auto error_code = clEnqueueWriteBuffer( queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to upload data to GPU (clEnqueueWriteBuffer) - ", CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status CLCommandQueue::EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, - void* data) { +absl::Status CLCommandQueue::EnqueueReadBuffer(cl_mem memory, + size_t size_in_bytes, + void* data) { auto error_code = clEnqueueReadBuffer( queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to read data from GPU (clEnqueueReadBuffer) - ", CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status CLCommandQueue::WaitForCompletion() { +absl::Status CLCommandQueue::WaitForCompletion() { auto error_code = clFinish(queue_); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to clFinish - ", CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } ProfilingCommandQueue::ProfilingCommandQueue(cl_command_queue queue) @@ -187,14 +191,14 @@ void ProfilingCommandQueue::SetEventsLabel(const std::string& name) { void ProfilingCommandQueue::ResetMeasurements() { events_.clear(); } -Status ProfilingCommandQueue::DispatchImplicit(const CLKernel& kernel, - int3 grid, - int3 work_group_size) { +absl::Status ProfilingCommandQueue::DispatchImplicit(const CLKernel& kernel, + int3 grid, + int3 work_group_size) { events_.push_back(CLEvent()); RETURN_IF_ERROR(CLCommandQueue::DispatchImplicit( kernel, grid, work_group_size, &events_[events_.size() - 1])); events_.back().SetName(current_label_); - return OkStatus(); + return absl::OkStatus(); } ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const { @@ -208,7 +212,7 @@ ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const { return result; } -Status ProfilingCommandQueue::GetBestWorkGroupIndex( +absl::Status ProfilingCommandQueue::GetBestWorkGroupIndex( const CLKernel& kernel, const DeviceInfo& device_info, const int3& grid, const std::vector& work_group_sizes, int* index) { // Some Adreno 3xx can have wrong numbers for some events @@ -268,20 +272,22 @@ Status ProfilingCommandQueue::GetBestWorkGroupIndex( *index = minimum_index; - return OkStatus(); + return absl::OkStatus(); } -Status CreateCLCommandQueue(const CLDevice& device, const CLContext& context, - CLCommandQueue* result) { +absl::Status CreateCLCommandQueue(const CLDevice& device, + const CLContext& context, + CLCommandQueue* result) { int error_code; cl_command_queue queue = clCreateCommandQueue(context.context(), device.id(), 0, &error_code); if (!queue) { - return UnknownError(absl::StrCat("Failed to create a command queue - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError( + absl::StrCat("Failed to create a command queue - ", + CLErrorCodeToString(error_code))); } *result = CLCommandQueue(queue, true); - return OkStatus(); + return absl::OkStatus(); } double ProfilingCommandQueue::GetQueueExecutionTimeMs() const { @@ -300,19 +306,20 @@ double ProfilingCommandQueue::GetSumOfEventsTimeMs() const { return sum; } -Status CreateProfilingCommandQueue(const CLDevice& device, - const CLContext& context, - ProfilingCommandQueue* result) { +absl::Status CreateProfilingCommandQueue(const CLDevice& device, + const CLContext& context, + ProfilingCommandQueue* result) { int error_code; cl_command_queue queue = clCreateCommandQueue( context.context(), device.id(), CL_QUEUE_PROFILING_ENABLE, &error_code); if (!queue) { - return UnknownError(absl::StrCat("Failed to create a command queue - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError( + absl::StrCat("Failed to create a command queue - ", + CLErrorCodeToString(error_code))); } *result = ProfilingCommandQueue(queue); - return OkStatus(); + return absl::OkStatus(); } absl::Duration ProfilingInfo::GetTotalTime() const { diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h index 84ffeca67eb..178e3b21a1e 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h @@ -74,22 +74,23 @@ class CLCommandQueue { cl_command_queue queue() const { return queue_; } - virtual Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size); + virtual absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size); - Status EnqueueEvent(CLEvent* event); + absl::Status EnqueueEvent(CLEvent* event); - Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size, CLEvent* event); + absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size, CLEvent* event); - Status EnqueueWriteImage(cl_mem memory, int3 region, const void* data); - Status EnqueueReadImage(cl_mem memory, int3 region, void* data); + absl::Status EnqueueWriteImage(cl_mem memory, int3 region, const void* data); + absl::Status EnqueueReadImage(cl_mem memory, int3 region, void* data); - Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, - const void* data); - Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, void* data); + absl::Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, + const void* data); + absl::Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, + void* data); - Status WaitForCompletion(); + absl::Status WaitForCompletion(); protected: void Release(); @@ -109,14 +110,15 @@ class ProfilingCommandQueue : public CLCommandQueue { ProfilingCommandQueue(const ProfilingCommandQueue&) = delete; ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete; - Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size) override; + absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size) override; // will write index for fastest work_group among work_group_sizes - Status GetBestWorkGroupIndex(const CLKernel& kernel, - const DeviceInfo& device_info, const int3& grid, - const std::vector& work_group_sizes, - int* index); + absl::Status GetBestWorkGroupIndex(const CLKernel& kernel, + const DeviceInfo& device_info, + const int3& grid, + const std::vector& work_group_sizes, + int* index); // call ResetMeasurements() to start new seriese of measurements void ResetMeasurements(); @@ -139,12 +141,13 @@ class ProfilingCommandQueue : public CLCommandQueue { std::string current_label_; }; -Status CreateCLCommandQueue(const CLDevice& device, const CLContext& context, - CLCommandQueue* result); +absl::Status CreateCLCommandQueue(const CLDevice& device, + const CLContext& context, + CLCommandQueue* result); -Status CreateProfilingCommandQueue(const CLDevice& device, - const CLContext& context, - ProfilingCommandQueue* result); +absl::Status CreateProfilingCommandQueue(const CLDevice& device, + const CLContext& context, + ProfilingCommandQueue* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/cl_context.cc b/tensorflow/lite/delegates/gpu/cl/cl_context.cc index e9e0ddf724b..e697c78b692 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_context.cc @@ -43,19 +43,21 @@ std::vector GetSupportedImage2DFormats(cl_context context, return result; } -Status CreateCLContext(const CLDevice& device, - cl_context_properties* properties, CLContext* result) { +absl::Status CreateCLContext(const CLDevice& device, + cl_context_properties* properties, + CLContext* result) { int error_code; cl_device_id device_id = device.id(); cl_context context = clCreateContext(properties, 1, &device_id, nullptr, nullptr, &error_code); if (!context) { - return UnknownError(absl::StrCat("Failed to create a compute context - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError( + absl::StrCat("Failed to create a compute context - ", + CLErrorCodeToString(error_code))); } *result = CLContext(context, true); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -99,15 +101,16 @@ bool CLContext::IsFloatTexture2DSupported(int num_channels, DataType data_type, return false; } -Status CreateCLContext(const CLDevice& device, CLContext* result) { +absl::Status CreateCLContext(const CLDevice& device, CLContext* result) { return CreateCLContext(device, nullptr, result); } -Status CreateCLGLContext(const CLDevice& device, - cl_context_properties egl_context, - cl_context_properties egl_display, CLContext* result) { +absl::Status CreateCLGLContext(const CLDevice& device, + cl_context_properties egl_context, + cl_context_properties egl_display, + CLContext* result) { if (!device.SupportsExtension("cl_khr_gl_sharing")) { - return UnavailableError("Device doesn't support CL-GL sharing."); + return absl::UnavailableError("Device doesn't support CL-GL sharing."); } cl_context_properties platform = reinterpret_cast(device.platform()); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_context.h b/tensorflow/lite/delegates/gpu/cl/cl_context.h index 20ec35f2b60..11922bd3678 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_context.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_context.h @@ -51,10 +51,11 @@ class CLContext { bool has_ownership_ = false; }; -Status CreateCLContext(const CLDevice& device, CLContext* result); -Status CreateCLGLContext(const CLDevice& device, - cl_context_properties egl_context, - cl_context_properties egl_display, CLContext* result); +absl::Status CreateCLContext(const CLDevice& device, CLContext* result); +absl::Status CreateCLGLContext(const CLDevice& device, + cl_context_properties egl_context, + cl_context_properties egl_display, + CLContext* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.cc b/tensorflow/lite/delegates/gpu/cl/cl_device.cc index c47f86a2928..5380c9ee653 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_device.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_device.cc @@ -516,11 +516,11 @@ void CLDevice::DisableOneLayerTextureArray() { info_.adreno_info.support_one_layer_texture_array = false; } -Status CreateDefaultGPUDevice(CLDevice* result) { +absl::Status CreateDefaultGPUDevice(CLDevice* result) { cl_uint num_platforms; clGetPlatformIDs(0, nullptr, &num_platforms); if (num_platforms == 0) { - return UnknownError("No supported OpenCL platform."); + return absl::UnknownError("No supported OpenCL platform."); } std::vector platforms(num_platforms); clGetPlatformIDs(num_platforms, platforms.data(), nullptr); @@ -529,7 +529,7 @@ Status CreateDefaultGPUDevice(CLDevice* result) { cl_uint num_devices; clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices); if (num_devices == 0) { - return UnknownError("No GPU on current platform."); + return absl::UnknownError("No GPU on current platform."); } std::vector devices(num_devices); @@ -537,7 +537,7 @@ Status CreateDefaultGPUDevice(CLDevice* result) { nullptr); *result = CLDevice(devices[0], platform_id); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.h b/tensorflow/lite/delegates/gpu/cl/cl_device.h index 7b3493e3faa..cbc95d485b9 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_device.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_device.h @@ -191,7 +191,7 @@ class CLDevice { DeviceInfo info_; }; -Status CreateDefaultGPUDevice(CLDevice* result); +absl::Status CreateDefaultGPUDevice(CLDevice* result); template T GetDeviceInfo(cl_device_id id, cl_device_info info) { @@ -204,12 +204,12 @@ T GetDeviceInfo(cl_device_id id, cl_device_info info) { } template -Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) { +absl::Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) { cl_int error = clGetDeviceInfo(id, info, sizeof(T), result, nullptr); if (error != CL_SUCCESS) { - return InvalidArgumentError(CLErrorCodeToString(error)); + return absl::InvalidArgumentError(CLErrorCodeToString(error)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/cl_errors.h b/tensorflow/lite/delegates/gpu/cl/cl_errors.h index 8c16b2696d7..fb59766bd18 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_errors.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_errors.h @@ -27,11 +27,12 @@ namespace cl { // @return if error_code is success, then return OK status. Otherwise translates // error code into a message. -inline Status GetOpenCLError(cl_int error_code) { +inline absl::Status GetOpenCLError(cl_int error_code) { if (error_code == CL_SUCCESS) { - return OkStatus(); + return absl::OkStatus(); } - return InternalError("OpenCL error: " + CLErrorCodeToString(error_code)); + return absl::InternalError("OpenCL error: " + + CLErrorCodeToString(error_code)); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc b/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc index 27d4d36c68a..04bf95d870a 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc @@ -25,34 +25,34 @@ namespace gpu { namespace cl { namespace { -Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id, - int* result) { +absl::Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id, + int* result) { size_t max_work_group_size; cl_int error_code = clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE, sizeof(size_t), &max_work_group_size, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ", CLErrorCodeToString(error_code))); } *result = static_cast(max_work_group_size); - return OkStatus(); + return absl::OkStatus(); } -Status GetKernelPrivateMemorySize(cl_kernel kernel, cl_device_id device_id, - int* result) { +absl::Status GetKernelPrivateMemorySize(cl_kernel kernel, + cl_device_id device_id, int* result) { cl_ulong private_mem_size; cl_int error_code = clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE, sizeof(cl_ulong), &private_mem_size, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ", CLErrorCodeToString(error_code))); } *result = static_cast(private_mem_size); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -82,17 +82,17 @@ CLKernel& CLKernel::operator=(CLKernel&& kernel) { CLKernel::~CLKernel() { Release(); } -Status CLKernel::ReInit() const { +absl::Status CLKernel::ReInit() const { clReleaseKernel(kernel_); cl_kernel* kern_ptr = const_cast(&kernel_); int error_code; *kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code); if (!kernel_ || error_code != CL_SUCCESS) { *kern_ptr = nullptr; - return UnknownError(absl::StrCat("Failed to create ", function_name_, - CLErrorCodeToString(error_code))); + return absl::UnknownError(absl::StrCat("Failed to create ", function_name_, + CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } void CLKernel::Release() { @@ -103,16 +103,16 @@ void CLKernel::Release() { } } -Status CLKernel::CreateFromProgram(const CLProgram& program, - const std::string& function_name) { +absl::Status CLKernel::CreateFromProgram(const CLProgram& program, + const std::string& function_name) { int error_code; function_name_ = function_name; kernel_ = clCreateKernel(program.program(), function_name.c_str(), &error_code); if (!kernel_ || error_code != CL_SUCCESS) { kernel_ = nullptr; - return UnknownError(absl::StrCat("Failed to create ", function_name, - CLErrorCodeToString(error_code))); + return absl::UnknownError(absl::StrCat("Failed to create ", function_name, + CLErrorCodeToString(error_code))); } program_ = program.program(); @@ -122,64 +122,64 @@ Status CLKernel::CreateFromProgram(const CLProgram& program, &private_memory_size_)); RETURN_IF_ERROR(GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(), &max_work_group_size_)); - return OkStatus(); + return absl::OkStatus(); } -Status CLKernel::SetMemory(int index, cl_mem memory) { +absl::Status CLKernel::SetMemory(int index, cl_mem memory) { return SetBytes(index, &memory, sizeof(cl_mem)); } -Status CLKernel::SetMemoryAuto(cl_mem memory) { +absl::Status CLKernel::SetMemoryAuto(cl_mem memory) { return SetBytesAuto(&memory, sizeof(cl_mem)); } -Status CLKernel::SetBytes(int index, const void* ptr, int length) const { +absl::Status CLKernel::SetBytes(int index, const void* ptr, int length) const { const int error_code = clSetKernelArg(kernel_, index, length, ptr); if (error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat("Failed to set kernel arguments - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ", + CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status CLKernel::SetBytesAuto(const void* ptr, int length) { +absl::Status CLKernel::SetBytesAuto(const void* ptr, int length) { const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr); if (error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat("Failed to set kernel arguments - ", - CLErrorCodeToString(error_code), - "(at index - ", binding_counter_, ")")); + return absl::UnknownError(absl::StrCat( + "Failed to set kernel arguments - ", CLErrorCodeToString(error_code), + "(at index - ", binding_counter_, ")")); } binding_counter_++; - return OkStatus(); + return absl::OkStatus(); } template <> -Status CLKernel::SetBytes(int index, const FLT& value) const { +absl::Status CLKernel::SetBytes(int index, const FLT& value) const { return SetBytes(index, value.GetData(), value.GetSize()); } template <> -Status CLKernel::SetBytes(int index, const FLT2& value) const { +absl::Status CLKernel::SetBytes(int index, const FLT2& value) const { return SetBytes(index, value.GetData(), value.GetSize()); } template <> -Status CLKernel::SetBytes(int index, const FLT4& value) const { +absl::Status CLKernel::SetBytes(int index, const FLT4& value) const { return SetBytes(index, value.GetData(), value.GetSize()); } template <> -Status CLKernel::SetBytesAuto(const FLT& value) { +absl::Status CLKernel::SetBytesAuto(const FLT& value) { return SetBytesAuto(value.GetData(), value.GetSize()); } template <> -Status CLKernel::SetBytesAuto(const FLT2& value) { +absl::Status CLKernel::SetBytesAuto(const FLT2& value) { return SetBytesAuto(value.GetData(), value.GetSize()); } template <> -Status CLKernel::SetBytesAuto(const FLT4& value) { +absl::Status CLKernel::SetBytesAuto(const FLT4& value) { return SetBytesAuto(value.GetData(), value.GetSize()); } diff --git a/tensorflow/lite/delegates/gpu/cl/cl_kernel.h b/tensorflow/lite/delegates/gpu/cl/cl_kernel.h index 3b63e43c967..b575684d2b4 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_kernel.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_kernel.h @@ -48,17 +48,17 @@ class CLKernel { cl_kernel kernel() const { return kernel_; } - Status CreateFromProgram(const CLProgram& program, - const std::string& function_name); + absl::Status CreateFromProgram(const CLProgram& program, + const std::string& function_name); - Status SetMemory(int index, cl_mem memory); - Status SetMemoryAuto(cl_mem memory); + absl::Status SetMemory(int index, cl_mem memory); + absl::Status SetMemoryAuto(cl_mem memory); template - Status SetBytes(int index, const T& value) const { + absl::Status SetBytes(int index, const T& value) const { return SetBytes(index, static_cast(&value), sizeof(T)); } template - Status SetBytesAuto(const T& value) { + absl::Status SetBytesAuto(const T& value) { return SetBytesAuto(static_cast(&value), sizeof(T)); } @@ -69,12 +69,12 @@ class CLKernel { // Do not use this function // workaround for Mali memory leak - Status ReInit() const; + absl::Status ReInit() const; private: void Release(); - Status SetBytes(int index, const void* ptr, int length) const; - Status SetBytesAuto(const void* ptr, int length); + absl::Status SetBytes(int index, const void* ptr, int length) const; + absl::Status SetBytesAuto(const void* ptr, int length); int private_memory_size_; int max_work_group_size_; @@ -87,22 +87,22 @@ class CLKernel { }; template <> -Status CLKernel::SetBytes(int index, const FLT& value) const; +absl::Status CLKernel::SetBytes(int index, const FLT& value) const; template <> -Status CLKernel::SetBytes(int index, const FLT2& value) const; +absl::Status CLKernel::SetBytes(int index, const FLT2& value) const; template <> -Status CLKernel::SetBytes(int index, const FLT4& value) const; +absl::Status CLKernel::SetBytes(int index, const FLT4& value) const; template <> -Status CLKernel::SetBytesAuto(const FLT& value); +absl::Status CLKernel::SetBytesAuto(const FLT& value); template <> -Status CLKernel::SetBytesAuto(const FLT2& value); +absl::Status CLKernel::SetBytesAuto(const FLT2& value); template <> -Status CLKernel::SetBytesAuto(const FLT4& value); +absl::Status CLKernel::SetBytesAuto(const FLT4& value); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/cl_program.cc b/tensorflow/lite/delegates/gpu/cl/cl_program.cc index 3592ad895ea..690bc598777 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_program.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_program.cc @@ -49,28 +49,29 @@ std::string GetProgramBuildInfo(cl_program program, cl_device_id id, return result; } -Status GetBinarySize(cl_program program, size_t* binary_size) { +absl::Status GetBinarySize(cl_program program, size_t* binary_size) { cl_int error_code = clGetProgramInfo(program, CL_PROGRAM_BINARY_SIZES, sizeof(size_t), binary_size, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat("Failed to get program binary size - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError( + absl::StrCat("Failed to get program binary size - ", + CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status BuildProgram(cl_program program, const CLDevice& device, - const std::string& compiler_options) { +absl::Status BuildProgram(cl_program program, const CLDevice& device, + const std::string& compiler_options) { const int error_code = clBuildProgram( program, 0, nullptr, compiler_options.c_str(), nullptr, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat( + return absl::UnknownError(absl::StrCat( "Failed to build program executable - ", CLErrorCodeToString(error_code), GetProgramBuildInfo(program, device.id(), CL_PROGRAM_BUILD_LOG))); } - return OkStatus(); + return absl::OkStatus(); } std::string CompilerOptionToString(const CLDevice& device, @@ -133,7 +134,7 @@ void CLProgram::Release() { } } -Status CLProgram::GetBinary(std::vector* result) const { +absl::Status CLProgram::GetBinary(std::vector* result) const { size_t binary_size; RETURN_IF_ERROR(GetBinarySize(program_, &binary_size)); result->resize(result->size() + binary_size); @@ -141,35 +142,36 @@ Status CLProgram::GetBinary(std::vector* result) const { cl_int error_code = clGetProgramInfo(program_, CL_PROGRAM_BINARIES, binary_size, &binary_ptr, nullptr); if (error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat("Failed to get program binary - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError(absl::StrCat("Failed to get program binary - ", + CLErrorCodeToString(error_code))); } - return OkStatus(); + return absl::OkStatus(); } -Status CreateCLProgram(const std::string& code, - const std::string& compiler_options, - const CLContext& context, const CLDevice& device, - CLProgram* result) { +absl::Status CreateCLProgram(const std::string& code, + const std::string& compiler_options, + const CLContext& context, const CLDevice& device, + CLProgram* result) { int error_code; const char* source = code.c_str(); cl_program program = clCreateProgramWithSource(context.context(), 1, &source, nullptr, &error_code); if (!program || error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat("Failed to create compute program - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError( + absl::StrCat("Failed to create compute program - ", + CLErrorCodeToString(error_code))); } *result = CLProgram(program, device.id()); RETURN_IF_ERROR(BuildProgram(program, device, compiler_options)); - return OkStatus(); + return absl::OkStatus(); } -Status CreateCLProgramFromBinary(const CLContext& context, - const CLDevice& device, - absl::Span binary, - CLProgram* result) { +absl::Status CreateCLProgramFromBinary(const CLContext& context, + const CLDevice& device, + absl::Span binary, + CLProgram* result) { cl_int binary_status; cl_int error_code; cl_device_id devices_list[] = {device.id()}; @@ -179,13 +181,13 @@ Status CreateCLProgramFromBinary(const CLContext& context, context.context(), 1, devices_list, &binary_size, &binary_pointer, &binary_status, &error_code); if (binary_status != CL_SUCCESS) { - return UnknownError(absl::StrCat( + return absl::UnknownError(absl::StrCat( "Something wrong with binary after clCreateProgramWithBinary - ", binary_status)); } if (error_code != CL_SUCCESS) { - return UnknownError(absl::StrCat("Failed to create program - ", - CLErrorCodeToString(error_code))); + return absl::UnknownError(absl::StrCat("Failed to create program - ", + CLErrorCodeToString(error_code))); } *result = CLProgram(program, device.id()); return BuildProgram(program, device, ""); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_program.h b/tensorflow/lite/delegates/gpu/cl/cl_program.h index b6deb3beb95..fb2a7edb9c1 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_program.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_program.h @@ -68,7 +68,7 @@ class CLProgram { // was created using clCreateProgramWithBinary. cl_device_id GetDeviceId() const { return device_id_; } - Status GetBinary(std::vector* result) const; + absl::Status GetBinary(std::vector* result) const; private: void Release(); @@ -79,15 +79,15 @@ class CLProgram { cl_device_id device_id_ = nullptr; }; -Status CreateCLProgram(const std::string& code, - const std::string& compiler_options, - const CLContext& context, const CLDevice& device, - CLProgram* result); +absl::Status CreateCLProgram(const std::string& code, + const std::string& compiler_options, + const CLContext& context, const CLDevice& device, + CLProgram* result); -Status CreateCLProgramFromBinary(const CLContext& context, - const CLDevice& device, - absl::Span binary, - CLProgram* result); +absl::Status CreateCLProgramFromBinary(const CLContext& context, + const CLDevice& device, + absl::Span binary, + CLProgram* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/egl_sync.cc b/tensorflow/lite/delegates/gpu/cl/egl_sync.cc index 8493fbb049f..ddc373bce31 100644 --- a/tensorflow/lite/delegates/gpu/cl/egl_sync.cc +++ b/tensorflow/lite/delegates/gpu/cl/egl_sync.cc @@ -21,15 +21,15 @@ namespace tflite { namespace gpu { namespace cl { -Status EglSync::NewFence(EGLDisplay display, EglSync* sync) { +absl::Status EglSync::NewFence(EGLDisplay display, EglSync* sync) { EGLSyncKHR egl_sync; RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglCreateSyncKHR, &egl_sync, display, EGL_SYNC_FENCE_KHR, nullptr)); if (egl_sync == EGL_NO_SYNC_KHR) { - return InternalError("Returned empty KHR EGL sync"); + return absl::InternalError("Returned empty KHR EGL sync"); } *sync = EglSync(display, egl_sync); - return OkStatus(); + return absl::OkStatus(); } EglSync& EglSync::operator=(EglSync&& sync) { @@ -48,22 +48,23 @@ void EglSync::Invalidate() { } } -Status EglSync::ServerWait() { +absl::Status EglSync::ServerWait() { EGLint result; RETURN_IF_ERROR( TFLITE_GPU_CALL_EGL(eglWaitSyncKHR, &result, display_, sync_, 0)); - return result == EGL_TRUE ? OkStatus() : InternalError("eglWaitSync failed"); + return result == EGL_TRUE ? absl::OkStatus() + : absl::InternalError("eglWaitSync failed"); } -Status EglSync::ClientWait() { +absl::Status EglSync::ClientWait() { EGLint result; // TODO(akulik): make it active wait for better performance RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglClientWaitSyncKHR, &result, display_, sync_, EGL_SYNC_FLUSH_COMMANDS_BIT_KHR, EGL_FOREVER_KHR)); return result == EGL_CONDITION_SATISFIED_KHR - ? OkStatus() - : InternalError("eglClientWaitSync failed"); + ? absl::OkStatus() + : absl::InternalError("eglClientWaitSync failed"); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/egl_sync.h b/tensorflow/lite/delegates/gpu/cl/egl_sync.h index 27a551c5d59..d0943a797ee 100644 --- a/tensorflow/lite/delegates/gpu/cl/egl_sync.h +++ b/tensorflow/lite/delegates/gpu/cl/egl_sync.h @@ -32,7 +32,7 @@ class EglSync { // flushed. // // Depends on EGL_KHR_fence_sync extension. - static Status NewFence(EGLDisplay display, EglSync* sync); + static absl::Status NewFence(EGLDisplay display, EglSync* sync); // Creates invalid object. EglSync() : EglSync(EGL_NO_DISPLAY, EGL_NO_SYNC_KHR) {} @@ -50,10 +50,10 @@ class EglSync { // Causes GPU to block and wait until this sync has been signaled. // This call does not block and returns immediately. - Status ServerWait(); + absl::Status ServerWait(); // Causes CPU to block and wait until this sync has been signaled. - Status ClientWait(); + absl::Status ClientWait(); // Returns the EGLDisplay on which this instance was created. EGLDisplay display() const { return display_; } diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc index ca13e19f73f..01d034fb1f7 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.cc +++ b/tensorflow/lite/delegates/gpu/cl/environment.cc @@ -26,6 +26,7 @@ namespace tflite { namespace gpu { namespace cl { namespace { + std::string GetKernelOneLayerTextureArray() { return R"( @@ -43,12 +44,12 @@ __kernel void main_function(__write_only image2d_array_t dst) { // texture, we will get zeroes instead of actual values. // The same kernel will work, if we use texture array with more than one layer. // With help of this code we can detect this bug. -Status CheckKernelSupportOfOneLayerTextureArray(Environment* env, - bool* result) { +absl::Status CheckKernelSupportOfOneLayerTextureArray(Environment* env, + bool* result) { // No bug on Adreno 6xx if (env->device().GetInfo().adreno_info.gpu_version >= 600) { *result = true; - return OkStatus(); + return absl::OkStatus(); } CLKernel kernel; RETURN_IF_ERROR(env->program_cache()->GetOrCreateCLKernel( @@ -75,12 +76,12 @@ Status CheckKernelSupportOfOneLayerTextureArray(Environment* env, break; } } - return OkStatus(); + return absl::OkStatus(); } -Status CreateEnvironment(Environment* result, bool shared, - cl_context_properties egl_context, - cl_context_properties egl_display) { +absl::Status CreateEnvironment(Environment* result, bool shared, + cl_context_properties egl_context, + cl_context_properties egl_display) { CLDevice gpu; RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu)); @@ -107,8 +108,9 @@ Status CreateEnvironment(Environment* result, bool shared, } } - return OkStatus(); + return absl::OkStatus(); } + } // namespace Environment::Environment(CLDevice&& device, CLContext&& context, @@ -137,7 +139,7 @@ Environment& Environment::operator=(Environment&& environment) { return *this; } -Status Environment::Init() { +absl::Status Environment::Init() { if (device().IsAdreno() && device().SupportsTextureArray()) { bool supports_one_layer; RETURN_IF_ERROR( @@ -146,7 +148,7 @@ Status Environment::Init() { GetDevicePtr()->DisableOneLayerTextureArray(); } } - return OkStatus(); + return absl::OkStatus(); } void Environment::SetHighPerformance() const { @@ -266,7 +268,7 @@ TensorStorageType GetStorageTypeWithMinimalMemoryConsumption( return TensorStorageType::BUFFER; } -Status CreateEnvironment(Environment* result) { +absl::Status CreateEnvironment(Environment* result) { CLDevice gpu; RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu)); diff --git a/tensorflow/lite/delegates/gpu/cl/environment.h b/tensorflow/lite/delegates/gpu/cl/environment.h index 496d6957623..b40d22d3dd6 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.h +++ b/tensorflow/lite/delegates/gpu/cl/environment.h @@ -57,7 +57,7 @@ class Environment { std::vector GetSupportedStorages() const; bool IsSupported(TensorStorageType storage_type) const; - Status Init(); + absl::Status Init(); void SetHighPerformance() const; void SetDefaultPerformance() const; @@ -75,7 +75,7 @@ TensorStorageType GetFastestStorageType(const CLDevice& gpu); TensorStorageType GetStorageTypeWithMinimalMemoryConsumption( const CLDevice& gpu); -Status CreateEnvironment(Environment* result); +absl::Status CreateEnvironment(Environment* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc index f4db12bf133..648b772d827 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc @@ -41,10 +41,11 @@ PFNEGLCREATESYNCPROC g_eglCreateSync = nullptr; } // namespace -Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, - EglSync* sync) { +absl::Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, + EglSync* sync) { if (!IsEglSyncFromClEventSupported()) { - return UnimplementedError("CreateEglSyncFromClEvent is not supported"); + return absl::UnimplementedError( + "CreateEglSyncFromClEvent is not supported"); } EGLSync egl_sync; const EGLAttrib attributes[] = {EGL_CL_EVENT_HANDLE, @@ -52,10 +53,10 @@ Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(g_eglCreateSync, &egl_sync, display, EGL_SYNC_CL_EVENT, attributes)); if (egl_sync == EGL_NO_SYNC) { - return InternalError("Returned empty EGL sync"); + return absl::InternalError("Returned empty EGL sync"); } *sync = EglSync(display, egl_sync); - return OkStatus(); + return absl::OkStatus(); } bool IsEglSyncFromClEventSupported() { @@ -73,52 +74,54 @@ bool IsEglSyncFromClEventSupported() { return supported; } -Status CreateClEventFromEglSync(cl_context context, const EglSync& egl_sync, - CLEvent* event) { +absl::Status CreateClEventFromEglSync(cl_context context, + const EglSync& egl_sync, CLEvent* event) { cl_int error_code; cl_event new_event = clCreateEventFromEGLSyncKHR( context, egl_sync.sync(), egl_sync.display(), &error_code); if (error_code != CL_SUCCESS) { - return InternalError( + return absl::InternalError( absl::StrCat("Unable to create CL sync from EGL sync. ", CLErrorCodeToString(error_code))); } *event = CLEvent(new_event); - return OkStatus(); + return absl::OkStatus(); } bool IsClEventFromEglSyncSupported(const CLDevice& device) { return device.SupportsExtension("cl_khr_egl_event"); } -Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, AccessType access_type, - CLContext* context, CLMemory* memory) { +absl::Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, + AccessType access_type, + CLContext* context, CLMemory* memory) { cl_int error_code; auto mem = clCreateFromGLBuffer(context->context(), ToClMemFlags(access_type), gl_ssbo_id, &error_code); if (error_code != CL_SUCCESS) { - return InternalError( + return absl::InternalError( absl::StrCat("Unable to acquire CL buffer from GL buffer. ", CLErrorCodeToString(error_code))); } *memory = CLMemory(mem, true); - return OkStatus(); + return absl::OkStatus(); } -Status CreateClMemoryFromGlTexture(GLenum texture_target, GLuint texture_id, - AccessType access_type, CLContext* context, - CLMemory* memory) { +absl::Status CreateClMemoryFromGlTexture(GLenum texture_target, + GLuint texture_id, + AccessType access_type, + CLContext* context, CLMemory* memory) { cl_int error_code; auto mem = clCreateFromGLTexture(context->context(), ToClMemFlags(access_type), texture_target, 0, texture_id, &error_code); if (error_code != CL_SUCCESS) { - return InternalError( + return absl::InternalError( absl::StrCat("Unable to create CL buffer from GL texture. ", CLErrorCodeToString(error_code))); } *memory = CLMemory(mem, true); - return OkStatus(); + return absl::OkStatus(); } bool IsGlSharingSupported(const CLDevice& device) { @@ -128,19 +131,18 @@ bool IsGlSharingSupported(const CLDevice& device) { AcquiredGlObjects::~AcquiredGlObjects() { Release({}, nullptr).IgnoreError(); } -Status AcquiredGlObjects::Acquire(const std::vector& memory, - cl_command_queue queue, - const std::vector& wait_events, - CLEvent* acquire_event, - AcquiredGlObjects* objects) { +absl::Status AcquiredGlObjects::Acquire( + const std::vector& memory, cl_command_queue queue, + const std::vector& wait_events, CLEvent* acquire_event, + AcquiredGlObjects* objects) { if (!memory.empty()) { cl_event new_event; cl_int error_code = clEnqueueAcquireGLObjects( queue, memory.size(), memory.data(), wait_events.size(), wait_events.data(), acquire_event ? &new_event : nullptr); if (error_code != CL_SUCCESS) { - return InternalError(absl::StrCat("Unable to acquire GL object. ", - CLErrorCodeToString(error_code))); + return absl::InternalError(absl::StrCat("Unable to acquire GL object. ", + CLErrorCodeToString(error_code))); } if (acquire_event) { *acquire_event = CLEvent(new_event); @@ -148,19 +150,19 @@ Status AcquiredGlObjects::Acquire(const std::vector& memory, clFlush(queue); } *objects = AcquiredGlObjects(memory, queue); - return OkStatus(); + return absl::OkStatus(); } -Status AcquiredGlObjects::Release(const std::vector& wait_events, - CLEvent* release_event) { +absl::Status AcquiredGlObjects::Release( + const std::vector& wait_events, CLEvent* release_event) { if (queue_ && !memory_.empty()) { cl_event new_event; cl_int error_code = clEnqueueReleaseGLObjects( queue_, memory_.size(), memory_.data(), wait_events.size(), wait_events.data(), release_event ? &new_event : nullptr); if (error_code != CL_SUCCESS) { - return InternalError(absl::StrCat("Unable to release GL object. ", - CLErrorCodeToString(error_code))); + return absl::InternalError(absl::StrCat("Unable to release GL object. ", + CLErrorCodeToString(error_code))); } if (release_event) { *release_event = CLEvent(new_event); @@ -168,7 +170,7 @@ Status AcquiredGlObjects::Release(const std::vector& wait_events, clFlush(queue_); queue_ = nullptr; } - return OkStatus(); + return absl::OkStatus(); } GlInteropFabric::GlInteropFabric(EGLDisplay egl_display, @@ -192,9 +194,9 @@ void GlInteropFabric::UnregisterMemory(cl_mem memory) { } } -Status GlInteropFabric::Start() { +absl::Status GlInteropFabric::Start() { if (!is_enabled()) { - return OkStatus(); + return absl::OkStatus(); } // In GL-CL interoperability, we need to make sure GL finished processing of @@ -235,9 +237,9 @@ Status GlInteropFabric::Start() { nullptr, &gl_objects_); } -Status GlInteropFabric::Finish() { +absl::Status GlInteropFabric::Finish() { if (!is_enabled()) { - return OkStatus(); + return absl::OkStatus(); } RETURN_IF_ERROR(gl_objects_.Release({}, &outbound_event_)); @@ -258,7 +260,7 @@ Status GlInteropFabric::Finish() { // This slow sync is the only working solution right now. We have to debug why // above version is not working fast and reliable. outbound_event_.Wait(); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.h b/tensorflow/lite/delegates/gpu/cl/gl_interop.h index 597bee857c6..7ebc3e4bf4f 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.h +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.h @@ -39,8 +39,8 @@ namespace cl { // returned sync and could be safely destroyed. // // Depends on EGL 1.5. -Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, - EglSync* sync); +absl::Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, + EglSync* sync); // Returns true if 'CreateEglSyncFromClEvent' is supported. bool IsEglSyncFromClEventSupported(); @@ -48,20 +48,22 @@ bool IsEglSyncFromClEventSupported(); // Creates CL event from EGL sync. // Created event could only be consumed by AcquiredGlObject::Acquire call as // a 'wait_event'. -Status CreateClEventFromEglSync(cl_context context, const EglSync& egl_sync, - CLEvent* event); +absl::Status CreateClEventFromEglSync(cl_context context, + const EglSync& egl_sync, CLEvent* event); // Returns true if 'CreateClEventFromEglSync' is supported. bool IsClEventFromEglSyncSupported(const CLDevice& device); // Creates new CL memory object from OpenGL buffer. -Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, AccessType access_type, - CLContext* context, CLMemory* memory); +absl::Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, + AccessType access_type, + CLContext* context, CLMemory* memory); // Creates new CL memory object from OpenGL texture. -Status CreateClMemoryFromGlTexture(GLenum texture_target, GLuint texture_id, - AccessType access_type, CLContext* context, - CLMemory* memory); +absl::Status CreateClMemoryFromGlTexture(GLenum texture_target, + GLuint texture_id, + AccessType access_type, + CLContext* context, CLMemory* memory); // Returns true if GL objects could be shared with OpenCL context. bool IsGlSharingSupported(const CLDevice& device); @@ -81,16 +83,16 @@ class AcquiredGlObjects { // CreateClMemoryFromGlBuffer or CreateClMemoryFromGlTexture calls. // If 'acquire_event' is not nullptr, it will be signared once acquisition is // complete. - static Status Acquire(const std::vector& memory, - cl_command_queue queue, - const std::vector& wait_events, - CLEvent* acquire_event /* optional */, - AcquiredGlObjects* objects); + static absl::Status Acquire(const std::vector& memory, + cl_command_queue queue, + const std::vector& wait_events, + CLEvent* acquire_event /* optional */, + AcquiredGlObjects* objects); // Releases OpenCL memory back to OpenGL context. If 'release_event' is not // nullptr, it will be signalled once release is complete. - Status Release(const std::vector& wait_events, - CLEvent* release_event /* optional */); + absl::Status Release(const std::vector& wait_events, + CLEvent* release_event /* optional */); private: AcquiredGlObjects(const std::vector& memory, cl_command_queue queue) @@ -108,10 +110,10 @@ class GlInteropFabric { // Ensures proper GL->CL synchronization is in place before // GL objects that are mapped to CL objects are used. - Status Start(); + absl::Status Start(); // Puts appropriate CL->GL synchronization after all work is complete. - Status Finish(); + absl::Status Finish(); // Registers memory to be used from GL context. Such CL memory object must // be created with CreateClMemoryFromGlBuffer or CreateClMemoryFromGlTexture diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc index 8e2c3308a47..0e2d046eba2 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc +++ b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc @@ -87,8 +87,8 @@ class Delegate { } } - Status Prepare(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params) { + absl::Status Prepare(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { // Extract TFLite delegate execution plan from the context and convert it // into FlowGraph32. GraphFloat32 graph; @@ -98,7 +98,7 @@ class Delegate { NullTransformationReporter reporter; ModelTransformer transformer(&graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return InternalError("Graph general transformations failed"); + return absl::InternalError("Graph general transformations failed"); } InferenceEnvironmentOptions env_options; @@ -108,7 +108,7 @@ class Delegate { options_.serialized_binary_cache_data, options_.serialized_binary_cache_size}; InferenceEnvironmentProperties properties; - Status status = + absl::Status status = NewInferenceEnvironment(env_options, &environment_, &properties); if (!properties.is_opencl_available) { context->ReportError(context, @@ -200,7 +200,7 @@ class Delegate { return builder->Build(&runner_); } - Status SetInputsAndOutputs(TfLiteContext* context) { + absl::Status SetInputsAndOutputs(TfLiteContext* context) { int i = 0; for (auto index : input_indices_) { RETURN_IF_ERROR( @@ -211,10 +211,10 @@ class Delegate { RETURN_IF_ERROR( runner_->SetOutputObject(i++, GetTensorObject(index, context))); } - return OkStatus(); + return absl::OkStatus(); } - Status Invoke(TfLiteContext* context) { + absl::Status Invoke(TfLiteContext* context) { RETURN_IF_ERROR(SetInputsAndOutputs(context)); return runner_->Run(); } @@ -310,7 +310,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = gpu_delegate->Prepare(context, params); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Init: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return nullptr; } return gpu_delegate; @@ -335,7 +335,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = GetDelegate(node)->Invoke(context); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Invoke: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index 47998bf8c99..2ec911813e6 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -169,9 +169,9 @@ CLNode& CLNode::operator=(CLNode&& node) { return *this; } -Status InferenceContext::InitFromGraph(const CreateInferenceInfo& create_info, - const GraphFloat32& graph, - Environment* env) { +absl::Status InferenceContext::InitFromGraph( + const CreateInferenceInfo& create_info, const GraphFloat32& graph, + Environment* env) { CreationContext creation_context; creation_context.device = env->GetDevicePtr(); creation_context.context = &env->context(); @@ -206,15 +206,15 @@ Status InferenceContext::InitFromGraph(const CreateInferenceInfo& create_info, tuning_parameters.tuning_type = TuningType::FAST; } RETURN_IF_ERROR(Tune(tuning_parameters)); - return OkStatus(); + return absl::OkStatus(); } -Status InferenceContext::InitFromGraphWithTransforms( +absl::Status InferenceContext::InitFromGraphWithTransforms( const CreateInferenceInfo& create_info, GraphFloat32* graph, Environment* env) { RETURN_IF_ERROR(RunGraphTransforms(graph)); RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env)); - return OkStatus(); + return absl::OkStatus(); } void InferenceContext::CopyInAndOutIds(const GraphFloat32& graph) { @@ -258,7 +258,7 @@ void InferenceContext::ReserveGraphTensors( tensor_reserver_.SetNext(max_id + 1); } -Status InferenceContext::ConvertOperations( +absl::Status InferenceContext::ConvertOperations( const CreationContext& creation_context, const GraphFloat32& graph, ModelHints hints) { std::vector graph_nodes = graph.nodes(); @@ -343,7 +343,7 @@ Status InferenceContext::ConvertOperations( } } - return OkStatus(); + return absl::OkStatus(); } void InferenceContext::Merge() { @@ -424,15 +424,15 @@ void InferenceContext::GetUsages( } } -Status InferenceContext::AllocateMemory(const CLDevice& device, - CLContext* context) { +absl::Status InferenceContext::AllocateMemory(const CLDevice& device, + CLContext* context) { RETURN_IF_ERROR(AllocateMemoryForBuffers(device, context)); RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device, context)); - return OkStatus(); + return absl::OkStatus(); } -Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device, - CLContext* context) { +absl::Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device, + CLContext* context) { std::map buffer_usages; GetUsages( [](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); }, @@ -480,11 +480,11 @@ Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device, created_tensors[tensor_index] = true; } } - return OkStatus(); + return absl::OkStatus(); } -Status InferenceContext::AllocateMemoryForStrongShapes(const CLDevice& device, - CLContext* context) { +absl::Status InferenceContext::AllocateMemoryForStrongShapes( + const CLDevice& device, CLContext* context) { std::map usages; GetUsages( [](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); }, @@ -517,7 +517,7 @@ Status InferenceContext::AllocateMemoryForStrongShapes(const CLDevice& device, } } } - return OkStatus(); + return absl::OkStatus(); } void InferenceContext::BindMemoryToOperations() { @@ -539,21 +539,22 @@ void InferenceContext::BindMemoryToOperations() { } } -Status InferenceContext::Compile(const CreationContext& creation_context) { +absl::Status InferenceContext::Compile( + const CreationContext& creation_context) { for (auto& node : nodes_) { RETURN_IF_ERROR(node.operations[0]->Compile(creation_context)); } - return OkStatus(); + return absl::OkStatus(); } -Status InferenceContext::Tune(const TuningParameters& tuning_parameters) { +absl::Status InferenceContext::Tune(const TuningParameters& tuning_parameters) { for (auto& node : nodes_) { RETURN_IF_ERROR(node.operations[0]->Tune(tuning_parameters)); } - return OkStatus(); + return absl::OkStatus(); } -Status InferenceContext::AddToQueue(CLCommandQueue* queue) { +absl::Status InferenceContext::AddToQueue(CLCommandQueue* queue) { if (need_manual_release_) { if (prev_enqueue_start_point_.is_valid()) { prev_enqueue_start_point_.Wait(); @@ -571,11 +572,11 @@ Status InferenceContext::AddToQueue(CLCommandQueue* queue) { if (need_flush_) { clFlush(queue->queue()); } - return OkStatus(); + return absl::OkStatus(); } -Status InferenceContext::Profile(ProfilingCommandQueue* queue, - ProfilingInfo* result) { +absl::Status InferenceContext::Profile(ProfilingCommandQueue* queue, + ProfilingInfo* result) { queue->ResetMeasurements(); for (auto& node : nodes_) { queue->SetEventsLabel(node.name); @@ -583,7 +584,7 @@ Status InferenceContext::Profile(ProfilingCommandQueue* queue, } RETURN_IF_ERROR(queue->WaitForCompletion()); *result = queue->GetProfilingInfo(); - return OkStatus(); + return absl::OkStatus(); } uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors() @@ -608,13 +609,15 @@ Tensor* InferenceContext::GetTensor(ValueId id) { } } -Status InferenceContext::SetInputTensor(ValueId id, const TensorFloat32& tensor, - CLCommandQueue* queue) { +absl::Status InferenceContext::SetInputTensor(ValueId id, + const TensorFloat32& tensor, + CLCommandQueue* queue) { return GetTensor(id)->WriteData(queue, tensor); } -Status InferenceContext::GetOutputTensor(ValueId id, CLCommandQueue* queue, - TensorFloat32* result) { +absl::Status InferenceContext::GetOutputTensor(ValueId id, + CLCommandQueue* queue, + TensorFloat32* result) { const auto& gpu_tensor = *GetTensor(id); const auto dst_shape = BHWC(gpu_tensor.Batch(), gpu_tensor.Height(), gpu_tensor.Width(), gpu_tensor.Channels()); @@ -624,17 +627,17 @@ Status InferenceContext::GetOutputTensor(ValueId id, CLCommandQueue* queue, return gpu_tensor.ReadData(queue, result); } -Status RunGraphTransforms(GraphFloat32* graph) { +absl::Status RunGraphTransforms(GraphFloat32* graph) { auto merge_padding_transform = NewMergePaddingWithAdd(); auto add_bias_transform = NewAddBias(); ModelTransformer transformer(graph, /*reporter=*/nullptr); if (!transformer.Apply("add_bias", add_bias_transform.get())) { - return InternalError("Invalid add_bias transform"); + return absl::InternalError("Invalid add_bias transform"); } if (!transformer.Apply("merge_padding", merge_padding_transform.get())) { - return InternalError("Invalid merge_padding transform"); + return absl::InternalError("Invalid merge_padding transform"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.h b/tensorflow/lite/delegates/gpu/cl/inference_context.h index 40b20e8806a..75365258e41 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.h +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.h @@ -65,53 +65,55 @@ class InferenceContext { TensorStorageType storage_type; ModelHints hints; }; - Status InitFromGraph(const CreateInferenceInfo& create_info, - const GraphFloat32& graph, Environment* env); + absl::Status InitFromGraph(const CreateInferenceInfo& create_info, + const GraphFloat32& graph, Environment* env); // Applies OpenCL-specific transformations to the graph before the // initialization. These transformations are either impossible or useless in // other backends. - Status InitFromGraphWithTransforms(const CreateInferenceInfo& create_info, - GraphFloat32* graph, Environment* env); + absl::Status InitFromGraphWithTransforms( + const CreateInferenceInfo& create_info, GraphFloat32* graph, + Environment* env); - Status AddToQueue(CLCommandQueue* queue); - Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result); + absl::Status AddToQueue(CLCommandQueue* queue); + absl::Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result); // for profiling and memory statistics uint64_t GetSizeOfMemoryAllocatedForIntermediateTensors() const; - Status SetInputTensor(ValueId id, const TensorFloat32& tensor, - CLCommandQueue* queue); + absl::Status SetInputTensor(ValueId id, const TensorFloat32& tensor, + CLCommandQueue* queue); // It will work only with input/output tensor ids. For all other ids we don't // have any guarantees. Tensor* GetTensor(ValueId id); - Status GetOutputTensor(ValueId id, CLCommandQueue* queue, - TensorFloat32* result); + absl::Status GetOutputTensor(ValueId id, CLCommandQueue* queue, + TensorFloat32* result); private: void CopyInAndOutIds(const GraphFloat32& graph); - Status ConvertOperations(const CreationContext& creation_context, - const GraphFloat32& graph, ModelHints hints); + absl::Status ConvertOperations(const CreationContext& creation_context, + const GraphFloat32& graph, ModelHints hints); void CreateLinks(); void ReserveGraphTensors(const CreateInferenceInfo& create_info, const CreationContext& creation_context, const GraphFloat32& graph); void Merge(); - Status AllocateMemory(const CLDevice& device, CLContext* context); + absl::Status AllocateMemory(const CLDevice& device, CLContext* context); - Status AllocateMemoryForBuffers(const CLDevice& device, CLContext* context); + absl::Status AllocateMemoryForBuffers(const CLDevice& device, + CLContext* context); - Status AllocateMemoryForStrongShapes(const CLDevice& device, - CLContext* context); + absl::Status AllocateMemoryForStrongShapes(const CLDevice& device, + CLContext* context); // utility function void GetUsages(const std::function& functor, std::map* usages); void BindMemoryToOperations(); - Status Compile(const CreationContext& creation_context); - Status Tune(const TuningParameters& tuning_parameters); + absl::Status Compile(const CreationContext& creation_context); + absl::Status Tune(const TuningParameters& tuning_parameters); // performance hacks bool need_flush_ = false; @@ -175,7 +177,7 @@ class InferenceContext { }; // Runs OpenCL specific transforms for the graph. -Status RunGraphTransforms(GraphFloat32* graph); +absl::Status RunGraphTransforms(GraphFloat32* graph); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc index b5c37c5987f..0c96f4316ec 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc @@ -143,17 +143,17 @@ std::string Add::GetArgsDeclaration() const { return args; } -Status Add::BindArguments(CLKernel* kernel) { +absl::Status Add::BindArguments(CLKernel* kernel) { for (int i = 1; i < src_depthes_.size(); ++i) { RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[i]->GetMemoryPtr())); } for (int i = 1; i < src_depthes_.size(); ++i) { RETURN_IF_ERROR(kernel->SetBytesAuto(src_[i]->GetWBatchedHSB())); } - return OkStatus(); + return absl::OkStatus(); } -Status Add::Compile(const CreationContext& creation_context) { +absl::Status Add::Compile(const CreationContext& creation_context) { const auto code = GetElementWiseCode(definition_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.h b/tensorflow/lite/delegates/gpu/cl/kernels/add.h index ac6243cc5e4..d47954748c7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/add.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.h @@ -36,7 +36,7 @@ class Add : public ElementwiseOperation { Add(const OperationDef& definition, const std::vector& channels, int dst_channels); - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Add(Add&& operation); @@ -47,7 +47,7 @@ class Add : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - Status BindArguments(CLKernel* kernel) override; + absl::Status BindArguments(CLKernel* kernel) override; private: std::string GetElementWiseCode( diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc index ad4b54853e1..deb0ebf67c4 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc @@ -21,17 +21,17 @@ namespace tflite { namespace gpu { namespace cl { -Status ExecuteGPUOperation(const std::vector& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, - const std::vector& dst_sizes, - const std::vector& dst_cpu) { +absl::Status ExecuteGPUOperation(const std::vector& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, + const std::vector& dst_sizes, + const std::vector& dst_cpu) { const OperationDef& op_def = operation->GetDefinition(); std::vector src(src_cpu.size()); for (int i = 0; i < src_cpu.size(); ++i) { auto src_shape = src_cpu[i].shape; if (src_shape.b != 1 && !op_def.IsBatchSupported()) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Layout doesn't have Batch dimension, but shape.b != 1"); } RETURN_IF_ERROR(CreateTensor(*creation_context.context, @@ -45,7 +45,7 @@ Status ExecuteGPUOperation(const std::vector& src_cpu, for (int i = 0; i < dst_cpu.size(); ++i) { auto dst_shape = dst_sizes[i]; if (dst_shape.b != 1 && !op_def.IsBatchSupported()) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Layout doesn't have Batch dimension, but shape.b != 1"); } RETURN_IF_ERROR(CreateTensor(*creation_context.context, @@ -64,22 +64,22 @@ Status ExecuteGPUOperation(const std::vector& src_cpu, dst_cpu[i]->data = std::vector(dst_sizes[i].DimensionsProduct(), 0); RETURN_IF_ERROR(dst[i].ReadData(creation_context.queue, dst_cpu[i])); } - return OkStatus(); + return absl::OkStatus(); } -Status ExecuteGPUOperation(const std::vector& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, const BHWC& dst_size, - TensorFloat32* result) { +absl::Status ExecuteGPUOperation(const std::vector& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, const BHWC& dst_size, + TensorFloat32* result) { return ExecuteGPUOperation( std::vector{src_cpu}, creation_context, operation, std::vector{dst_size}, std::vector{result}); } -Status ExecuteGPUOperation(const TensorFloat32& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, const BHWC& dst_size, - TensorFloat32* result) { +absl::Status ExecuteGPUOperation(const TensorFloat32& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, const BHWC& dst_size, + TensorFloat32* result) { return ExecuteGPUOperation(std::vector{src_cpu}, creation_context, operation, dst_size, result); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h index c127d1bacd3..4d3636d0384 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h @@ -51,21 +51,21 @@ class OpenCLOperationTest : public ::testing::Test { CreationContext creation_context_; }; -Status ExecuteGPUOperation(const TensorFloat32& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, const BHWC& dst_size, - TensorFloat32* result); +absl::Status ExecuteGPUOperation(const TensorFloat32& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, const BHWC& dst_size, + TensorFloat32* result); -Status ExecuteGPUOperation(const std::vector& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, const BHWC& dst_size, - TensorFloat32* result); +absl::Status ExecuteGPUOperation(const std::vector& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, const BHWC& dst_size, + TensorFloat32* result); -Status ExecuteGPUOperation(const std::vector& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, - const std::vector& dst_sizes, - const std::vector& dst_cpu); +absl::Status ExecuteGPUOperation(const std::vector& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, + const std::vector& dst_sizes, + const std::vector& dst_cpu); } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc index 141a19de6e1..ef7915afba5 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc @@ -96,7 +96,7 @@ ConcatXY& ConcatXY::operator=(ConcatXY&& operation) { return *this; } -Status ConcatXY::Compile(const CreationContext& creation_context) { +absl::Status ConcatXY::Compile(const CreationContext& creation_context) { const auto code = GetConcatKernelCode(definition_, tensors_count_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -104,7 +104,7 @@ Status ConcatXY::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status ConcatXY::BindArguments() { +absl::Status ConcatXY::BindArguments() { kernel_.ResetBindingCounter(); for (int i = 0; i < tensors_count_; ++i) { RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr())); @@ -122,7 +122,7 @@ Status ConcatXY::BindArguments() { y_offset += attr_.axis == Axis::HEIGHT ? height : 0; } RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 ConcatXY::GetGridSize() const { @@ -140,12 +140,12 @@ int3 ConcatXY::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConcatXY::Tune(const TuningParameters& params) { +absl::Status ConcatXY::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status ConcatXY::AddToQueue(CLCommandQueue* queue) { +absl::Status ConcatXY::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h index 6bc0c87a51f..a170b593cf0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h @@ -31,10 +31,10 @@ class ConcatXY : public GPUOperation { ConcatXY(const OperationDef& definition, const ConcatAttributes& attr, int tensors_count) : GPUOperation(definition), attr_(attr), tensors_count_(tensors_count) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConcatXY(ConcatXY&& operation); @@ -43,7 +43,7 @@ class ConcatXY : public GPUOperation { ConcatXY& operator=(const ConcatXY&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; ConcatAttributes attr_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc index 039fac0d0e3..3a7ec1c0cb7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc @@ -25,8 +25,8 @@ limitations under the License. namespace tflite { namespace gpu { namespace cl { - namespace { + bool IsAllChannelsX4(const std::vector& channels) { for (int channel : channels) { if (channel % 4 != 0) { @@ -146,6 +146,7 @@ std::string GetConcatKernelCode( c += "}\n"; return c; } + } // namespace ConcatZ::ConcatZ(ConcatZ&& kernel) @@ -164,7 +165,7 @@ ConcatZ& ConcatZ::operator=(ConcatZ&& kernel) { return *this; } -Status ConcatZ::Compile(const CreationContext& creation_context) { +absl::Status ConcatZ::Compile(const CreationContext& creation_context) { const auto code = GetConcatKernelCode(definition_, channels_, linked_operations_); std::vector options; @@ -186,7 +187,7 @@ Status ConcatZ::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status ConcatZ::BindArguments() { +absl::Status ConcatZ::BindArguments() { kernel_.ResetBindingCounter(); for (int i = 0; i < channels_.size(); ++i) { RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr())); @@ -197,7 +198,7 @@ Status ConcatZ::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[i]->Slices())); } RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 ConcatZ::GetGridSize() const { @@ -207,12 +208,12 @@ int3 ConcatZ::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConcatZ::Tune(const TuningParameters& params) { +absl::Status ConcatZ::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status ConcatZ::AddToQueue(CLCommandQueue* queue) { +absl::Status ConcatZ::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h index 9fc0fcc1fdb..ec25f6e4ed9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h @@ -32,10 +32,10 @@ class ConcatZ : public GPUOperation { public: ConcatZ(const OperationDef& definition, const std::vector& channels) : GPUOperation(definition), channels_(channels) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConcatZ(ConcatZ&& kernel); @@ -44,7 +44,7 @@ class ConcatZ : public GPUOperation { ConcatZ& operator=(const ConcatZ&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; std::vector channels_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc index e6015357bfc..b79599d8e95 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc @@ -76,7 +76,7 @@ Conv3D& Conv3D::operator=(Conv3D&& operation) { return *this; } -Status Conv3D::Compile(const CreationContext& creation_context) { +absl::Status Conv3D::Compile(const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; const std::string code = @@ -92,7 +92,7 @@ Status Conv3D::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Conv3D::BindArguments() { +absl::Status Conv3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (conv_params_.AreWeightsBuffer()) { @@ -131,7 +131,7 @@ Status Conv3D::BindArguments() { IntegralDivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w))); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS())); - return OkStatus(); + return absl::OkStatus(); } int3 Conv3D::GetGridSize() const { @@ -154,12 +154,12 @@ int3 Conv3D::GetGridSize() const { conv_params_.work_group_size.z); } -Status Conv3D::Tune(const TuningParameters& params) { +absl::Status Conv3D::Tune(const TuningParameters& params) { if (conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP || conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_BY_THREADS) { - return OkStatus(); + return absl::OkStatus(); } if (conv_params_.work_group_launch_order[0] == 0 && conv_params_.work_group_launch_order[1] == 1 && @@ -168,10 +168,10 @@ Status Conv3D::Tune(const TuningParameters& params) { return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &conv_params_.work_group_size); } - return OkStatus(); + return absl::OkStatus(); } -Status Conv3D::AddToQueue(CLCommandQueue* queue) { +absl::Status Conv3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), conv_params_.work_group_size); @@ -903,9 +903,9 @@ Conv3D::ConvParams Conv3D::GuessBestParams( x_kernel_is_1, y_kernel_is_1, z_kernel_is_1); } -Status CreateConv3D(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution3DAttributes& attr, Conv3D* result) { +absl::Status CreateConv3D(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution3DAttributes& attr, Conv3D* result) { *result = Conv3D(definition, attr, *creation_context.device); return result->UploadData(attr.weights, attr.bias, creation_context.context); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h index 8fc48c4114a..00b1e868e5d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h @@ -39,9 +39,9 @@ namespace cl { class Conv3D : public GPUOperation { public: Conv3D() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Conv3D(Conv3D&& operation); @@ -75,21 +75,21 @@ class Conv3D : public GPUOperation { const CLDevice& device); template - Status UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + absl::Status UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - friend Status CreateConv3D(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution3DAttributes& attr, - Conv3D* result); + friend absl::Status CreateConv3D(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution3DAttributes& attr, + Conv3D* result); friend std::string GenerateConv3D( const OperationDef& op_def, const LinearStorage& biases, @@ -105,7 +105,7 @@ class Conv3D : public GPUOperation { int dst_slices, bool x_kernel_is_1, bool y_kernel_is_1, bool z_kernel_is_1) const; - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Texture2D weights_0_; @@ -125,9 +125,9 @@ class Conv3D : public GPUOperation { }; template -Status Conv3D::UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context) { +absl::Status Conv3D::UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context) { RETURN_IF_ERROR(UploadWeights(weights, context)); LinearStorageCreateInfo create_info; create_info.storage_type = conv_params_.AreWeightsBuffer() @@ -139,12 +139,12 @@ Status Conv3D::UploadData(const ::tflite::gpu::Tensor& weights, create_info.name = "biases"; create_info.aligned_size = weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_)); - return OkStatus(); + return absl::OkStatus(); } template -Status Conv3D::UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context) { +absl::Status Conv3D::UploadWeights( + const ::tflite::gpu::Tensor& weights, CLContext* context) { const int block_size = conv_params_.block_size.w; const int dst_slices = AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size); @@ -211,7 +211,7 @@ Status Conv3D::UploadWeights(const ::tflite::gpu::Tensor& weights, } } - return OkStatus(); + return absl::OkStatus(); } template @@ -271,9 +271,9 @@ void Conv3D::RearrangeWeightsData( } } -Status CreateConv3D(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution3DAttributes& attr, Conv3D* result); +absl::Status CreateConv3D(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution3DAttributes& attr, Conv3D* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc index 3a8c726021c..70bd1b5249f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc @@ -291,16 +291,16 @@ ConvBuffer1x1& ConvBuffer1x1::operator=(ConvBuffer1x1&& operation) { return *this; } -Status ConvBuffer1x1::Compile(const CreationContext& creation_context) { +absl::Status ConvBuffer1x1::Compile(const CreationContext& creation_context) { std::string code = GenerateConvBuffer1x1(definition_, conv_params_, linked_operations_); RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_)); - return OkStatus(); + return absl::OkStatus(); } -Status ConvBuffer1x1::BindArguments() { +absl::Status ConvBuffer1x1::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -313,7 +313,7 @@ Status ConvBuffer1x1::BindArguments() { src_width_elements * src_[0]->Height()); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_size)); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 ConvBuffer1x1::GetGridSize() const { @@ -328,13 +328,13 @@ int3 ConvBuffer1x1::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConvBuffer1x1::Tune(const TuningParameters& params) { +absl::Status ConvBuffer1x1::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &conv_params_.work_group_size); } -Status ConvBuffer1x1::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvBuffer1x1::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), conv_params_.work_group_size); @@ -351,12 +351,12 @@ bool IsConvBuffer1x1Supported(const OperationDef& definition, attr.padding.appended.w == 0 && attr.padding.appended.h == 0; } -Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvBuffer1x1* result, const BHWC* shape) { +absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvBuffer1x1* result, const BHWC* shape) { if (!IsConvBuffer1x1Supported(definition, attr)) { - return InvalidArgumentError("ConvBuffer1x1 doesn't supported"); + return absl::InvalidArgumentError("ConvBuffer1x1 doesn't supported"); } const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); @@ -372,10 +372,10 @@ Status CreateConvBuffer1x1(const CreationContext& creation_context, return result->UploadData(attr.weights, attr.bias, creation_context.context); } -Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvBuffer1x1* result, const BHWC* shape) { +absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvBuffer1x1* result, const BHWC* shape) { const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); ConvBuffer1x1::ConvParams conv_params; @@ -392,11 +392,10 @@ Status CreateConvBuffer1x1(const CreationContext& creation_context, return result->UploadData(attr.weights, attr.bias, creation_context.context); } -Status CreateConvBuffer1x1Wino4x4To6x6(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvBuffer1x1* result, - const BHWC* shape) { +absl::Status CreateConvBuffer1x1Wino4x4To6x6( + const CreationContext& creation_context, const OperationDef& definition, + const Convolution2DAttributes& attr, ConvBuffer1x1* result, + const BHWC* shape) { const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); ConvBuffer1x1::ConvParams conv_params; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h index 54e99d29ec7..07da846107e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h @@ -45,10 +45,10 @@ class ConvBuffer1x1 : public GPUOperation { ConvBuffer1x1(const ConvBuffer1x1&) = delete; ConvBuffer1x1& operator=(const ConvBuffer1x1&) = delete; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; struct ConvParams { int3 block_size = int3(1, 1, 1); @@ -64,33 +64,33 @@ class ConvBuffer1x1 : public GPUOperation { private: ConvBuffer1x1(const OperationDef& definition, const ConvParams& conv_params); - friend Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvBuffer1x1* result, const BHWC* shape); - friend Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvBuffer1x1* result, const BHWC* shape); - friend Status CreateConvBuffer1x1Wino4x4To6x6( + friend absl::Status CreateConvBuffer1x1( + const CreationContext& creation_context, const OperationDef& definition, + const Convolution2DAttributes& attr, ConvBuffer1x1* result, + const BHWC* shape); + friend absl::Status CreateConvBuffer1x1( + const CreationContext& creation_context, const OperationDef& definition, + const FullyConnectedAttributes& attr, ConvBuffer1x1* result, + const BHWC* shape); + friend absl::Status CreateConvBuffer1x1Wino4x4To6x6( const CreationContext& creation_context, const OperationDef& definition, const Convolution2DAttributes& attr, ConvBuffer1x1* result, const BHWC* shape); template - Status UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + absl::Status UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); template - Status UploadDataForWinograd4x4To6x6( + absl::Status UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -101,20 +101,20 @@ class ConvBuffer1x1 : public GPUOperation { }; template -Status ConvBuffer1x1::UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context) { +absl::Status ConvBuffer1x1::UploadData( + const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, CLContext* context) { RETURN_IF_ERROR(UploadWeights(weights, context)); LinearStorageCreateInfo create_info; create_info.storage_type = LinearStorageType::BUFFER; create_info.data_type = definition_.GetDataType(); create_info.aligned_size = weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_)); - return OkStatus(); + return absl::OkStatus(); } template -Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6( +absl::Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context) { ::tflite::gpu::Tensor wino_weights; @@ -132,7 +132,7 @@ Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6( } template -Status ConvBuffer1x1::UploadWeights( +absl::Status ConvBuffer1x1::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -162,21 +162,22 @@ Status ConvBuffer1x1::UploadWeights( bool IsConvBuffer1x1Supported(const OperationDef& definition, const Convolution2DAttributes& attr); -Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvBuffer1x1* result, const BHWC* shape = nullptr); +absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvBuffer1x1* result, + const BHWC* shape = nullptr); -Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvBuffer1x1* result, const BHWC* shape = nullptr); +absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvBuffer1x1* result, + const BHWC* shape = nullptr); -Status CreateConvBuffer1x1Wino4x4To6x6(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvBuffer1x1* result, - const BHWC* shape = nullptr); +absl::Status CreateConvBuffer1x1Wino4x4To6x6( + const CreationContext& creation_context, const OperationDef& definition, + const Convolution2DAttributes& attr, ConvBuffer1x1* result, + const BHWC* shape = nullptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc index ceb3b8985e8..07d2da9d641 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc @@ -219,7 +219,7 @@ ConvConstants& ConvConstants::operator=(ConvConstants&& kernel) { return *this; } -Status ConvConstants::Compile(const CreationContext& creation_context) { +absl::Status ConvConstants::Compile(const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; const auto code = GenerateConvolutionConstantCode( @@ -240,7 +240,7 @@ Status ConvConstants::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status ConvConstants::BindArguments() { +absl::Status ConvConstants::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -254,7 +254,7 @@ Status ConvConstants::BindArguments() { kernel_.SetBytesAuto(int2(dilation_.x * src_[0]->Batch(), dilation_.y))); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 ConvConstants::GetGridSize() const { @@ -263,12 +263,12 @@ int3 ConvConstants::GetGridSize() const { return int3(grid_x, grid_y, 1); } -Status ConvConstants::Tune(const TuningParameters& params) { +absl::Status ConvConstants::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status ConvConstants::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvConstants::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -294,12 +294,12 @@ bool IsConvConstantsSupported(const CLDevice& device, return filters_buffer_size <= kConstantMaxSize && flt4_registers <= 8; } -Status CreateConvConstants(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvConstants* result) { +absl::Status CreateConvConstants(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvConstants* result) { if (!IsConvConstantsSupported(*creation_context.device, definition, attr)) { - return InvalidArgumentError("ConvConstants doesn't supported"); + return absl::InvalidArgumentError("ConvConstants doesn't supported"); } *result = ConvConstants(definition, attr); RETURN_IF_ERROR( @@ -310,8 +310,7 @@ Status CreateConvConstants(const CreationContext& creation_context, create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h index b4830d20fd1..fc0e66b5e86 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h @@ -35,10 +35,10 @@ namespace cl { class ConvConstants : public GPUOperation { public: ConvConstants() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvConstants(ConvConstants&& kernel); @@ -47,10 +47,9 @@ class ConvConstants : public GPUOperation { ConvConstants& operator=(const ConvConstants&) = delete; private: - friend Status CreateConvConstants(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvConstants* result); + friend absl::Status CreateConvConstants( + const CreationContext& creation_context, const OperationDef& definition, + const Convolution2DAttributes& attr, ConvConstants* result); explicit ConvConstants(const OperationDef& definition, const Convolution2DAttributes& attr) : GPUOperation(definition), @@ -62,14 +61,14 @@ class ConvConstants : public GPUOperation { dst_channels_(attr.weights.shape.o) {} template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -87,7 +86,7 @@ class ConvConstants : public GPUOperation { }; template -Status ConvConstants::UploadWeights( +absl::Status ConvConstants::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); const int kernel_x = weights.shape.w; @@ -157,10 +156,10 @@ bool IsConvConstantsSupported(const CLDevice& device, const OperationDef& definition, const Convolution2DAttributes& attr); -Status CreateConvConstants(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvConstants* result); +absl::Status CreateConvConstants(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvConstants* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc index c1860d6452f..bd4f53395f3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc @@ -173,7 +173,7 @@ ConvPowerVR& ConvPowerVR::operator=(ConvPowerVR&& operation) { return *this; } -Status ConvPowerVR::Compile(const CreationContext& creation_context) { +absl::Status ConvPowerVR::Compile(const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_padding_.x != 1; const std::string code = @@ -189,7 +189,7 @@ Status ConvPowerVR::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status ConvPowerVR::BindArguments() { +absl::Status ConvPowerVR::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -211,7 +211,7 @@ Status ConvPowerVR::BindArguments() { } RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 ConvPowerVR::GetGridSize() const { @@ -245,13 +245,13 @@ int3 ConvPowerVR::GetGridSize() const { } } -Status ConvPowerVR::Tune(const TuningParameters& params) { +absl::Status ConvPowerVR::Tune(const TuningParameters& params) { if (conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP || conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_BY_THREADS || conv_params_.fixed_work_group_size) { - return OkStatus(); + return absl::OkStatus(); } if (conv_params_.work_group_launch_order[0] == 0 && conv_params_.work_group_launch_order[1] == 1 && @@ -260,10 +260,10 @@ Status ConvPowerVR::Tune(const TuningParameters& params) { return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &conv_params_.work_group_size); } - return OkStatus(); + return absl::OkStatus(); } -Status ConvPowerVR::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvPowerVR::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), conv_params_.work_group_size); @@ -848,27 +848,26 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd( return params; } -Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvPowerVR* result, const BHWC* dst_shape) { +absl::Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvPowerVR* result, const BHWC* dst_shape) { *result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape); return result->UploadData(attr.weights, attr.bias, creation_context.context); } -Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvPowerVR* result, const BHWC* dst_shape) { +absl::Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvPowerVR* result, const BHWC* dst_shape) { *result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape); return result->UploadData(attr.weights, attr.bias, creation_context.context); } -Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvPowerVR* result, - const BHWC* dst_shape) { +absl::Status CreateConvPowerVRWino4x4To6x6( + const CreationContext& creation_context, const OperationDef& definition, + const Convolution2DAttributes& attr, ConvPowerVR* result, + const BHWC* dst_shape) { *result = ConvPowerVR(definition); result->conv_params_ = result->GuessBestParamsWinograd( *creation_context.device, definition, attr, dst_shape); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h index 44145c585da..954205f1ca3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h @@ -39,9 +39,9 @@ namespace cl { class ConvPowerVR : public GPUOperation { public: ConvPowerVR() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvPowerVR(ConvPowerVR&& operation); @@ -87,29 +87,31 @@ class ConvPowerVR : public GPUOperation { explicit ConvPowerVR(const OperationDef& definition); template - Status UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + absl::Status UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); template - Status UploadDataForWinograd4x4To6x6( + absl::Status UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); - friend Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvPowerVR* result, const BHWC* dst_shape); + friend absl::Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvPowerVR* result, + const BHWC* dst_shape); - friend Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvPowerVR* result, const BHWC* dst_shape); + friend absl::Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvPowerVR* result, + const BHWC* dst_shape); - friend Status CreateConvPowerVRWino4x4To6x6( + friend absl::Status CreateConvPowerVRWino4x4To6x6( const CreationContext& creation_context, const OperationDef& definition, const Convolution2DAttributes& attr, ConvPowerVR* result, const BHWC* dst_shape); @@ -138,7 +140,7 @@ class ConvPowerVR : public GPUOperation { bool different_weights_for_height, const BHWC* dst_shape = nullptr) const; - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -152,20 +154,20 @@ class ConvPowerVR : public GPUOperation { }; template -Status ConvPowerVR::UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context) { +absl::Status ConvPowerVR::UploadData( + const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, CLContext* context) { RETURN_IF_ERROR(UploadWeights(weights, context)); LinearStorageCreateInfo create_info; create_info.storage_type = LinearStorageType::BUFFER; create_info.data_type = conv_params_.weights_data_type; create_info.aligned_size = weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_)); - return OkStatus(); + return absl::OkStatus(); } template -Status ConvPowerVR::UploadDataForWinograd4x4To6x6( +absl::Status ConvPowerVR::UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context) { ::tflite::gpu::Tensor wino_weights; @@ -179,12 +181,12 @@ Status ConvPowerVR::UploadDataForWinograd4x4To6x6( bias.shape = Linear(weights.shape.o); bias.data.resize(weights.shape.o, 0.0f); RETURN_IF_ERROR(CreateLinearStorage(create_info, bias, context, &biases_)); - return OkStatus(); + return absl::OkStatus(); } template -Status ConvPowerVR::UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context) { +absl::Status ConvPowerVR::UploadWeights( + const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -210,21 +212,22 @@ Status ConvPowerVR::UploadWeights(const ::tflite::gpu::Tensor& weights, } } -Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvPowerVR* result, const BHWC* dst_shape = nullptr); +absl::Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvPowerVR* result, + const BHWC* dst_shape = nullptr); -Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvPowerVR* result, const BHWC* dst_shape = nullptr); +absl::Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvPowerVR* result, + const BHWC* dst_shape = nullptr); -Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvPowerVR* result, - const BHWC* dst_shape = nullptr); +absl::Status CreateConvPowerVRWino4x4To6x6( + const CreationContext& creation_context, const OperationDef& definition, + const Convolution2DAttributes& attr, ConvPowerVR* result, + const BHWC* dst_shape = nullptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc index 780d6646ea8..953f564c40a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc @@ -30,6 +30,7 @@ namespace tflite { namespace gpu { namespace cl { namespace { + std::string GenerateConvCode( const OperationDef& op_def, const int3& block_size, bool is1x1, bool adreno4xx_optimization, bool stride_correction, @@ -384,7 +385,7 @@ ConvTexture& ConvTexture::operator=(ConvTexture&& operation) { return *this; } -Status ConvTexture::Compile(const CreationContext& creation_context) { +absl::Status ConvTexture::Compile(const CreationContext& creation_context) { auto storage_type = definition_.GetPrimaryStorageType(); bool is1x1 = kernel_size_.x == 1 && kernel_size_.y == 1; bool adreno4xx_optimization = @@ -407,7 +408,7 @@ Status ConvTexture::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status ConvTexture::BindArguments() { +absl::Status ConvTexture::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_0_.GetMemoryPtr())); @@ -427,7 +428,7 @@ Status ConvTexture::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_)); RETURN_IF_ERROR( kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y))); - return OkStatus(); + return absl::OkStatus(); } int3 ConvTexture::GetGridSize() const { @@ -438,37 +439,36 @@ int3 ConvTexture::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConvTexture::Tune(const TuningParameters& params) { +absl::Status ConvTexture::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &work_group_size_); } -Status ConvTexture::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvTexture::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvTexture* result) { +absl::Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvTexture* result) { *result = ConvTexture(definition, attr); return result->UploadData(attr.weights, attr.bias, creation_context.context); } -Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvTexture* result) { +absl::Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvTexture* result) { *result = ConvTexture(definition); return result->UploadData(attr.weights, attr.bias, creation_context.context); } -Status CreateConvTextureWino4x4To6x6(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvTexture* result) { +absl::Status CreateConvTextureWino4x4To6x6( + const CreationContext& creation_context, const OperationDef& definition, + const Convolution2DAttributes& attr, ConvTexture* result) { *result = ConvTexture(definition); result->different_weights_for_height_ = true; result->block_size_ = {4, 1, 2}; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h index fb25f655057..b7fbac91cf2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h @@ -41,10 +41,10 @@ namespace cl { class ConvTexture : public GPUOperation { public: ConvTexture() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvTexture(ConvTexture&& operation); @@ -53,16 +53,16 @@ class ConvTexture : public GPUOperation { ConvTexture& operator=(const ConvTexture&) = delete; private: - friend Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvTexture* result); - friend Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvTexture* result); + friend absl::Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvTexture* result); + friend absl::Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvTexture* result); - friend Status CreateConvTextureWino4x4To6x6( + friend absl::Status CreateConvTextureWino4x4To6x6( const CreationContext& creation_context, const OperationDef& definition, const Convolution2DAttributes& attr, ConvTexture* result); @@ -70,25 +70,25 @@ class ConvTexture : public GPUOperation { const Convolution2DAttributes& attr); explicit ConvTexture(const OperationDef& definition); template - Status UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + absl::Status UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); template - Status UploadDataForWinograd4x4To6x6( + absl::Status UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst_0, absl::Span dst_1, absl::Span dst_2, absl::Span dst_3); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Texture2D weights_0_; @@ -114,20 +114,20 @@ class ConvTexture : public GPUOperation { }; template -Status ConvTexture::UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context) { +absl::Status ConvTexture::UploadData( + const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, CLContext* context) { RETURN_IF_ERROR(UploadWeights(weights, context)); LinearStorageCreateInfo create_info; create_info.storage_type = LinearStorageType::TEXTURE_2D; create_info.data_type = definition_.GetDataType(); create_info.aligned_size = weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_)); - return OkStatus(); + return absl::OkStatus(); } template -Status ConvTexture::UploadDataForWinograd4x4To6x6( +absl::Status ConvTexture::UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context) { ::tflite::gpu::Tensor wino_weights; @@ -145,8 +145,8 @@ Status ConvTexture::UploadDataForWinograd4x4To6x6( } template -Status ConvTexture::UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context) { +absl::Status ConvTexture::UploadWeights( + const ::tflite::gpu::Tensor& weights, CLContext* context) { int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); dst_depth = AlignByN(dst_depth, block_size_.z); const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -246,20 +246,19 @@ void ConvTexture::RearrangeWeightsData( } } -Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvTexture* result); +absl::Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvTexture* result); -Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvTexture* result); +absl::Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvTexture* result); -Status CreateConvTextureWino4x4To6x6(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvTexture* result); +absl::Status CreateConvTextureWino4x4To6x6( + const CreationContext& creation_context, const OperationDef& definition, + const Convolution2DAttributes& attr, ConvTexture* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc index 947c39cd299..e3170f068e9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc @@ -35,12 +35,12 @@ namespace { class OpenClConverterImpl : public TensorObjectConverter { public: - virtual Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) = 0; + virtual absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) = 0; protected: - Status DispatchKernel(cl_mem input, cl_mem output) { + absl::Status DispatchKernel(cl_mem input, cl_mem output) { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(input)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(output)); @@ -119,9 +119,9 @@ class FromTensorConverter : public OpenClConverterImpl { })"); } - Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { auto params_kernel = output_def.object_def.data_layout == DataLayout::BHWC ? GetToBhwcKernel(input_def, output_def) : GetToDhwc4Kernel(input_def, output_def); @@ -157,11 +157,12 @@ __kernel void from_tensor()" + environment->device(), &kernel_); } - Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto output = absl::get_if(&output_obj); if (!output || !output->memobj) { - return InvalidArgumentError("Missing output in from_tensor converter"); + return absl::InvalidArgumentError( + "Missing output in from_tensor converter"); } auto input_texture = absl::get_if(&input_obj); if (input_texture && input_texture->memobj) { @@ -171,7 +172,7 @@ __kernel void from_tensor()" + if (input_buffer && input_buffer->memobj) { return DispatchKernel(input_buffer->memobj, output->memobj); } - return InvalidArgumentError("Missing input in from_tensor converter"); + return absl::InvalidArgumentError("Missing input in from_tensor converter"); } }; @@ -225,9 +226,9 @@ class ToTensorConverter : public OpenClConverterImpl { )"); } - Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { auto params_kernel = input_def.object_def.data_layout == DataLayout::BHWC ? GetFromBhwcKernel(input_def, output_def) : GetFromDhwc4Kernel(input_def, output_def); @@ -261,11 +262,11 @@ __kernel void to_tensor()" + &kernel_); } - Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto input = absl::get_if(&input_obj); if (!input || !input->memobj) { - return InvalidArgumentError("Missing input in to_tensor converter"); + return absl::InvalidArgumentError("Missing input in to_tensor converter"); } auto output_texture = absl::get_if(&output_obj); if (output_texture && output_texture->memobj) { @@ -275,7 +276,7 @@ __kernel void to_tensor()" + if (output_buffer && output_buffer->memobj) { return DispatchKernel(input->memobj, output_buffer->memobj); } - return InvalidArgumentError("Missing input in to_tensor converter"); + return absl::InvalidArgumentError("Missing input in to_tensor converter"); } }; @@ -318,18 +319,18 @@ class TrivialCopier : public OpenClConverterImpl { input.data_layout == output.data_layout; } - Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { dims_ = input_def.dimensions; data_type_ = input_def.object_def.data_type; queue_ = environment->queue(); region_ = CalculateTextureRegion(output_def); - return OkStatus(); + return absl::OkStatus(); } - Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto texture_input = absl::get_if(&input_obj); auto texture_output = absl::get_if(&output_obj); if (texture_input && texture_output) { @@ -340,12 +341,12 @@ class TrivialCopier : public OpenClConverterImpl { if (buffer_input && buffer_output) { return Copy(*buffer_input, *buffer_output); } - return InternalError("Unexpected object"); + return absl::InternalError("Unexpected object"); } - Status Copy(const OpenClBuffer& input, const OpenClBuffer& output) { + absl::Status Copy(const OpenClBuffer& input, const OpenClBuffer& output) { if (input.memobj == output.memobj) { - return OkStatus(); + return absl::OkStatus(); } return GetOpenCLError(clEnqueueCopyBuffer( queue_->queue(), input.memobj, output.memobj, 0, 0, @@ -353,9 +354,9 @@ class TrivialCopier : public OpenClConverterImpl { nullptr)); } - Status Copy(const OpenClTexture& input, const OpenClTexture& output) { + absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) { if (input.memobj == output.memobj) { - return OkStatus(); + return absl::OkStatus(); } size_t origin[3] = {0, 0, 0}; return GetOpenCLError( @@ -380,18 +381,18 @@ class CpuCopier : public OpenClConverterImpl { IsOpenClTextureOrBuffer(input.object_type))); } - Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { region_ = CalculateTextureRegion( input_def.object_def.object_type == ObjectType::CPU_MEMORY ? output_def : input_def); queue_ = environment->queue(); - return OkStatus(); + return absl::OkStatus(); } - Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto cpu_input = absl::get_if(&input_obj); auto cpu_output = absl::get_if(&output_obj); if (cpu_input) { @@ -419,7 +420,7 @@ class CpuCopier : public OpenClConverterImpl { buffer_input->memobj, cpu_output->size_bytes, cpu_output->data); } } - return InternalError("Unexpected object"); + return absl::InternalError("Unexpected object"); } private: @@ -442,7 +443,7 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder { ToTensorConverter::IsSupported(input_def, output_def)); } - Status MakeConverter( + absl::Status MakeConverter( const TensorObjectDef& input, const TensorObjectDef& output, std::unique_ptr* converter) final { std::unique_ptr impl; @@ -457,11 +458,11 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder { } else if (ToTensorConverter::IsSupported(input_def, output_def)) { impl = absl::make_unique(); } else { - return UnimplementedError("Unsupported conversion"); + return absl::UnimplementedError("Unsupported conversion"); } RETURN_IF_ERROR(impl->Init(input, output, environment_)); *converter = std::move(impl); - return OkStatus(); + return absl::OkStatus(); } Environment* environment_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc index 921a257aa7e..417fb63e820 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc @@ -368,7 +368,8 @@ ConvolutionTransposed& ConvolutionTransposed::operator=( return *this; } -Status ConvolutionTransposed::Compile(const CreationContext& creation_context) { +absl::Status ConvolutionTransposed::Compile( + const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, biases_, *creation_context.device, weights_are_buffer_, block_size_, linked_operations_); @@ -380,7 +381,7 @@ Status ConvolutionTransposed::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status ConvolutionTransposed::BindArguments() { +absl::Status ConvolutionTransposed::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (weights_are_buffer_) { @@ -399,7 +400,7 @@ Status ConvolutionTransposed::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 ConvolutionTransposed::GetGridSize() const { @@ -412,21 +413,21 @@ int3 ConvolutionTransposed::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConvolutionTransposed::Tune(const TuningParameters& params) { +absl::Status ConvolutionTransposed::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &work_group_size_); } -Status ConvolutionTransposed::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvolutionTransposed::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status CreateConvolutionTransposed(const CreationContext& creation_context, - const OperationDef& definition, - const ConvolutionTransposedAttributes& attr, - ConvolutionTransposed* result) { +absl::Status CreateConvolutionTransposed( + const CreationContext& creation_context, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr, + ConvolutionTransposed* result) { *result = ConvolutionTransposed(definition, attr, *creation_context.device); RETURN_IF_ERROR( result->UploadWeights(attr.weights, creation_context.context)); @@ -438,8 +439,7 @@ Status CreateConvolutionTransposed(const CreationContext& creation_context, create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h index 73fce020f5a..7545b9091e2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h @@ -38,10 +38,10 @@ namespace cl { class ConvolutionTransposed : public GPUOperation { public: ConvolutionTransposed() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed(ConvolutionTransposed&& operation); @@ -50,7 +50,7 @@ class ConvolutionTransposed : public GPUOperation { ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete; private: - friend Status CreateConvolutionTransposed( + friend absl::Status CreateConvolutionTransposed( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed* result); @@ -58,14 +58,14 @@ class ConvolutionTransposed : public GPUOperation { const ConvolutionTransposedAttributes& attr, const CLDevice& device); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; LinearStorage biases_; @@ -88,7 +88,7 @@ class ConvolutionTransposed : public GPUOperation { }; template -Status ConvolutionTransposed::UploadWeights( +absl::Status ConvolutionTransposed::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z); @@ -153,7 +153,7 @@ Status ConvolutionTransposed::UploadWeights( } } - return OkStatus(); + return absl::OkStatus(); } template @@ -208,10 +208,9 @@ void ConvolutionTransposed::RearrangeWeightsData( } } -Status CreateConvolutionTransposed(const CreationContext& creation_context, - const OperationDef& definition, - const ConvolutionTransposedAttributes& attr, - ConvolutionTransposed* result); +absl::Status CreateConvolutionTransposed( + const CreationContext& creation_context, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr, ConvolutionTransposed* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc index 147674b7eff..9d3f0b2639c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc @@ -396,7 +396,7 @@ ConvolutionTransposed3D& ConvolutionTransposed3D::operator=( return *this; } -Status ConvolutionTransposed3D::Compile( +absl::Status ConvolutionTransposed3D::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposed3DCode( definition_, biases_, *creation_context.device, weights_are_buffer_, @@ -417,7 +417,7 @@ Status ConvolutionTransposed3D::Compile( *creation_context.device, &kernel_); } -Status ConvolutionTransposed3D::BindArguments() { +absl::Status ConvolutionTransposed3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (weights_are_buffer_) { @@ -444,7 +444,7 @@ Status ConvolutionTransposed3D::BindArguments() { IntegralDivideRoundUp(dst_[0]->Slices(), block_size_.w))); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHDS())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHDS())); - return OkStatus(); + return absl::OkStatus(); } int3 ConvolutionTransposed3D::GetGridSize() const { @@ -459,18 +459,18 @@ int3 ConvolutionTransposed3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConvolutionTransposed3D::Tune(const TuningParameters& params) { +absl::Status ConvolutionTransposed3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &work_group_size_); } -Status ConvolutionTransposed3D::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvolutionTransposed3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status CreateConvolutionTransposed3D( +absl::Status CreateConvolutionTransposed3D( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposed3DAttributes& attr, ConvolutionTransposed3D* result) { @@ -485,8 +485,7 @@ Status CreateConvolutionTransposed3D( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h index c3fbd87a240..763494efce6 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h @@ -38,10 +38,10 @@ namespace cl { class ConvolutionTransposed3D : public GPUOperation { public: ConvolutionTransposed3D() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed3D(ConvolutionTransposed3D&& operation); @@ -50,7 +50,7 @@ class ConvolutionTransposed3D : public GPUOperation { ConvolutionTransposed3D& operator=(const ConvolutionTransposed3D&) = delete; private: - friend Status CreateConvolutionTransposed3D( + friend absl::Status CreateConvolutionTransposed3D( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposed3DAttributes& attr, ConvolutionTransposed3D* result); @@ -58,14 +58,14 @@ class ConvolutionTransposed3D : public GPUOperation { const ConvolutionTransposed3DAttributes& attr, const CLDevice& device); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; LinearStorage biases_; @@ -88,7 +88,7 @@ class ConvolutionTransposed3D : public GPUOperation { }; template -Status ConvolutionTransposed3D::UploadWeights( +absl::Status ConvolutionTransposed3D::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z); @@ -155,7 +155,7 @@ Status ConvolutionTransposed3D::UploadWeights( } } - return OkStatus(); + return absl::OkStatus(); } template @@ -214,7 +214,7 @@ void ConvolutionTransposed3D::RearrangeWeightsData( } } -Status CreateConvolutionTransposed3D( +absl::Status CreateConvolutionTransposed3D( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposed3DAttributes& attr, ConvolutionTransposed3D* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc index 7b19ac0ba38..4be593be57b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc @@ -304,12 +304,11 @@ ConvolutionTransposed3x3& ConvolutionTransposed3x3::operator=( return *this; } -Status ConvolutionTransposed3x3::Compile( +absl::Status ConvolutionTransposed3x3::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, biases_, linked_operations_, weights_upload_type_, padding_, work_group_launch_order_); - std::vector options; if (definition_.precision == CalculationsPrecision::F16 && creation_context.device->IsPowerVR()) { @@ -318,11 +317,10 @@ Status ConvolutionTransposed3x3::Compile( RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); - - return OkStatus(); + return absl::OkStatus(); } -Status ConvolutionTransposed3x3::BindArguments() { +absl::Status ConvolutionTransposed3x3::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -337,10 +335,7 @@ Status ConvolutionTransposed3x3::BindArguments() { padding_.x >= 1 ? (padding_.x - 1) / 2 : (padding_.x - 2) / 2; const int padding_y = padding_.y >= 1 ? (padding_.y - 1) / 2 : (padding_.y - 2) / 2; - RETURN_IF_ERROR( - kernel_.SetBytesAuto(int2(padding_x * src_[0]->Batch(), padding_y))); - - return OkStatus(); + return kernel_.SetBytesAuto(int2(padding_x * src_[0]->Batch(), padding_y)); } int3 ConvolutionTransposed3x3::GetGridSize() const { @@ -358,7 +353,7 @@ int3 ConvolutionTransposed3x3::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConvolutionTransposed3x3::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvolutionTransposed3x3::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -370,13 +365,13 @@ bool IsConvolutionTransposed3x3Supported( attr.stride.w == 2 && attr.stride.h == 2; } -Status CreateConvolutionTransposed3x3( +absl::Status CreateConvolutionTransposed3x3( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3* result) { if (!IsConvolutionTransposed3x3Supported(*creation_context.device, definition, attr)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "ConvolutionTransposed3x3 doesn't support this attributes"); } const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h); @@ -391,7 +386,7 @@ Status CreateConvolutionTransposed3x3( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h index 9e12d884719..5da112e19c0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h @@ -37,8 +37,8 @@ namespace cl { class ConvolutionTransposed3x3 : public GPUOperation { public: ConvolutionTransposed3x3() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed3x3(ConvolutionTransposed3x3&& operation); @@ -56,19 +56,19 @@ class ConvolutionTransposed3x3 : public GPUOperation { private: ConvolutionTransposed3x3(const OperationDef& definition, const CLDevice& device, int2 padding); - friend Status CreateConvolutionTransposed3x3( + friend absl::Status CreateConvolutionTransposed3x3( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3* result); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; int2 padding_; @@ -82,7 +82,7 @@ class ConvolutionTransposed3x3 : public GPUOperation { }; template -Status ConvolutionTransposed3x3::UploadWeights( +absl::Status ConvolutionTransposed3x3::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); @@ -165,7 +165,7 @@ bool IsConvolutionTransposed3x3Supported( const CLDevice& device, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); -Status CreateConvolutionTransposed3x3( +absl::Status CreateConvolutionTransposed3x3( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc index 40838d28eed..b8e4b25443e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc @@ -221,19 +221,18 @@ ConvolutionTransposed3x3Thin& ConvolutionTransposed3x3Thin::operator=( return *this; } -Status ConvolutionTransposed3x3Thin::Compile( +absl::Status ConvolutionTransposed3x3Thin::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, biases_, IntegralDivideRoundUp(src_channels_, 4), IntegralDivideRoundUp(dst_channels_, 4), *creation_context.device, linked_operations_); - return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -Status ConvolutionTransposed3x3Thin::BindArguments() { +absl::Status ConvolutionTransposed3x3Thin::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -242,7 +241,7 @@ Status ConvolutionTransposed3x3Thin::BindArguments() { RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 ConvolutionTransposed3x3Thin::GetGridSize() const { @@ -252,12 +251,13 @@ int3 ConvolutionTransposed3x3Thin::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConvolutionTransposed3x3Thin::Tune(const TuningParameters& params) { +absl::Status ConvolutionTransposed3x3Thin::Tune( + const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status ConvolutionTransposed3x3Thin::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvolutionTransposed3x3Thin::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -271,13 +271,13 @@ bool IsConvolutionTransposed3x3ThinSupported( attr.padding.appended.h == 1; } -Status CreateConvolutionTransposed3x3Thin( +absl::Status CreateConvolutionTransposed3x3Thin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3Thin* result) { if (!IsConvolutionTransposed3x3ThinSupported(*creation_context.device, attr)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "ConvolutionTransposed3x3Thin doesn't support this attributes"); } *result = ConvolutionTransposed3x3Thin(definition, attr); @@ -291,8 +291,7 @@ Status CreateConvolutionTransposed3x3Thin( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h index f8d10d6c6b8..f2a0d586bd1 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h @@ -37,10 +37,10 @@ namespace cl { class ConvolutionTransposed3x3Thin : public GPUOperation { public: ConvolutionTransposed3x3Thin() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed3x3Thin(ConvolutionTransposed3x3Thin&& operation); @@ -51,7 +51,7 @@ class ConvolutionTransposed3x3Thin : public GPUOperation { delete; private: - friend Status CreateConvolutionTransposed3x3Thin( + friend absl::Status CreateConvolutionTransposed3x3Thin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3Thin* result); @@ -59,14 +59,14 @@ class ConvolutionTransposed3x3Thin : public GPUOperation { const OperationDef& definition, const ConvolutionTransposedAttributes& attr); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -80,7 +80,7 @@ class ConvolutionTransposed3x3Thin : public GPUOperation { }; template -Status ConvolutionTransposed3x3Thin::UploadWeights( +absl::Status ConvolutionTransposed3x3Thin::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(src_channels_, 4); const int dst_depth = IntegralDivideRoundUp(dst_channels_, 4); @@ -150,7 +150,7 @@ void ConvolutionTransposed3x3Thin::RearrangeWeightsData( bool IsConvolutionTransposed3x3ThinSupported( const CLDevice& device, const ConvolutionTransposedAttributes& attr); -Status CreateConvolutionTransposed3x3Thin( +absl::Status CreateConvolutionTransposed3x3Thin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3Thin* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc index 1e36be17778..a558fe6cb3c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc @@ -301,7 +301,7 @@ ConvolutionTransposed4x4& ConvolutionTransposed4x4::operator=( return *this; } -Status ConvolutionTransposed4x4::Compile( +absl::Status ConvolutionTransposed4x4::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, biases_, linked_operations_, weights_upload_type_); @@ -314,11 +314,10 @@ Status ConvolutionTransposed4x4::Compile( RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); - - return OkStatus(); + return absl::OkStatus(); } -Status ConvolutionTransposed4x4::BindArguments() { +absl::Status ConvolutionTransposed4x4::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -329,8 +328,7 @@ Status ConvolutionTransposed4x4::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); const int32_t filters_offset = 4 * 16 * src_[0]->Slices(); RETURN_IF_ERROR(kernel_.SetBytesAuto(filters_offset)); - - return OkStatus(); + return absl::OkStatus(); } int3 ConvolutionTransposed4x4::GetGridSize() const { @@ -341,7 +339,7 @@ int3 ConvolutionTransposed4x4::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConvolutionTransposed4x4::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvolutionTransposed4x4::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -354,13 +352,13 @@ bool IsConvolutionTransposed4x4Supported( attr.padding.prepended.w == 1 && attr.padding.prepended.h == 1; } -Status CreateConvolutionTransposed4x4( +absl::Status CreateConvolutionTransposed4x4( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed4x4* result) { if (!IsConvolutionTransposed4x4Supported(*creation_context.device, definition, attr)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "ConvolutionTransposed4x4 doesn't support this attributes"); } *result = ConvolutionTransposed4x4(definition, *creation_context.device); @@ -373,7 +371,7 @@ Status CreateConvolutionTransposed4x4( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h index 8d92542c908..7bf37c56119 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h @@ -37,8 +37,8 @@ namespace cl { class ConvolutionTransposed4x4 : public GPUOperation { public: ConvolutionTransposed4x4() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed4x4(ConvolutionTransposed4x4&& operation); @@ -56,19 +56,19 @@ class ConvolutionTransposed4x4 : public GPUOperation { private: ConvolutionTransposed4x4(const OperationDef& definition, const CLDevice& device); - friend Status CreateConvolutionTransposed4x4( + friend absl::Status CreateConvolutionTransposed4x4( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed4x4* result); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -80,7 +80,7 @@ class ConvolutionTransposed4x4 : public GPUOperation { }; template -Status ConvolutionTransposed4x4::UploadWeights( +absl::Status ConvolutionTransposed4x4::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); @@ -150,7 +150,7 @@ bool IsConvolutionTransposed4x4Supported( const CLDevice& device, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); -Status CreateConvolutionTransposed4x4( +absl::Status CreateConvolutionTransposed4x4( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed4x4* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc index 03b9ab0eb6c..8ea40bedd7d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc @@ -184,7 +184,7 @@ ConvolutionTransposedThin& ConvolutionTransposedThin::operator=( return *this; } -Status ConvolutionTransposedThin::Compile( +absl::Status ConvolutionTransposedThin::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, IntegralDivideRoundUp(src_channels_, 4), dst_channels_, @@ -201,7 +201,7 @@ Status ConvolutionTransposedThin::Compile( *creation_context.device, &kernel_); } -Status ConvolutionTransposedThin::BindArguments() { +absl::Status ConvolutionTransposedThin::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr())); @@ -210,7 +210,7 @@ Status ConvolutionTransposedThin::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(bias_value_)); - return OkStatus(); + return absl::OkStatus(); } int3 ConvolutionTransposedThin::GetGridSize() const { @@ -220,12 +220,12 @@ int3 ConvolutionTransposedThin::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ConvolutionTransposedThin::Tune(const TuningParameters& params) { +absl::Status ConvolutionTransposedThin::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status ConvolutionTransposedThin::AddToQueue(CLCommandQueue* queue) { +absl::Status ConvolutionTransposedThin::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -238,18 +238,18 @@ bool IsConvolutionTransposedThinSupported( attr.padding.appended.w == 0 && attr.padding.appended.h == 0; } -Status CreateConvolutionTransposedThin( +absl::Status CreateConvolutionTransposedThin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposedThin* result) { if (!IsConvolutionTransposedThinSupported(*creation_context.device, attr)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "ConvolutionTransposedThin doesn't support this attributes"); } *result = ConvolutionTransposedThin(definition, attr); RETURN_IF_ERROR( result->UploadWeights(attr.weights, creation_context.context)); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h index 0642a7c928b..573772965ae 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h @@ -38,10 +38,10 @@ namespace cl { class ConvolutionTransposedThin : public GPUOperation { public: ConvolutionTransposedThin() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposedThin(ConvolutionTransposedThin&& operation); @@ -51,21 +51,21 @@ class ConvolutionTransposedThin : public GPUOperation { delete; private: - friend Status CreateConvolutionTransposedThin( + friend absl::Status CreateConvolutionTransposedThin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposedThin* result); ConvolutionTransposedThin(const OperationDef& definition, const ConvolutionTransposedAttributes& attr); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Buffer weights_buf_; @@ -80,7 +80,7 @@ class ConvolutionTransposedThin : public GPUOperation { }; template -Status ConvolutionTransposedThin::UploadWeights( +absl::Status ConvolutionTransposedThin::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(src_channels_, 4); const int elements_count = @@ -136,7 +136,7 @@ void ConvolutionTransposedThin::RearrangeWeightsData( bool IsConvolutionTransposedThinSupported( const CLDevice& device, const ConvolutionTransposedAttributes& attr); -Status CreateConvolutionTransposedThin( +absl::Status CreateConvolutionTransposedThin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposedThin* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc index e7bf31b0d37..99bec18c7f8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc @@ -226,7 +226,8 @@ DepthWiseConvolution& DepthWiseConvolution::operator=( return *this; } -Status DepthWiseConvolution::Compile(const CreationContext& creation_context) { +absl::Status DepthWiseConvolution::Compile( + const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; const auto code = GenerateDepthWiseConvolutionCode( @@ -237,7 +238,7 @@ Status DepthWiseConvolution::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status DepthWiseConvolution::BindArguments() { +absl::Status DepthWiseConvolution::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_)); @@ -255,7 +256,7 @@ Status DepthWiseConvolution::BindArguments() { } RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 DepthWiseConvolution::GetGridSize() const { @@ -265,20 +266,20 @@ int3 DepthWiseConvolution::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status DepthWiseConvolution::Tune(const TuningParameters& params) { +absl::Status DepthWiseConvolution::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status DepthWiseConvolution::AddToQueue(CLCommandQueue* queue) { +absl::Status DepthWiseConvolution::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status CreateDepthWiseConvolution(const CreationContext& creation_context, - const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, - DepthWiseConvolution* result) { +absl::Status CreateDepthWiseConvolution( + const CreationContext& creation_context, const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr, + DepthWiseConvolution* result) { bool weights_are_buffer = creation_context.device->IsMali(); *result = DepthWiseConvolution(definition, attr, weights_are_buffer); RETURN_IF_ERROR( @@ -291,7 +292,7 @@ Status CreateDepthWiseConvolution(const CreationContext& creation_context, create_info.aligned_size = attr.weights.shape.o * attr.weights.shape.i; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h index 5915ed94502..8f3320ae57b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h @@ -38,10 +38,10 @@ namespace cl { class DepthWiseConvolution : public GPUOperation { public: DepthWiseConvolution() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only DepthWiseConvolution(DepthWiseConvolution&& operation); @@ -50,7 +50,7 @@ class DepthWiseConvolution : public GPUOperation { DepthWiseConvolution& operator=(const DepthWiseConvolution&) = delete; private: - friend Status CreateDepthWiseConvolution( + friend absl::Status CreateDepthWiseConvolution( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr, DepthWiseConvolution* result); @@ -58,14 +58,14 @@ class DepthWiseConvolution : public GPUOperation { const DepthwiseConvolution2DAttributes& attr, bool weights_are_buffer); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; bool weights_are_buffer_; @@ -86,7 +86,7 @@ class DepthWiseConvolution : public GPUOperation { }; template -Status DepthWiseConvolution::UploadWeights( +absl::Status DepthWiseConvolution::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); @@ -130,7 +130,7 @@ Status DepthWiseConvolution::UploadWeights( weights_ = weights_tex2d_.GetMemoryPtr(); } - return OkStatus(); + return absl::OkStatus(); } template @@ -162,10 +162,9 @@ void DepthWiseConvolution::RearrangeWeightsData( } } -Status CreateDepthWiseConvolution(const CreationContext& creation_context, - const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, - DepthWiseConvolution* result); +absl::Status CreateDepthWiseConvolution( + const CreationContext& creation_context, const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr, DepthWiseConvolution* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc index e3297cb6814..57d30dd2734 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc @@ -256,7 +256,7 @@ DepthWiseConvolution3D& DepthWiseConvolution3D::operator=( return *this; } -Status DepthWiseConvolution3D::Compile( +absl::Status DepthWiseConvolution3D::Compile( const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; @@ -268,7 +268,7 @@ Status DepthWiseConvolution3D::Compile( *creation_context.device, &kernel_); } -Status DepthWiseConvolution3D::BindArguments() { +absl::Status DepthWiseConvolution3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (weights_are_buffer_) { @@ -295,7 +295,7 @@ Status DepthWiseConvolution3D::BindArguments() { } RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS())); - return OkStatus(); + return absl::OkStatus(); } int3 DepthWiseConvolution3D::GetGridSize() const { @@ -305,17 +305,17 @@ int3 DepthWiseConvolution3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status DepthWiseConvolution3D::Tune(const TuningParameters& params) { +absl::Status DepthWiseConvolution3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status DepthWiseConvolution3D::AddToQueue(CLCommandQueue* queue) { +absl::Status DepthWiseConvolution3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status CreateDepthWiseConvolution3D( +absl::Status CreateDepthWiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, DepthWiseConvolution3D* result) { @@ -330,7 +330,7 @@ Status CreateDepthWiseConvolution3D( create_info.aligned_size = attr.weights.shape.o * attr.weights.shape.i; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h index e3c565422af..78ca6862416 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h @@ -38,10 +38,10 @@ namespace cl { class DepthWiseConvolution3D : public GPUOperation { public: DepthWiseConvolution3D() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only DepthWiseConvolution3D(DepthWiseConvolution3D&& operation); @@ -50,7 +50,7 @@ class DepthWiseConvolution3D : public GPUOperation { DepthWiseConvolution3D& operator=(const DepthWiseConvolution3D&) = delete; private: - friend Status CreateDepthWiseConvolution3D( + friend absl::Status CreateDepthWiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, DepthWiseConvolution3D* result); @@ -58,14 +58,14 @@ class DepthWiseConvolution3D : public GPUOperation { const DepthwiseConvolution3DAttributes& attr, const CLDevice& device); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Texture2D weights_tex2d_; @@ -85,7 +85,7 @@ class DepthWiseConvolution3D : public GPUOperation { }; template -Status DepthWiseConvolution3D::UploadWeights( +absl::Status DepthWiseConvolution3D::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_slices = IntegralDivideRoundUp(dst_channels, 4); @@ -123,7 +123,7 @@ Status DepthWiseConvolution3D::UploadWeights( gpu_data.data(), context, &weights_tex2d_)); } } - return OkStatus(); + return absl::OkStatus(); } template @@ -158,7 +158,7 @@ void DepthWiseConvolution3D::RearrangeWeightsData( } } -Status CreateDepthWiseConvolution3D( +absl::Status CreateDepthWiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, DepthWiseConvolution3D* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc index 704df26f2ba..3324adada3b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc @@ -297,7 +297,8 @@ DepthWiseConv3x3& DepthWiseConv3x3::operator=(DepthWiseConv3x3&& operation) { return *this; } -Status DepthWiseConv3x3::Compile(const CreationContext& creation_context) { +absl::Status DepthWiseConv3x3::Compile( + const CreationContext& creation_context) { std::string code = GenerateDepthWiseConvCode( definition_, linked_operations_, *creation_context.device, weights_are_buffer_, local_mem_uploads_); @@ -311,15 +312,14 @@ Status DepthWiseConv3x3::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status DepthWiseConv3x3::BindArguments() { +absl::Status DepthWiseConv3x3::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_)); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - - return OkStatus(); + return absl::OkStatus(); } int3 DepthWiseConv3x3::GetGridSize() const { @@ -329,15 +329,15 @@ int3 DepthWiseConv3x3::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status DepthWiseConv3x3::Tune(const TuningParameters& params) { +absl::Status DepthWiseConv3x3::Tune(const TuningParameters& params) { if (local_mem_uploads_) { - return OkStatus(); + return absl::OkStatus(); } RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status DepthWiseConv3x3::AddToQueue(CLCommandQueue* queue) { +absl::Status DepthWiseConv3x3::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -351,12 +351,11 @@ bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr) { attr.padding.appended.h == 1; } -Status CreateDepthWiseConv3x3(const CreationContext& creation_context, - const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, - DepthWiseConv3x3* result) { +absl::Status CreateDepthWiseConv3x3( + const CreationContext& creation_context, const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result) { if (!IsDepthWiseConv3x3Supported(attr)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "DepthWiseConv3x3 doesn't support this attributes"); } bool weights_are_buffer = @@ -364,9 +363,8 @@ Status CreateDepthWiseConv3x3(const CreationContext& creation_context, bool local_mem_uploads = weights_are_buffer && creation_context.device->IsPowerVR(); *result = DepthWiseConv3x3(definition, weights_are_buffer, local_mem_uploads); - RETURN_IF_ERROR(result->UploadWeightsAndBiases(attr.weights, attr.bias, - creation_context.context)); - return OkStatus(); + return result->UploadWeightsAndBiases(attr.weights, attr.bias, + creation_context.context); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h index 1630557afc9..936ab773229 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h @@ -38,10 +38,10 @@ namespace cl { class DepthWiseConv3x3 : public GPUOperation { public: DepthWiseConv3x3() = default; - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only DepthWiseConv3x3(DepthWiseConv3x3&& operation); @@ -53,11 +53,11 @@ class DepthWiseConv3x3 : public GPUOperation { explicit DepthWiseConv3x3(const OperationDef& definition, bool weights_are_buffer, bool local_mem_uploads); template - Status UploadWeightsAndBiases(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + absl::Status UploadWeightsAndBiases( + const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, CLContext* context); - friend Status CreateDepthWiseConv3x3( + friend absl::Status CreateDepthWiseConv3x3( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result); @@ -66,7 +66,7 @@ class DepthWiseConv3x3 : public GPUOperation { const ::tflite::gpu::Tensor& weights, const ::tflite::gpu::Tensor& biases, absl::Span dst); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; bool weights_are_buffer_; @@ -80,7 +80,7 @@ class DepthWiseConv3x3 : public GPUOperation { }; template -Status DepthWiseConv3x3::UploadWeightsAndBiases( +absl::Status DepthWiseConv3x3::UploadWeightsAndBiases( const ::tflite::gpu::Tensor& weights, const ::tflite::gpu::Tensor& biases, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -122,7 +122,7 @@ Status DepthWiseConv3x3::UploadWeightsAndBiases( weights_ = weights_tex2d_.GetMemoryPtr(); } - return OkStatus(); + return absl::OkStatus(); } template @@ -160,10 +160,9 @@ void DepthWiseConv3x3::RearrangeWeightsAndBiasesData( bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr); -Status CreateDepthWiseConv3x3(const CreationContext& creation_context, - const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, - DepthWiseConv3x3* result); +absl::Status CreateDepthWiseConv3x3( + const CreationContext& creation_context, const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc index 7c394a45669..e435bccef03 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc @@ -203,14 +203,14 @@ std::string ElementwiseTwoInput::GetArgsDeclaration() const { return args; } -Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) { +absl::Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) { if (use_scalar_para_) { RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_para_)); } else { RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr())); RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB())); } - return OkStatus(); + return absl::OkStatus(); } ElementwiseTwoInput CreateElementwiseTwoInput( diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h index 8bf33b0c128..4c85fee6071 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h @@ -75,7 +75,7 @@ class ElementwiseTwoInput : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - Status BindArguments(CLKernel* kernel) override; + absl::Status BindArguments(CLKernel* kernel) override; inline void SetScalarPara(FLT scalar) { scalar_para_ = scalar; use_scalar_para_ = true; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc index 44a3e97554c..f93648f82fc 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc @@ -113,7 +113,7 @@ FullyConnected& FullyConnected::operator=(FullyConnected&& kernel) { return *this; } -Status FullyConnected::Compile(const CreationContext& creation_context) { +absl::Status FullyConnected::Compile(const CreationContext& creation_context) { int wg_width = 32; int wg_height = 4; int work_items; @@ -134,10 +134,10 @@ Status FullyConnected::Compile(const CreationContext& creation_context) { } work_items = work_group_size_.x * work_group_size_.y * work_group_size_.z; } while (work_items > kernel_.GetMaxWorkGroupSize()); - return OkStatus(); + return absl::OkStatus(); } -Status FullyConnected::AddToQueue(CLCommandQueue* queue) { +absl::Status FullyConnected::AddToQueue(CLCommandQueue* queue) { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -146,15 +146,14 @@ Status FullyConnected::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR( kernel_.SetBytesAuto(int2(src_[0]->Slices(), dst_[0]->Slices()))); - return queue->DispatchImplicit(kernel_, {dst_[0]->Slices(), 1, 1}, work_group_size_); } -Status CreateFullyConnected(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - FullyConnected* result) { +absl::Status CreateFullyConnected(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + FullyConnected* result) { *result = FullyConnected(definition); RETURN_IF_ERROR( result->UploadWeights(attr.weights, creation_context.context)); @@ -165,7 +164,7 @@ Status CreateFullyConnected(const CreationContext& creation_context, create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h index 83ac279a71b..bc7cbd32fb0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h @@ -37,9 +37,9 @@ namespace cl { class FullyConnected : public GPUOperation { public: FullyConnected() = default; - Status AddToQueue(CLCommandQueue* queue) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only FullyConnected(FullyConnected&& kernel); @@ -49,14 +49,13 @@ class FullyConnected : public GPUOperation { private: explicit FullyConnected(const OperationDef& definition); - friend Status CreateFullyConnected(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - FullyConnected* result); + friend absl::Status CreateFullyConnected( + const CreationContext& creation_context, const OperationDef& definition, + const FullyConnectedAttributes& attr, FullyConnected* result); template - Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeights(const ::tflite::gpu::Tensor& weights, @@ -69,7 +68,7 @@ class FullyConnected : public GPUOperation { }; template -Status FullyConnected::UploadWeights( +absl::Status FullyConnected::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); @@ -123,10 +122,10 @@ void FullyConnected::RearrangeWeights( } } -Status CreateFullyConnected(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - FullyConnected* result); +absl::Status CreateFullyConnected(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + FullyConnected* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc index 4972bb9f737..9f4c9871123 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc @@ -154,7 +154,7 @@ ElementwiseOperation& ElementwiseOperation::operator=( return *this; } -Status ElementwiseOperation::BindArguments() { +absl::Status ElementwiseOperation::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArguments(&kernel_)); @@ -162,7 +162,7 @@ Status ElementwiseOperation::BindArguments() { RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 ElementwiseOperation::GetGridSize() const { @@ -172,19 +172,20 @@ int3 ElementwiseOperation::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status ElementwiseOperation::Compile(const CreationContext& creation_context) { +absl::Status ElementwiseOperation::Compile( + const CreationContext& creation_context) { const auto code = GetElementWiseCode(definition_, *this, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -Status ElementwiseOperation::AddToQueue(CLCommandQueue* queue) { +absl::Status ElementwiseOperation::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status ElementwiseOperation::Tune(const TuningParameters& params) { +absl::Status ElementwiseOperation::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } @@ -209,12 +210,12 @@ std::string PostProcess(const std::vector& linked_ops, return code; } -Status BindArgs(CLKernel* kernel, - const std::vector& linked_ops) { +absl::Status BindArgs(CLKernel* kernel, + const std::vector& linked_ops) { for (auto linked_op : linked_ops) { RETURN_IF_ERROR(linked_op->BindArguments(kernel)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h index 4507f0eb81d..17817682bce 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h @@ -96,11 +96,15 @@ class GPUOperation { void SetSrc(Tensor* ptr, int index = 0); void SetDst(Tensor* ptr, int index = 0); - virtual Status AddToQueue(CLCommandQueue* queue) { return OkStatus(); } - virtual Status Tune(const TuningParameters& params) { return OkStatus(); } + virtual absl::Status AddToQueue(CLCommandQueue* queue) { + return absl::OkStatus(); + } + virtual absl::Status Tune(const TuningParameters& params) { + return absl::OkStatus(); + } - virtual Status Compile(const CreationContext& creation_context) { - return OkStatus(); + virtual absl::Status Compile(const CreationContext& creation_context) { + return absl::OkStatus(); } const OperationDef& GetDefinition() const { return definition_; } @@ -127,10 +131,10 @@ class ElementwiseOperation : public GPUOperation { : GPUOperation(definition) {} virtual ~ElementwiseOperation() {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only ElementwiseOperation(ElementwiseOperation&& operation); @@ -150,10 +154,12 @@ class ElementwiseOperation : public GPUOperation { virtual std::string GetCoreCode(const LinkingContext& context) const = 0; virtual std::string GetArgsDeclaration() const { return ""; } - virtual Status BindArguments(CLKernel* kernel) { return OkStatus(); } + virtual absl::Status BindArguments(CLKernel* kernel) { + return absl::OkStatus(); + } protected: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; int3 work_group_size_ = int3(8, 4, 1); @@ -171,8 +177,8 @@ std::string PostProcess(const std::vector& linked_ops, // Binds arguments to given kernel for elementwise operations in // linked_ops. // Every ElementwiseOperation can bind her arguments. -Status BindArgs(CLKernel* kernel, - const std::vector& linked_ops); +absl::Status BindArgs(CLKernel* kernel, + const std::vector& linked_ops); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc index f2e53a06908..77eea07f278 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc @@ -121,14 +121,14 @@ LSTM& LSTM::operator=(LSTM&& kernel) { return *this; } -Status LSTM::Compile(const CreationContext& creation_context) { +absl::Status LSTM::Compile(const CreationContext& creation_context) { const auto code = GetLSTMCode(definition_, *creation_context.device); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -Status LSTM::BindArguments() { +absl::Status LSTM::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr())); @@ -137,8 +137,7 @@ Status LSTM::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch())); - - return OkStatus(); + return absl::OkStatus(); } int3 LSTM::GetGridSize() const { @@ -148,12 +147,12 @@ int3 LSTM::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status LSTM::Tune(const TuningParameters& params) { +absl::Status LSTM::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status LSTM::AddToQueue(CLCommandQueue* queue) { +absl::Status LSTM::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h index 3e84887cdc2..27b072ed001 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h @@ -28,9 +28,9 @@ namespace cl { class LSTM : public GPUOperation { public: explicit LSTM(const OperationDef& definition); - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only LSTM(LSTM&& kernel); @@ -39,7 +39,7 @@ class LSTM : public GPUOperation { LSTM& operator=(const LSTM&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc index 194daee5f1e..56109fc713b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc @@ -218,7 +218,7 @@ MaxUnpooling& MaxUnpooling::operator=(MaxUnpooling&& kernel) { return *this; } -Status MaxUnpooling::Compile(const CreationContext& creation_context) { +absl::Status MaxUnpooling::Compile(const CreationContext& creation_context) { const auto code = GetMaxUnpoolingKernelCode( definition_, *creation_context.device, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -226,7 +226,7 @@ Status MaxUnpooling::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status MaxUnpooling::BindArguments() { +absl::Status MaxUnpooling::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr())); @@ -237,8 +237,7 @@ Status MaxUnpooling::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_)); - - return OkStatus(); + return absl::OkStatus(); } int3 MaxUnpooling::GetGridSize() const { @@ -248,12 +247,12 @@ int3 MaxUnpooling::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status MaxUnpooling::Tune(const TuningParameters& params) { +absl::Status MaxUnpooling::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status MaxUnpooling::AddToQueue(CLCommandQueue* queue) { +absl::Status MaxUnpooling::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -291,7 +290,7 @@ MaxUnpooling3D& MaxUnpooling3D::operator=(MaxUnpooling3D&& kernel) { return *this; } -Status MaxUnpooling3D::Compile(const CreationContext& creation_context) { +absl::Status MaxUnpooling3D::Compile(const CreationContext& creation_context) { const auto code = GetMaxUnpooling3DKernelCode( definition_, *creation_context.device, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -299,7 +298,7 @@ Status MaxUnpooling3D::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status MaxUnpooling3D::BindArguments() { +absl::Status MaxUnpooling3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr())); @@ -316,8 +315,7 @@ Status MaxUnpooling3D::BindArguments() { kernel_.SetBytesAuto(int4(padding_.x, padding_.y, padding_.z, 1))); RETURN_IF_ERROR( kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1))); - - return OkStatus(); + return absl::OkStatus(); } int3 MaxUnpooling3D::GetGridSize() const { @@ -327,12 +325,12 @@ int3 MaxUnpooling3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status MaxUnpooling3D::Tune(const TuningParameters& params) { +absl::Status MaxUnpooling3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status MaxUnpooling3D::AddToQueue(CLCommandQueue* queue) { +absl::Status MaxUnpooling3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h index c7479acb728..19184ee1e89 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h @@ -29,10 +29,10 @@ class MaxUnpooling : public GPUOperation { public: MaxUnpooling(const OperationDef& definition, const MaxUnpooling2DAttributes& attr); - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only MaxUnpooling(MaxUnpooling&& kernel); @@ -41,7 +41,7 @@ class MaxUnpooling : public GPUOperation { MaxUnpooling& operator=(const MaxUnpooling&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; int2 stride_; @@ -59,10 +59,10 @@ class MaxUnpooling3D : public GPUOperation { public: MaxUnpooling3D(const OperationDef& definition, const MaxUnpooling3DAttributes& attr); - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only MaxUnpooling3D(MaxUnpooling3D&& kernel); @@ -71,7 +71,7 @@ class MaxUnpooling3D : public GPUOperation { MaxUnpooling3D& operator=(const MaxUnpooling3D&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; int3 stride_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc index 9dd0546c059..f79a30e33dd 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc @@ -103,7 +103,7 @@ Mean& Mean::operator=(Mean&& operation) { return *this; } -Status Mean::Compile(const CreationContext& creation_context) { +absl::Status Mean::Compile(const CreationContext& creation_context) { if (creation_context.device->IsAdreno3xx()) { work_group_size_ = int3(16, 8, 1); } @@ -114,7 +114,7 @@ Status Mean::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Mean::BindArguments() { +absl::Status Mean::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -124,7 +124,7 @@ Status Mean::BindArguments() { const double size_0 = work_group_size_.x * work_group_size_.y; const double size_1 = total_size / size_0; RETURN_IF_ERROR(kernel_.SetBytesAuto(float2(1.0 / size_1, 1.0 / size_0))); - return OkStatus(); + return absl::OkStatus(); } int3 Mean::GetGridSize() const { @@ -134,7 +134,7 @@ int3 Mean::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Mean::AddToQueue(CLCommandQueue* queue) { +absl::Status Mean::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean.h b/tensorflow/lite/delegates/gpu/cl/kernels/mean.h index 0c0d3fff81c..4525551b5f2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean.h @@ -30,9 +30,9 @@ class Mean : public GPUOperation { public: Mean() = default; explicit Mean(const OperationDef& definition) : GPUOperation(definition) {} - Status AddToQueue(CLCommandQueue* queue) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Mean(Mean&& operation); @@ -41,7 +41,7 @@ class Mean : public GPUOperation { Mean& operator=(const Mean&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc index 45f48246078..fde0712a412 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc @@ -89,7 +89,7 @@ std::string MultiplyAdd::GetArgsDeclaration() const { return args; } -Status MultiplyAdd::BindArguments(CLKernel* kernel) { +absl::Status MultiplyAdd::BindArguments(CLKernel* kernel) { if (use_mul_vec_) { RETURN_IF_ERROR(kernel->SetMemoryAuto(mul_vec_.GetMemoryPtr())); } @@ -102,12 +102,12 @@ Status MultiplyAdd::BindArguments(CLKernel* kernel) { if (scalar_add_.Active()) { RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_add_)); } - return OkStatus(); + return absl::OkStatus(); } -Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr, - CalculationsPrecision scalar_precision, - CLContext* context) { +absl::Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr, + CalculationsPrecision scalar_precision, + CLContext* context) { auto mul = absl::get_if<::tflite::gpu::Tensor>( &attr.param); auto mul_scalar = absl::get_if(&attr.param); @@ -116,12 +116,12 @@ Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr, } else { scalar_mul_ = FLT(scalar_precision, *mul_scalar); } - return OkStatus(); + return absl::OkStatus(); } -Status MultiplyAdd::UploadAdd(const AddAttributes& attr, - CalculationsPrecision scalar_precision, - CLContext* context) { +absl::Status MultiplyAdd::UploadAdd(const AddAttributes& attr, + CalculationsPrecision scalar_precision, + CLContext* context) { auto add = absl::get_if<::tflite::gpu::Tensor>( &attr.param); auto add_scalar = absl::get_if(&attr.param); @@ -130,12 +130,13 @@ Status MultiplyAdd::UploadAdd(const AddAttributes& attr, } else { scalar_add_ = FLT(scalar_precision, *add_scalar); } - return OkStatus(); + return absl::OkStatus(); } -Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& attr, MultiplyAdd* result) { +absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& attr, + MultiplyAdd* result) { const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 : definition.precision; @@ -143,12 +144,12 @@ Status CreateMultiplyAdd(const CreationContext& creation_context, RETURN_IF_ERROR( result->UploadMul(attr, scalar_precision, creation_context.context)); result->SetLinkIndex(0); - return OkStatus(); + return absl::OkStatus(); } -Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const AddAttributes& attr, MultiplyAdd* result) { +absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const AddAttributes& attr, MultiplyAdd* result) { const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 : definition.precision; @@ -156,13 +157,14 @@ Status CreateMultiplyAdd(const CreationContext& creation_context, RETURN_IF_ERROR( result->UploadAdd(attr, scalar_precision, creation_context.context)); result->SetLinkIndex(0); - return OkStatus(); + return absl::OkStatus(); } -Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& mul_attr, - const AddAttributes& add_attr, MultiplyAdd* result) { +absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& mul_attr, + const AddAttributes& add_attr, + MultiplyAdd* result) { const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 : definition.precision; @@ -172,7 +174,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context, RETURN_IF_ERROR( result->UploadAdd(add_attr, scalar_precision, creation_context.context)); result->SetLinkIndex(0); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h index 83bb6e11216..4047a7e5c1b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h @@ -40,40 +40,42 @@ class MultiplyAdd : public ElementwiseOperation { MultiplyAdd(const MultiplyAdd&) = delete; MultiplyAdd& operator=(const MultiplyAdd&) = delete; - Status UploadMul(const MultiplyAttributes& attr, - CalculationsPrecision scalar_precision, CLContext* context); - Status UploadAdd(const AddAttributes& attr, - CalculationsPrecision scalar_precision, CLContext* context); + absl::Status UploadMul(const MultiplyAttributes& attr, + CalculationsPrecision scalar_precision, + CLContext* context); + absl::Status UploadAdd(const AddAttributes& attr, + CalculationsPrecision scalar_precision, + CLContext* context); template - Status UploadMul(const ::tflite::gpu::Tensor& mul, - CLContext* context); + absl::Status UploadMul(const ::tflite::gpu::Tensor& mul, + CLContext* context); template - Status UploadAdd(const ::tflite::gpu::Tensor& add, - CLContext* context); + absl::Status UploadAdd(const ::tflite::gpu::Tensor& add, + CLContext* context); void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - Status BindArguments(CLKernel* kernel) override; + absl::Status BindArguments(CLKernel* kernel) override; - friend Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& attr, - MultiplyAdd* result); + friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& attr, + MultiplyAdd* result); - friend Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const AddAttributes& attr, - MultiplyAdd* result); + friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const AddAttributes& attr, + MultiplyAdd* result); - friend Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& mul_attr, - const AddAttributes& add_attr, - MultiplyAdd* result); + friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& mul_attr, + const AddAttributes& add_attr, + MultiplyAdd* result); private: explicit MultiplyAdd(const OperationDef& definition) @@ -89,41 +91,43 @@ class MultiplyAdd : public ElementwiseOperation { FLT scalar_add_; }; -Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& attr, MultiplyAdd* result); +absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& attr, + MultiplyAdd* result); -Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const AddAttributes& attr, MultiplyAdd* result); +absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const AddAttributes& attr, MultiplyAdd* result); -Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& mul_attr, - const AddAttributes& add_attr, MultiplyAdd* result); +absl::Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& mul_attr, + const AddAttributes& add_attr, + MultiplyAdd* result); template -Status MultiplyAdd::UploadMul(const ::tflite::gpu::Tensor& mul, - CLContext* context) { +absl::Status MultiplyAdd::UploadMul(const ::tflite::gpu::Tensor& mul, + CLContext* context) { LinearStorageCreateInfo create_info; create_info.storage_type = DeduceLinearStorageType(definition_.GetPrimaryStorageType()); create_info.data_type = definition_.GetDataType(); RETURN_IF_ERROR(CreateLinearStorage(create_info, mul, context, &mul_vec_)); use_mul_vec_ = true; - return OkStatus(); + return absl::OkStatus(); } template -Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor& add, - CLContext* context) { +absl::Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor& add, + CLContext* context) { LinearStorageCreateInfo create_info; create_info.storage_type = DeduceLinearStorageType(definition_.GetPrimaryStorageType()); create_info.data_type = definition_.GetDataType(); RETURN_IF_ERROR(CreateLinearStorage(create_info, add, context, &add_vec_)); use_add_vec_ = true; - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc index 1443f5958db..48edcb448a1 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc @@ -169,7 +169,7 @@ Padding& Padding::operator=(Padding&& kernel) { return *this; } -Status Padding::Compile(const CreationContext& creation_context) { +absl::Status Padding::Compile(const CreationContext& creation_context) { const auto code = GetPaddingCode(definition_, linked_operations_, attributes_); return creation_context.cache->GetOrCreateCLKernel( @@ -177,7 +177,7 @@ Status Padding::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Padding::BindArguments() { +absl::Status Padding::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -187,7 +187,7 @@ Status Padding::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); const auto& prep = attributes_.prepended; RETURN_IF_ERROR(kernel_.SetBytesAuto(int4(prep.w, prep.h, prep.c, prep.b))); - return OkStatus(); + return absl::OkStatus(); } int3 Padding::GetGridSize() const { @@ -197,12 +197,12 @@ int3 Padding::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Padding::Tune(const TuningParameters& params) { +absl::Status Padding::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status Padding::AddToQueue(CLCommandQueue* queue) { +absl::Status Padding::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.h b/tensorflow/lite/delegates/gpu/cl/kernels/padding.h index 38e78d4a461..ddf9f9583be 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.h @@ -28,10 +28,10 @@ namespace cl { class Padding : public GPUOperation { public: Padding(const OperationDef& definition, const PadAttributes& attr); - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Padding(Padding&& kernel); @@ -40,7 +40,7 @@ class Padding : public GPUOperation { Padding& operator=(const Padding&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; PadAttributes attributes_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc index 17705782f93..fb985461c02 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc @@ -408,7 +408,7 @@ Pooling& Pooling::operator=(Pooling&& kernel) { return *this; } -Status Pooling::Compile(const CreationContext& creation_context) { +absl::Status Pooling::Compile(const CreationContext& creation_context) { std::string code; const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; @@ -423,7 +423,7 @@ Status Pooling::Compile(const CreationContext& creation_context) { linked_operations_, output_indices_); break; default: - return InvalidArgumentError( + return absl::InvalidArgumentError( "You should create another kernel with this params"); break; } @@ -432,7 +432,7 @@ Status Pooling::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Pooling::BindArguments() { +absl::Status Pooling::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -447,7 +447,7 @@ Status Pooling::BindArguments() { kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y))); RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_)); - return OkStatus(); + return absl::OkStatus(); } int3 Pooling::GetGridSize() const { @@ -457,12 +457,12 @@ int3 Pooling::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Pooling::Tune(const TuningParameters& params) { +absl::Status Pooling::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status Pooling::AddToQueue(CLCommandQueue* queue) { +absl::Status Pooling::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -506,7 +506,7 @@ Pooling3D& Pooling3D::operator=(Pooling3D&& kernel) { return *this; } -Status Pooling3D::Compile(const CreationContext& creation_context) { +absl::Status Pooling3D::Compile(const CreationContext& creation_context) { std::string code; const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; @@ -521,7 +521,7 @@ Status Pooling3D::Compile(const CreationContext& creation_context) { linked_operations_, output_indices_); break; default: - return InvalidArgumentError( + return absl::InvalidArgumentError( "You should create another kernel with this params"); break; } @@ -530,7 +530,7 @@ Status Pooling3D::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Pooling3D::BindArguments() { +absl::Status Pooling3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -550,7 +550,7 @@ Status Pooling3D::BindArguments() { RETURN_IF_ERROR( kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1))); - return OkStatus(); + return absl::OkStatus(); } int3 Pooling3D::GetGridSize() const { @@ -560,12 +560,12 @@ int3 Pooling3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Pooling3D::Tune(const TuningParameters& params) { +absl::Status Pooling3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status Pooling3D::AddToQueue(CLCommandQueue* queue) { +absl::Status Pooling3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h index eaeb188f19e..09d2d5260f7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h @@ -30,10 +30,10 @@ namespace cl { class Pooling : public GPUOperation { public: Pooling(const OperationDef& definition, const Pooling2DAttributes& attr); - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Pooling(Pooling&& kernel); @@ -42,7 +42,7 @@ class Pooling : public GPUOperation { Pooling& operator=(const Pooling&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; int2 stride_; @@ -62,10 +62,10 @@ Pooling CreatePooling(const OperationDef& definition, class Pooling3D : public GPUOperation { public: Pooling3D(const OperationDef& definition, const Pooling3DAttributes& attr); - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Pooling3D(Pooling3D&& kernel); @@ -74,7 +74,7 @@ class Pooling3D : public GPUOperation { Pooling3D& operator=(const Pooling3D&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; int3 stride_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc index 8aa357b91b4..1879d390ad6 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc @@ -73,21 +73,21 @@ std::string PReLU::GetArgsDeclaration() const { return args; } -Status PReLU::BindArguments(CLKernel* kernel) { +absl::Status PReLU::BindArguments(CLKernel* kernel) { RETURN_IF_ERROR(kernel->SetMemoryAuto(alpha_.GetMemoryPtr())); if (clip_.Active()) { RETURN_IF_ERROR(kernel->SetBytesAuto(clip_)); } - return OkStatus(); + return absl::OkStatus(); } -Status CreatePReLU(const CreationContext& creation_context, - const OperationDef& definition, const PReLUAttributes& attr, - PReLU* result) { +absl::Status CreatePReLU(const CreationContext& creation_context, + const OperationDef& definition, + const PReLUAttributes& attr, PReLU* result) { auto alpha = absl::get_if<::tflite::gpu::Tensor>( &attr.alpha); if (!alpha) { - return InvalidArgumentError("Alpha is missing"); + return absl::InvalidArgumentError("Alpha is missing"); } const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 @@ -95,7 +95,7 @@ Status CreatePReLU(const CreationContext& creation_context, *result = PReLU(definition, attr, scalar_precision); RETURN_IF_ERROR(result->UploadParameters(*alpha, creation_context.context)); result->SetLinkIndex(0); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h index 0feb387e644..4ba0a92158f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h @@ -44,30 +44,30 @@ class PReLU : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - Status BindArguments(CLKernel* kernel) override; + absl::Status BindArguments(CLKernel* kernel) override; - friend Status CreatePReLU(const CreationContext& creation_context, - const OperationDef& definition, - const PReLUAttributes& attr, PReLU* result); + friend absl::Status CreatePReLU(const CreationContext& creation_context, + const OperationDef& definition, + const PReLUAttributes& attr, PReLU* result); private: PReLU(const OperationDef& definition, const PReLUAttributes& attr, CalculationsPrecision scalar_precision); template - Status UploadParameters(const ::tflite::gpu::Tensor& parameters, - CLContext* context); + absl::Status UploadParameters( + const ::tflite::gpu::Tensor& parameters, CLContext* context); FLT clip_; LinearStorage alpha_; }; -Status CreatePReLU(const CreationContext& creation_context, - const OperationDef& definition, const PReLUAttributes& attr, - PReLU* result); +absl::Status CreatePReLU(const CreationContext& creation_context, + const OperationDef& definition, + const PReLUAttributes& attr, PReLU* result); template -Status PReLU::UploadParameters( +absl::Status PReLU::UploadParameters( const ::tflite::gpu::Tensor& parameters, CLContext* context) { LinearStorageCreateInfo create_info; create_info.storage_type = @@ -75,7 +75,7 @@ Status PReLU::UploadParameters( create_info.data_type = definition_.GetPrimaryDataType(); RETURN_IF_ERROR( CreateLinearStorage(create_info, parameters, context, &alpha_)); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc index f7751fac6ff..e0346a66ff9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc @@ -92,17 +92,17 @@ std::string QuantizeAndDequantize::GetArgsDeclaration() const { scale_.GetDeclaration()); } -Status QuantizeAndDequantize::BindArguments(CLKernel* kernel) { +absl::Status QuantizeAndDequantize::BindArguments(CLKernel* kernel) { RETURN_IF_ERROR(kernel->SetBytesAuto(min_)); RETURN_IF_ERROR(kernel->SetBytesAuto(max_)); RETURN_IF_ERROR(kernel->SetBytesAuto(scale_)); - return OkStatus(); + return absl::OkStatus(); } -Status CreateQuantizeAndDequantize(const CreationContext& creation_context, - const OperationDef& definition, - const QuantizeAndDequantizeAttributes& attr, - QuantizeAndDequantize* result) { +absl::Status CreateQuantizeAndDequantize( + const CreationContext& creation_context, const OperationDef& definition, + const QuantizeAndDequantizeAttributes& attr, + QuantizeAndDequantize* result) { const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 : definition.precision; @@ -120,7 +120,7 @@ Status CreateQuantizeAndDequantize(const CreationContext& creation_context, *result = QuantizeAndDequantize(definition, attr, scalar_precision); } result->SetLinkIndex(0); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h index 07fa8f21773..41c295e881d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h @@ -57,9 +57,9 @@ class QuantizeAndDequantize : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - Status BindArguments(CLKernel* kernel) override; + absl::Status BindArguments(CLKernel* kernel) override; - friend Status CreateQuantizeAndDequantize( + friend absl::Status CreateQuantizeAndDequantize( const CreationContext& creation_context, const OperationDef& definition, const QuantizeAndDequantizeAttributes& attr, QuantizeAndDequantize* result); @@ -70,27 +70,26 @@ class QuantizeAndDequantize : public ElementwiseOperation { CalculationsPrecision scalar_precision); template - Status UploadParameters(const ::tflite::gpu::Tensor& parameters, - CLContext* context); + absl::Status UploadParameters( + const ::tflite::gpu::Tensor& parameters, CLContext* context); FLT min_; FLT max_; FLT scale_; }; -Status CreateQuantizeAndDequantize(const CreationContext& creation_context, - const OperationDef& definition, - const QuantizeAndDequantizeAttributes& attr, - QuantizeAndDequantize* result); +absl::Status CreateQuantizeAndDequantize( + const CreationContext& creation_context, const OperationDef& definition, + const QuantizeAndDequantizeAttributes& attr, QuantizeAndDequantize* result); template -Status QuantizeAndDequantize::UploadParameters( +absl::Status QuantizeAndDequantize::UploadParameters( const ::tflite::gpu::Tensor& parameters, CLContext* context) { LinearStorageCreateInfo create_info; create_info.storage_type = DeduceLinearStorageType(definition_.GetPrimaryStorageType()); create_info.data_type = definition_.GetPrimaryDataType(); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/relu.cc b/tensorflow/lite/delegates/gpu/cl/kernels/relu.cc index ce903972c35..a96db2aa45e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/relu.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/relu.cc @@ -80,14 +80,14 @@ std::string ReLU::GetArgsDeclaration() const { return args; } -Status ReLU::BindArguments(CLKernel* kernel) { +absl::Status ReLU::BindArguments(CLKernel* kernel) { if (alpha_.Active()) { RETURN_IF_ERROR(kernel->SetBytesAuto(alpha_)); } if (clip_.Active()) { RETURN_IF_ERROR(kernel->SetBytesAuto(clip_)); } - return OkStatus(); + return absl::OkStatus(); } ReLU CreateReLU(const CreationContext& creation_context, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/relu.h b/tensorflow/lite/delegates/gpu/cl/kernels/relu.h index c4fb68588d3..c8260a33faf 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/relu.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/relu.h @@ -37,7 +37,7 @@ class ReLU : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - Status BindArguments(CLKernel* kernel) override; + absl::Status BindArguments(CLKernel* kernel) override; friend ReLU CreateReLU(const CreationContext& creation_context, const OperationDef& definition, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc index 3bb3cdd5d22..e1589e9d682 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc @@ -156,7 +156,7 @@ Reshape& Reshape::operator=(Reshape&& operation) { return *this; } -Status Reshape::Compile(const CreationContext& creation_context) { +absl::Status Reshape::Compile(const CreationContext& creation_context) { const auto code = definition_.IsBatchSupported() ? GetReshapeBatchedCode(definition_, linked_operations_) : GetReshapeCode(definition_, linked_operations_); @@ -165,7 +165,7 @@ Status Reshape::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Reshape::BindArguments() { +absl::Status Reshape::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -174,8 +174,7 @@ Status Reshape::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels())); - - return OkStatus(); + return absl::OkStatus(); } int3 Reshape::GetGridSize() const { @@ -185,12 +184,12 @@ int3 Reshape::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Reshape::Tune(const TuningParameters& params) { +absl::Status Reshape::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status Reshape::AddToQueue(CLCommandQueue* queue) { +absl::Status Reshape::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h index 2117ef05907..e11c066ebd3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h @@ -29,10 +29,10 @@ class Reshape : public GPUOperation { public: explicit Reshape(const OperationDef& definition) : GPUOperation(definition), work_group_size_(8, 4, 1) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Reshape(Reshape&& operation); @@ -41,7 +41,7 @@ class Reshape : public GPUOperation { Reshape& operator=(const Reshape&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc index 3741a02aa5b..de6813e741f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc @@ -120,7 +120,7 @@ Reshapex4& Reshapex4::operator=(Reshapex4&& operation) { return *this; } -Status Reshapex4::Compile(const CreationContext& creation_context) { +absl::Status Reshapex4::Compile(const CreationContext& creation_context) { const auto code = definition_.IsBatchSupported() ? GetReshapeBatchedCode(definition_, linked_operations_) : GetReshapeCode(definition_, linked_operations_); @@ -129,15 +129,14 @@ Status Reshapex4::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Reshapex4::BindArguments() { +absl::Status Reshapex4::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - - return OkStatus(); + return absl::OkStatus(); } int3 Reshapex4::GetGridSize() const { @@ -147,12 +146,12 @@ int3 Reshapex4::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Reshapex4::Tune(const TuningParameters& params) { +absl::Status Reshapex4::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status Reshapex4::AddToQueue(CLCommandQueue* queue) { +absl::Status Reshapex4::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h index 656e299b547..d61224a7367 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h @@ -30,10 +30,10 @@ class Reshapex4 : public GPUOperation { public: explicit Reshapex4(const OperationDef& definition) : GPUOperation(definition), work_group_size_(8, 4, 1) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Reshapex4(Reshapex4&& operation); @@ -42,7 +42,7 @@ class Reshapex4 : public GPUOperation { Reshapex4& operator=(const Reshapex4&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc index bd109020004..5d578fe6e09 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc @@ -209,7 +209,7 @@ Resize& Resize::operator=(Resize&& operation) { return *this; } -Status Resize::Compile(const CreationContext& creation_context) { +absl::Status Resize::Compile(const CreationContext& creation_context) { const auto code = GetResizeCode(definition_, attr_.type, attr_.half_pixel_centers, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -217,7 +217,7 @@ Status Resize::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Resize::BindArguments() { +absl::Status Resize::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -230,7 +230,7 @@ Status Resize::BindArguments() { float2(CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_), CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor)); - return OkStatus(); + return absl::OkStatus(); } int3 Resize::GetGridSize() const { @@ -240,12 +240,12 @@ int3 Resize::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Resize::AddToQueue(CLCommandQueue* queue) { +absl::Status Resize::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status Resize::Tune(const TuningParameters& params) { +absl::Status Resize::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } @@ -271,7 +271,7 @@ Resize3D& Resize3D::operator=(Resize3D&& operation) { return *this; } -Status Resize3D::Compile(const CreationContext& creation_context) { +absl::Status Resize3D::Compile(const CreationContext& creation_context) { const auto code = GetResize3DCode(definition_, attr_.type, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -279,7 +279,7 @@ Status Resize3D::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status Resize3D::BindArguments() { +absl::Status Resize3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -296,7 +296,7 @@ Status Resize3D::BindArguments() { CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_), CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_), 1.0f); RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor)); - return OkStatus(); + return absl::OkStatus(); } int3 Resize3D::GetGridSize() const { @@ -306,12 +306,12 @@ int3 Resize3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Resize3D::AddToQueue(CLCommandQueue* queue) { +absl::Status Resize3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status Resize3D::Tune(const TuningParameters& params) { +absl::Status Resize3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h index a80f9a98382..04459e12ff9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h @@ -27,10 +27,10 @@ namespace cl { class Resize : public GPUOperation { public: - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Resize(Resize&& operation); @@ -45,7 +45,7 @@ class Resize : public GPUOperation { Resize(const OperationDef& definition, const Resize2DAttributes& attr) : GPUOperation(definition), attr_(attr) {} - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Resize2DAttributes attr_; @@ -58,10 +58,10 @@ Resize CreateResize(const OperationDef& definition, class Resize3D : public GPUOperation { public: - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Resize3D(Resize3D&& operation); @@ -76,7 +76,7 @@ class Resize3D : public GPUOperation { Resize3D(const OperationDef& definition, const Resize3DAttributes& attr) : GPUOperation(definition), attr_(attr) {} - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; Resize3DAttributes attr_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc index 350abf7f64e..0f9fcb03097 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc @@ -79,14 +79,14 @@ Softmax& Softmax::operator=(Softmax&& kernel) { return *this; } -Status Softmax::Compile(const CreationContext& creation_context) { +absl::Status Softmax::Compile(const CreationContext& creation_context) { const auto code = GetSoftmaxKernelCode(definition_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -Status Softmax::BindArguments() { +absl::Status Softmax::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -94,7 +94,7 @@ Status Softmax::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR( kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels()))); - return OkStatus(); + return absl::OkStatus(); } int3 Softmax::GetGridSize() const { @@ -104,12 +104,12 @@ int3 Softmax::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Softmax::Tune(const TuningParameters& params) { +absl::Status Softmax::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status Softmax::AddToQueue(CLCommandQueue* queue) { +absl::Status Softmax::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h index b8b7846e8de..703a40a4e89 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h @@ -30,10 +30,10 @@ class Softmax : public GPUOperation { public: Softmax() = default; explicit Softmax(const OperationDef& definition) : GPUOperation(definition) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Softmax(Softmax&& kernel); @@ -44,7 +44,7 @@ class Softmax : public GPUOperation { friend Softmax CreateSoftmax(); private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; int3 work_group_size_ = int3(8, 4, 1); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc index 168dc6ce4a9..09e6c978026 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc @@ -115,14 +115,14 @@ Softmax1x1& Softmax1x1::operator=(Softmax1x1&& kernel) { return *this; } -Status Softmax1x1::Compile(const CreationContext& creation_context) { +absl::Status Softmax1x1::Compile(const CreationContext& creation_context) { const auto code = GetSoftmaxKernelCode(definition_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -Status Softmax1x1::AddToQueue(CLCommandQueue* queue) { +absl::Status Softmax1x1::AddToQueue(CLCommandQueue* queue) { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h index 0fd5325a863..0d28145ca03 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h @@ -30,9 +30,9 @@ class Softmax1x1 : public GPUOperation { Softmax1x1() = default; explicit Softmax1x1(const OperationDef& definition) : GPUOperation(definition) {} - Status AddToQueue(CLCommandQueue* queue) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Softmax1x1(Softmax1x1&& kernel); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc index db6882ce4f4..b763684516a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc @@ -96,14 +96,14 @@ SpaceToDepth& SpaceToDepth::operator=(SpaceToDepth&& operation) { return *this; } -Status SpaceToDepth::Compile(const CreationContext& creation_context) { +absl::Status SpaceToDepth::Compile(const CreationContext& creation_context) { const auto code = GetSpaceToDepthCode(definition_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -Status SpaceToDepth::BindArguments() { +absl::Status SpaceToDepth::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -121,12 +121,12 @@ int3 SpaceToDepth::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status SpaceToDepth::Tune(const TuningParameters& params) { +absl::Status SpaceToDepth::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status SpaceToDepth::AddToQueue(CLCommandQueue* queue) { +absl::Status SpaceToDepth::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h index 3d316569fcb..9dd257a4c4d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h @@ -30,9 +30,9 @@ class SpaceToDepth : public GPUOperation { public: SpaceToDepth(const OperationDef& op_def, const SpaceToDepthAttributes& attr) : GPUOperation(op_def), attr_(attr), work_group_size_(8, 4, 1) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; + absl::Status Compile(const CreationContext& creation_context) override; SpaceToDepth(SpaceToDepth&& operation); SpaceToDepth& operator=(SpaceToDepth&& operation); @@ -40,7 +40,7 @@ class SpaceToDepth : public GPUOperation { SpaceToDepth& operator=(const SpaceToDepth&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; SpaceToDepthAttributes attr_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc index 4f5cf9b26c7..19f1b185d3c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc @@ -166,7 +166,7 @@ StridedSlice& StridedSlice::operator=(StridedSlice&& operation) { return *this; } -Status StridedSlice::Compile(const CreationContext& creation_context) { +absl::Status StridedSlice::Compile(const CreationContext& creation_context) { const auto code = GetStridedSliceCode(definition_, Is4Aligned(attributes_), linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -174,7 +174,7 @@ Status StridedSlice::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -Status StridedSlice::BindArguments() { +absl::Status StridedSlice::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -187,7 +187,7 @@ Status StridedSlice::BindArguments() { attributes_.strides.c, attributes_.strides.b))); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - return OkStatus(); + return absl::OkStatus(); } int3 StridedSlice::GetGridSize() const { @@ -197,12 +197,12 @@ int3 StridedSlice::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status StridedSlice::Tune(const TuningParameters& params) { +absl::Status StridedSlice::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status StridedSlice::AddToQueue(CLCommandQueue* queue) { +absl::Status StridedSlice::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h index f30f6777134..ee6f18fdacb 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h @@ -27,10 +27,10 @@ namespace cl { class StridedSlice : public GPUOperation { public: StridedSlice(const OperationDef& definition, const SliceAttributes& attr); - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only StridedSlice(StridedSlice&& operation); @@ -39,7 +39,7 @@ class StridedSlice : public GPUOperation { StridedSlice& operator=(const StridedSlice&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; SliceAttributes attributes_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc index cab9b728866..66a272fa2da 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc @@ -125,14 +125,14 @@ Transpose& Transpose::operator=(Transpose&& operation) { return *this; } -Status Transpose::Compile(const CreationContext& creation_context) { +absl::Status Transpose::Compile(const CreationContext& creation_context) { const auto code = GetTransposeCode(definition_, attr_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -Status Transpose::BindArguments() { +absl::Status Transpose::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -141,8 +141,7 @@ Status Transpose::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels())); - - return OkStatus(); + return absl::OkStatus(); } int3 Transpose::GetGridSize() const { @@ -152,12 +151,12 @@ int3 Transpose::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Transpose::Tune(const TuningParameters& params) { +absl::Status Transpose::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -Status Transpose::AddToQueue(CLCommandQueue* queue) { +absl::Status Transpose::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h index 22c155a79ba..61038b1e0ca 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h @@ -28,9 +28,9 @@ class Transpose : public GPUOperation { public: Transpose(const OperationDef& definition, const TransposeAttributes& attr) : GPUOperation(definition), attr_(attr), work_group_size_(8, 4, 1) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Transpose(Transpose&& operation); @@ -39,7 +39,7 @@ class Transpose : public GPUOperation { Transpose& operator=(const Transpose&) = delete; private: - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; TransposeAttributes attr_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc index 9bb89874c3d..81a8fc690c4 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc @@ -381,7 +381,7 @@ Winograd4x4To36& Winograd4x4To36::operator=(Winograd4x4To36&& operation) { return *this; } -Status Winograd4x4To36::Compile(const CreationContext& creation_context) { +absl::Status Winograd4x4To36::Compile(const CreationContext& creation_context) { std::vector options; if (creation_context.device->IsAdreno()) { options.push_back(CompilerOptions::ADRENO_MORE_WAVES); @@ -397,10 +397,10 @@ Status Winograd4x4To36::Compile(const CreationContext& creation_context) { code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); work_group_size_ = SelectBestWorkGroup(); - return OkStatus(); + return absl::OkStatus(); } -Status Winograd4x4To36::UploadBt(CLContext* context) { +absl::Status Winograd4x4To36::UploadBt(CLContext* context) { ::tflite::gpu::Tensor bt_aligned; bt_aligned.shape = Linear(6 * 8); bt_aligned.data.resize(6 * 8); @@ -427,7 +427,7 @@ int3 Winograd4x4To36::SelectBestWorkGroup() { return GetFirstSuitableWorkGroup(wgs, kernel_.GetMaxWorkGroupSize()); } -Status Winograd4x4To36::BindArguments() { +absl::Status Winograd4x4To36::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(bt_.GetMemoryPtr())); @@ -444,8 +444,7 @@ Status Winograd4x4To36::BindArguments() { kernel_.SetBytesAuto(int2(-padding_.prepended.w, -padding_.prepended.h))); RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_total)); RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_x)); - - return OkStatus(); + return absl::OkStatus(); } int3 Winograd4x4To36::GetGridSize() const { @@ -455,7 +454,7 @@ int3 Winograd4x4To36::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Winograd4x4To36::Tune(const TuningParameters& params) { +absl::Status Winograd4x4To36::Tune(const TuningParameters& params) { switch (params.tuning_type) { case TuningType::EXHAUSTIVE: RETURN_IF_ERROR(BindArguments()); @@ -464,19 +463,19 @@ Status Winograd4x4To36::Tune(const TuningParameters& params) { case TuningType::FAST: default: work_group_size_ = SelectBestWorkGroup(); - return OkStatus(); + return absl::OkStatus(); } } -Status Winograd4x4To36::AddToQueue(CLCommandQueue* queue) { +absl::Status Winograd4x4To36::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status CreateWinograd4x4To36(const CreationContext& creation_context, - const OperationDef& definition, - const Padding2D& padding, - Winograd4x4To36* result) { +absl::Status CreateWinograd4x4To36(const CreationContext& creation_context, + const OperationDef& definition, + const Padding2D& padding, + Winograd4x4To36* result) { *result = Winograd4x4To36(definition, padding); return result->UploadBt(creation_context.context); } @@ -499,7 +498,7 @@ Winograd36To4x4& Winograd36To4x4::operator=(Winograd36To4x4&& operation) { return *this; } -Status Winograd36To4x4::Compile(const CreationContext& creation_context) { +absl::Status Winograd36To4x4::Compile(const CreationContext& creation_context) { std::vector options; if (definition_.precision == CalculationsPrecision::F16 && creation_context.device->IsPowerVR()) { @@ -511,10 +510,10 @@ Status Winograd36To4x4::Compile(const CreationContext& creation_context) { code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); work_group_size_ = SelectBestWorkGroup(); - return OkStatus(); + return absl::OkStatus(); } -Status Winograd36To4x4::UploadAt(CLContext* context) { +absl::Status Winograd36To4x4::UploadAt(CLContext* context) { ::tflite::gpu::Tensor at_aligned; at_aligned.shape = Linear(4 * 8); at_aligned.data.resize(4 * 8); @@ -541,7 +540,7 @@ int3 Winograd36To4x4::SelectBestWorkGroup() { return GetFirstSuitableWorkGroup(wgs, kernel_.GetMaxWorkGroupSize()); } -Status Winograd36To4x4::BindArguments() { +absl::Status Winograd36To4x4::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(at_.GetMemoryPtr())); @@ -552,8 +551,7 @@ Status Winograd36To4x4::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); const int tiles_x = IntegralDivideRoundUp(dst_[0]->Width(), 4); RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_x)); - - return OkStatus(); + return absl::OkStatus(); } int3 Winograd36To4x4::GetGridSize() const { @@ -565,7 +563,7 @@ int3 Winograd36To4x4::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -Status Winograd36To4x4::Tune(const TuningParameters& params) { +absl::Status Winograd36To4x4::Tune(const TuningParameters& params) { switch (params.tuning_type) { case TuningType::EXHAUSTIVE: RETURN_IF_ERROR(BindArguments()); @@ -574,16 +572,16 @@ Status Winograd36To4x4::Tune(const TuningParameters& params) { case TuningType::FAST: default: work_group_size_ = SelectBestWorkGroup(); - return OkStatus(); + return absl::OkStatus(); } } -Status Winograd36To4x4::AddToQueue(CLCommandQueue* queue) { +absl::Status Winograd36To4x4::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -Status CreateWinograd36To4x4( +absl::Status CreateWinograd36To4x4( const CreationContext& creation_context, const OperationDef& definition, const ::tflite::gpu::Tensor& biases, Winograd36To4x4* result) { @@ -594,7 +592,6 @@ Status CreateWinograd36To4x4( create_info.name = "biases"; RETURN_IF_ERROR(CreateLinearStorage( create_info, biases, creation_context.context, &result->biases_)); - return result->UploadAt(creation_context.context); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h index f6b80b67f32..5a0444c4be5 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h @@ -36,9 +36,9 @@ class Winograd4x4To36 : public GPUOperation { Winograd4x4To36() = default; Winograd4x4To36(const OperationDef& definition, const Padding2D& padding) : GPUOperation(definition), padding_(padding) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Winograd4x4To36(Winograd4x4To36&& operation); @@ -47,17 +47,16 @@ class Winograd4x4To36 : public GPUOperation { Winograd4x4To36& operator=(const Winograd4x4To36&) = delete; private: - friend Status CreateWinograd4x4To36(const CreationContext& creation_context, - const OperationDef& definition, - const Padding2D& padding, - Winograd4x4To36* result); + friend absl::Status CreateWinograd4x4To36( + const CreationContext& creation_context, const OperationDef& definition, + const Padding2D& padding, Winograd4x4To36* result); - Status UploadBt(CLContext* context); + absl::Status UploadBt(CLContext* context); // Must be called after kernel compilation int3 SelectBestWorkGroup(); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; LinearStorage bt_; @@ -67,18 +66,19 @@ class Winograd4x4To36 : public GPUOperation { int3 work_group_size_ = int3(128, 1, 1); }; -Status CreateWinograd4x4To36(const CreationContext& creation_context, - const OperationDef& definition, - const Padding2D& padding, Winograd4x4To36* result); +absl::Status CreateWinograd4x4To36(const CreationContext& creation_context, + const OperationDef& definition, + const Padding2D& padding, + Winograd4x4To36* result); class Winograd36To4x4 : public GPUOperation { public: Winograd36To4x4() = default; explicit Winograd36To4x4(const OperationDef& definition) : GPUOperation(definition) {} - Status AddToQueue(CLCommandQueue* queue) override; - Status Tune(const TuningParameters& params) override; - Status Compile(const CreationContext& creation_context) override; + absl::Status AddToQueue(CLCommandQueue* queue) override; + absl::Status Tune(const TuningParameters& params) override; + absl::Status Compile(const CreationContext& creation_context) override; // Move only Winograd36To4x4(Winograd36To4x4&& operation); @@ -87,17 +87,17 @@ class Winograd36To4x4 : public GPUOperation { Winograd36To4x4& operator=(const Winograd36To4x4&) = delete; private: - friend Status CreateWinograd36To4x4( + friend absl::Status CreateWinograd36To4x4( const CreationContext& creation_context, const OperationDef& definition, const ::tflite::gpu::Tensor& biases, Winograd36To4x4* result); - Status UploadAt(CLContext* context); + absl::Status UploadAt(CLContext* context); // Must be called after kernel compilation int3 SelectBestWorkGroup(); - Status BindArguments(); + absl::Status BindArguments(); int3 GetGridSize() const; LinearStorage at_; @@ -107,7 +107,7 @@ class Winograd36To4x4 : public GPUOperation { int3 work_group_size_ = int3(128, 1, 1); }; -Status CreateWinograd36To4x4( +absl::Status CreateWinograd36To4x4( const CreationContext& creation_context, const OperationDef& definition, const ::tflite::gpu::Tensor& biases, Winograd36To4x4* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc index 7a2e54840b9..683116091b8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc @@ -75,9 +75,10 @@ std::vector GenerateWorkGroupSizesXY128Linear( return work_groups; } -Status GetBestWorkGroupAlignedToGrid(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - int3* best_work_group) { +absl::Status GetBestWorkGroupAlignedToGrid(const TuningParameters& params, + const CLKernel& kernel, + const int3& grid, + int3* best_work_group) { std::vector work_groups; RETURN_IF_ERROR(GenerateWorkGroupSizesAlignedToGrid( grid, params.info->max_work_group_sizes, kernel.GetMaxWorkGroupSize(), @@ -86,7 +87,7 @@ Status GetBestWorkGroupAlignedToGrid(const TuningParameters& params, RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex( kernel, *params.info, grid, work_groups, &best_work_group_index)); *best_work_group = work_groups[best_work_group_index]; - return OkStatus(); + return absl::OkStatus(); } int GetPenalty(int grid_size, int group_size) { @@ -202,30 +203,31 @@ int3 GetWorkGroupConv(const int3& grid, int max_size, int max_z_size) { return int3(wg_x, wg_y, wg_z); } -Status GetBestWorkGroupXY128(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - WorkGroupSizeAlignment z_alignment, - int3* best_work_group) { +absl::Status GetBestWorkGroupXY128(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + WorkGroupSizeAlignment z_alignment, + int3* best_work_group) { std::vector work_groups = GenerateWorkGroupSizesXY128( grid, kernel.GetMaxWorkGroupSize(), z_alignment); int best_work_group_index; RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex( kernel, *params.info, grid, work_groups, &best_work_group_index)); *best_work_group = work_groups[best_work_group_index]; - return OkStatus(); + return absl::OkStatus(); } -Status GetBestWorkGroupXY128Linear(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - WorkGroupSizeAlignment z_alignment, - int3* best_work_group) { +absl::Status GetBestWorkGroupXY128Linear(const TuningParameters& params, + const CLKernel& kernel, + const int3& grid, + WorkGroupSizeAlignment z_alignment, + int3* best_work_group) { std::vector work_groups = GenerateWorkGroupSizesXY128Linear( grid, kernel.GetMaxWorkGroupSize(), z_alignment); int best_work_group_index; RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex( kernel, *params.info, grid, work_groups, &best_work_group_index)); *best_work_group = work_groups[best_work_group_index]; - return OkStatus(); + return absl::OkStatus(); } bool XY128RequiresMoreWorkGroupsThenXY128Linear(int width, int height) { @@ -244,24 +246,25 @@ bool XY128RequiresMoreWorkGroupsThenXY128Linear(int width, int height) { return !have_equal_work_groups; } -Status GetBestWorkGroup(const TuningParameters& params, const CLKernel& kernel, - const int3& grid, int3* best_work_group) { +absl::Status GetBestWorkGroup(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + int3* best_work_group) { switch (params.tuning_type) { case TuningType::FAST: *best_work_group = GetWorkGroup(grid, kernel.GetMaxWorkGroupSize()); - return OkStatus(); + return absl::OkStatus(); case TuningType::EXHAUSTIVE: return GetBestWorkGroupAlignedToGrid(params, kernel, grid, best_work_group); default: *best_work_group = {8, 4, 1}; - return OkStatus(); + return absl::OkStatus(); } } -Status GetBestWorkGroupConv(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - int3* best_work_group) { +absl::Status GetBestWorkGroupConv(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + int3* best_work_group) { switch (params.tuning_type) { case TuningType::FAST: { int max_z_size = 16; @@ -271,14 +274,14 @@ Status GetBestWorkGroupConv(const TuningParameters& params, max_z_size = std::min(max_z_size, params.info->max_work_group_sizes.z); *best_work_group = GetWorkGroupConv(grid, kernel.GetMaxWorkGroupSize(), max_z_size); - return OkStatus(); + return absl::OkStatus(); } case TuningType::EXHAUSTIVE: return GetBestWorkGroupAlignedToGrid(params, kernel, grid, best_work_group); default: *best_work_group = {8, 4, 1}; - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h index 4b9801e6009..7cc60f4723f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h @@ -31,16 +31,17 @@ namespace cl { // Here and later you can find XY128, this is because 128 is SIMD width of A6xx // And XY128 means that work_group_size.x * work_group_size.y % 128 = 0 // We need it to correctly work with constants uploading on A6xx -Status GetBestWorkGroupXY128(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - WorkGroupSizeAlignment z_alignment, - int3* best_work_group); - -Status GetBestWorkGroupXY128Linear(const TuningParameters& params, +absl::Status GetBestWorkGroupXY128(const TuningParameters& params, const CLKernel& kernel, const int3& grid, WorkGroupSizeAlignment z_alignment, int3* best_work_group); +absl::Status GetBestWorkGroupXY128Linear(const TuningParameters& params, + const CLKernel& kernel, + const int3& grid, + WorkGroupSizeAlignment z_alignment, + int3* best_work_group); + int3 GetWorkGroupXY128ConvLinear(const int3& grid); int3 GetWorkGroupXY128Simple(const int3& grid); @@ -48,12 +49,13 @@ int3 GetWorkGroupXY128Conv(const int3& grid); bool XY128RequiresMoreWorkGroupsThenXY128Linear(int width, int height); -Status GetBestWorkGroup(const TuningParameters& params, const CLKernel& kernel, - const int3& grid, int3* best_work_group); +absl::Status GetBestWorkGroup(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + int3* best_work_group); -Status GetBestWorkGroupConv(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - int3* best_work_group); +absl::Status GetBestWorkGroupConv(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + int3* best_work_group); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc index cd7fe729c7d..4fb21d0ec6a 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + namespace tflite { namespace gpu { namespace cl { @@ -73,29 +75,31 @@ LinearStorageType DeduceLinearStorageType( } } -Status CreateBufferLinearStorage(int size, DataType data_type, void* data, - CLContext* context, LinearStorage* result) { +absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data, + CLContext* context, + LinearStorage* result) { const int float4_size = data_type == DataType::FLOAT32 ? sizeof(float4) : sizeof(half4); *result = LinearStorage(size, LinearStorageType::BUFFER, data_type); RETURN_IF_ERROR(CreateReadOnlyBuffer(float4_size * size, data, context, &result->buffer_storage_)); result->memory_ = result->buffer_storage_.GetMemoryPtr(); - return OkStatus(); + return absl::OkStatus(); } -Status CreateTextureLinearStorage(int size, DataType data_type, void* data, - CLContext* context, LinearStorage* result) { +absl::Status CreateTextureLinearStorage(int size, DataType data_type, + void* data, CLContext* context, + LinearStorage* result) { *result = LinearStorage(size, LinearStorageType::TEXTURE_2D, data_type); RETURN_IF_ERROR(CreateTexture2DRGBA(data_type, size, 1, data, context, &result->texture_storage_)); result->memory_ = result->texture_storage_.GetMemoryPtr(); - return OkStatus(); + return absl::OkStatus(); } -Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, - int size, void* data, CLContext* context, - LinearStorage* result) { +absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, + int size, void* data, CLContext* context, + LinearStorage* result) { if (creation_info.storage_type == LinearStorageType::BUFFER) { return CreateBufferLinearStorage(size, creation_info.data_type, data, context, result); diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.h b/tensorflow/lite/delegates/gpu/cl/linear_storage.h index 3d3d9d5222f..93aecd57854 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.h +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.h @@ -64,12 +64,12 @@ class LinearStorage { std::string GetDeclaration() const; private: - friend Status CreateTextureLinearStorage(int size, DataType data_type, - void* data, CLContext* context, - LinearStorage* result); - friend Status CreateBufferLinearStorage(int size, DataType data_type, - void* data, CLContext* context, - LinearStorage* result); + friend absl::Status CreateTextureLinearStorage(int size, DataType data_type, + void* data, CLContext* context, + LinearStorage* result); + friend absl::Status CreateBufferLinearStorage(int size, DataType data_type, + void* data, CLContext* context, + LinearStorage* result); LinearStorage(int depth, LinearStorageType storage_type, DataType data_type); @@ -83,20 +83,22 @@ class LinearStorage { DataType data_type_; }; -Status CreateBufferLinearStorage(int size, DataType data_type, void* data, - CLContext* context, LinearStorage* result); +absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data, + CLContext* context, + LinearStorage* result); -Status CreateTextureLinearStorage(int size, DataType data_type, void* data, - CLContext* context, LinearStorage* result); +absl::Status CreateTextureLinearStorage(int size, DataType data_type, + void* data, CLContext* context, + LinearStorage* result); -Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, - int size, void* data, CLContext* context, - LinearStorage* result); +absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, + int size, void* data, CLContext* context, + LinearStorage* result); template -Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, - const ::tflite::gpu::Tensor& tensor, - CLContext* context, LinearStorage* result) { +absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, + const ::tflite::gpu::Tensor& tensor, + CLContext* context, LinearStorage* result) { int size = creation_info.aligned_size != 0 ? creation_info.aligned_size : tensor.shape.v; const int depth = IntegralDivideRoundUp(size, 4); @@ -112,7 +114,7 @@ Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, context, result)); } result->SetName(creation_info.name); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc index 3b471ce816c..be551bc9973 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc @@ -31,11 +31,11 @@ namespace cl { function = reinterpret_cast(dlsym(libopencl, #function)); \ } -Status LoadOpenCL() { +absl::Status LoadOpenCL() { void* libopencl = dlopen("libOpenCL.so", RTLD_NOW | RTLD_LOCAL); if (libopencl) { LoadOpenCLFunctions(libopencl, false); - return OkStatus(); + return absl::OkStatus(); } else { // Pixel phone? libopencl = dlopen("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); @@ -45,9 +45,9 @@ Status LoadOpenCL() { reinterpret_cast(dlsym(libopencl, "enableOpenCL")); enableOpenCL(); LoadOpenCLFunctions(libopencl, true); - return OkStatus(); + return absl::OkStatus(); } else { - return UnknownError( + return absl::UnknownError( absl::StrCat("OpenCL library not loaded - ", dlerror())); } } diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h index 16ae24437a3..2201b4c1e5d 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h @@ -27,7 +27,7 @@ namespace tflite { namespace gpu { namespace cl { -Status LoadOpenCL(); +absl::Status LoadOpenCL(); void LoadOpenCLFunctions(void *libopencl, bool is_pixel); typedef cl_int(CL_API_CALL *PFN_clGetPlatformIDs)( diff --git a/tensorflow/lite/delegates/gpu/cl/program_cache.cc b/tensorflow/lite/delegates/gpu/cl/program_cache.cc index e6735b448de..285aa06d99b 100644 --- a/tensorflow/lite/delegates/gpu/cl/program_cache.cc +++ b/tensorflow/lite/delegates/gpu/cl/program_cache.cc @@ -56,7 +56,7 @@ ProgramCache& ProgramCache::operator=(ProgramCache&& program_cache) { return *this; } -Status ProgramCache::GetOrCreateCLKernel( +absl::Status ProgramCache::GetOrCreateCLKernel( const std::string& code, const std::string& function_name, const std::vector& compiler_options, const CLContext& context, const CLDevice& device, CLKernel* result) { @@ -64,32 +64,31 @@ Status ProgramCache::GetOrCreateCLKernel( ProgramDescriptor desc{code, options, use_fingerprints_}; auto it = programs_.find(desc); if (it != programs_.end()) { - RETURN_IF_ERROR(result->CreateFromProgram(it->second, function_name)); - return OkStatus(); + return result->CreateFromProgram(it->second, function_name); } CLProgram program; RETURN_IF_ERROR(CreateCLProgram(code, options, context, device, &program)); RETURN_IF_ERROR(result->CreateFromProgram(program, function_name)); programs_.insert(std::make_pair(std::move(desc), std::move(program))); - return OkStatus(); + return absl::OkStatus(); } -Status ProgramCache::GetOrCreateCLKernel(const std::string& code, - const std::string& function_name, - const CLContext& context, - const CLDevice& device, - CLKernel* result) { +absl::Status ProgramCache::GetOrCreateCLKernel(const std::string& code, + const std::string& function_name, + const CLContext& context, + const CLDevice& device, + CLKernel* result) { return GetOrCreateCLKernel(code, function_name, {}, context, device, result); } -Status ProgramCache::AddSerializedCache( +absl::Status ProgramCache::AddSerializedCache( const CLContext& context, const CLDevice& device, absl::Span serialized_cache) { flatbuffers::Verifier verifier(serialized_cache.data(), serialized_cache.size()); if (!data::VerifyCompiledCacheBuffer(verifier)) { - return InvalidArgumentError("Serialized model is corrupted."); + return absl::InvalidArgumentError("Serialized model is corrupted."); } auto model = data::GetCompiledCache(serialized_cache.data()); @@ -97,7 +96,7 @@ Status ProgramCache::AddSerializedCache( model->driver_version()->size()); if (device.GetPlatformVersion() != platform_version) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "OpenCL driver changed, cache invalid, should be regenerated"); } @@ -116,10 +115,10 @@ Status ProgramCache::AddSerializedCache( programs_.insert(std::make_pair(std::move(desc), std::move(program))); } } - return OkStatus(); + return absl::OkStatus(); } -Status ProgramCache::GetSerializedCache( +absl::Status ProgramCache::GetSerializedCache( const CLDevice& device, std::vector* serialized_cache) const { ::flatbuffers::FlatBufferBuilder builder; std::vector> serialized_programs; @@ -140,9 +139,9 @@ Status ProgramCache::GetSerializedCache( data::FinishCompiledCacheBuffer(builder, cache_builder.Finish()); size_t next_element = serialized_cache->size(); serialized_cache->resize(serialized_cache->size() + builder.GetSize()); - memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(), - builder.GetSize()); - return OkStatus(); + std::memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(), + builder.GetSize()); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/program_cache.h b/tensorflow/lite/delegates/gpu/cl/program_cache.h index b8d019d3d47..21f9583a59a 100644 --- a/tensorflow/lite/delegates/gpu/cl/program_cache.h +++ b/tensorflow/lite/delegates/gpu/cl/program_cache.h @@ -41,20 +41,21 @@ class ProgramCache { ProgramCache(const ProgramCache&) = delete; ProgramCache& operator=(const ProgramCache&) = delete; - Status GetOrCreateCLKernel( + absl::Status GetOrCreateCLKernel( const std::string& code, const std::string& function_name, const std::vector& compiler_options, const CLContext& context, const CLDevice& device, CLKernel* result); - Status GetOrCreateCLKernel(const std::string& code, - const std::string& function_name, - const CLContext& context, const CLDevice& device, - CLKernel* result); + absl::Status GetOrCreateCLKernel(const std::string& code, + const std::string& function_name, + const CLContext& context, + const CLDevice& device, CLKernel* result); - Status AddSerializedCache(const CLContext& context, const CLDevice& device, - absl::Span serialized_cache); - Status GetSerializedCache(const CLDevice& device, - std::vector* serialized_cache) const; + absl::Status AddSerializedCache(const CLContext& context, + const CLDevice& device, + absl::Span serialized_cache); + absl::Status GetSerializedCache(const CLDevice& device, + std::vector* serialized_cache) const; private: struct ProgramDescriptor { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc index a420373f50a..d2d775f819f 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc @@ -29,11 +29,12 @@ namespace gpu { namespace cl { namespace { -Status SelectConvolutionAdreno(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - std::unique_ptr* ptr) { +absl::Status SelectConvolutionAdreno(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + ModelHints hints, + std::unique_ptr* ptr) { if (IsConvConstantsSupported(*creation_context.device, op_def, attr)) { ConvConstants conv; RETURN_IF_ERROR(CreateConvConstants(creation_context, op_def, attr, &conv)); @@ -43,28 +44,24 @@ Status SelectConvolutionAdreno(const Convolution2DAttributes& attr, RETURN_IF_ERROR(CreateConvTexture(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); } - - return OkStatus(); + return absl::OkStatus(); } -Status SelectConvolutionWinogradAdreno(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, - ModelHints hints, - std::unique_ptr* ptr) { +absl::Status SelectConvolutionWinogradAdreno( + const Convolution2DAttributes& attr, const BHWC& dst_shape, + const CreationContext& creation_context, const OperationDef& op_def, + ModelHints hints, std::unique_ptr* ptr) { ConvTexture conv; RETURN_IF_ERROR( CreateConvTextureWino4x4To6x6(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); - - return OkStatus(); + return absl::OkStatus(); } -Status SelectConvolutionNVidia(const Convolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectConvolutionNVidia(const Convolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { if (IsConvConstantsSupported(*creation_context.device, op_def, attr)) { ConvConstants conv; RETURN_IF_ERROR(CreateConvConstants(creation_context, op_def, attr, &conv)); @@ -74,24 +71,24 @@ Status SelectConvolutionNVidia(const Convolution2DAttributes& attr, RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectConvolutionPowerVR(const Convolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectConvolutionPowerVR(const Convolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { ConvPowerVR conv; RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); - return OkStatus(); + return absl::OkStatus(); } -Status SelectConvolutionMali(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectConvolutionMali(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER && IsConvBuffer1x1Supported(op_def, attr)) { ConvBuffer1x1 conv; @@ -104,14 +101,13 @@ Status SelectConvolutionMali(const Convolution2DAttributes& attr, CreateConvPowerVR(creation_context, op_def, attr, &conv, &dst_shape)); *ptr = absl::make_unique(std::move(conv)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectConvolutionWinogradMali(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectConvolutionWinogradMali( + const Convolution2DAttributes& attr, const BHWC& dst_shape, + const CreationContext& creation_context, const OperationDef& op_def, + std::unique_ptr* ptr) { if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) { ConvBuffer1x1 conv; RETURN_IF_ERROR(CreateConvBuffer1x1Wino4x4To6x6(creation_context, op_def, @@ -123,17 +119,16 @@ Status SelectConvolutionWinogradMali(const Convolution2DAttributes& attr, attr, &conv, &dst_shape)); *ptr = absl::make_unique(std::move(conv)); } - - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status SelectConvolution(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - std::unique_ptr* ptr) { +absl::Status SelectConvolution(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectConvolutionAdreno(attr, dst_shape, creation_context, op_def, @@ -152,12 +147,10 @@ Status SelectConvolution(const Convolution2DAttributes& attr, } } -Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, - ModelHints hints, - std::unique_ptr* ptr) { +absl::Status SelectConvolutionForWinograd( + const Convolution2DAttributes& attr, const BHWC& dst_shape, + const CreationContext& creation_context, const OperationDef& op_def, + ModelHints hints, std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectConvolutionWinogradAdreno(attr, dst_shape, creation_context, @@ -169,7 +162,7 @@ Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr, RETURN_IF_ERROR( CreateConvPowerVRWino4x4To6x6(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); - return OkStatus(); + return absl::OkStatus(); } case Vendor::MALI: return SelectConvolutionWinogradMali(attr, dst_shape, creation_context, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h index dc0657ec47c..94723527ad5 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h @@ -28,18 +28,16 @@ namespace tflite { namespace gpu { namespace cl { -Status SelectConvolution(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - std::unique_ptr* ptr); +absl::Status SelectConvolution(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + std::unique_ptr* ptr); -Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, - ModelHints hints, - std::unique_ptr* ptr); +absl::Status SelectConvolutionForWinograd( + const Convolution2DAttributes& attr, const BHWC& dst_shape, + const CreationContext& creation_context, const OperationDef& op_def, + ModelHints hints, std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc index 8dd0ef6b3cb..12e99b57aa7 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc @@ -28,7 +28,7 @@ namespace gpu { namespace cl { namespace { -Status SelectConvolutionTransposedAdreno( +absl::Status SelectConvolutionTransposedAdreno( const ConvolutionTransposedAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, std::unique_ptr* ptr) { @@ -49,10 +49,10 @@ Status SelectConvolutionTransposedAdreno( CreateConvolutionTransposed(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectConvolutionTransposedPowerVR( +absl::Status SelectConvolutionTransposedPowerVR( const ConvolutionTransposedAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, std::unique_ptr* ptr) { @@ -85,10 +85,10 @@ Status SelectConvolutionTransposedPowerVR( CreateConvolutionTransposed(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectConvolutionTransposedMali( +absl::Status SelectConvolutionTransposedMali( const ConvolutionTransposedAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, std::unique_ptr* ptr) { @@ -96,14 +96,15 @@ Status SelectConvolutionTransposedMali( RETURN_IF_ERROR( CreateConvolutionTransposed(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); - return OkStatus(); + return absl::OkStatus(); } + } // namespace -Status SelectConvolutionTransposed(const ConvolutionTransposedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectConvolutionTransposed( + const ConvolutionTransposedAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectConvolutionTransposedAdreno(attr, creation_context, op_def, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h index 50f5e5baad5..ff37c1024ad 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h @@ -26,10 +26,10 @@ namespace tflite { namespace gpu { namespace cl { -Status SelectConvolutionTransposed(const ConvolutionTransposedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectConvolutionTransposed( + const ConvolutionTransposedAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc index 9fe7aa9732e..e2a941870db 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc @@ -28,12 +28,13 @@ namespace tflite { namespace gpu { namespace cl { -Status SelectDefault(const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - const std::vector>*>& inputs, - const std::vector>*>& outputs, - const Node& node, GPUOperationsSubgraph* gpu_subgraph) { - return UnimplementedError( +absl::Status SelectDefault(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const std::vector>*>& inputs, + const std::vector>*>& outputs, + const Node& node, + GPUOperationsSubgraph* gpu_subgraph) { + return absl::UnimplementedError( absl::StrCat("No selector for ", node.operation.type)); } diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h index b4b996cc4fb..05e33501cd4 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h @@ -29,11 +29,12 @@ namespace tflite { namespace gpu { namespace cl { -Status SelectDefault(const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - const std::vector>*>& inputs, - const std::vector>*>& outputs, - const Node& node, GPUOperationsSubgraph* gpu_subgraph); +absl::Status SelectDefault(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const std::vector>*>& inputs, + const std::vector>*>& outputs, + const Node& node, + GPUOperationsSubgraph* gpu_subgraph); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc index 85afa3fff43..0098117dea1 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc @@ -26,10 +26,10 @@ namespace gpu { namespace cl { namespace { -Status SelectDWConvolutionAdreno(const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectDWConvolutionAdreno( + const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + std::unique_ptr* ptr) { if (!op_def.IsBatchSupported() && IsDepthWiseConv3x3Supported(attr)) { DepthWiseConv3x3 dw_conv; RETURN_IF_ERROR( @@ -41,13 +41,13 @@ Status SelectDWConvolutionAdreno(const DepthwiseConvolution2DAttributes& attr, CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); *ptr = absl::make_unique(std::move(dw_conv)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectDWConvolutionPowerVR(const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectDWConvolutionPowerVR( + const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + std::unique_ptr* ptr) { if (!op_def.IsBatchSupported() && IsDepthWiseConv3x3Supported(attr)) { DepthWiseConv3x3 dw_conv; RETURN_IF_ERROR( @@ -59,13 +59,13 @@ Status SelectDWConvolutionPowerVR(const DepthwiseConvolution2DAttributes& attr, CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); *ptr = absl::make_unique(std::move(dw_conv)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectDWConvolutionMali(const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectDWConvolutionMali( + const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + std::unique_ptr* ptr) { const auto storage_type = op_def.src_tensors[0].storage_type; bool buffer_type = storage_type == TensorStorageType::BUFFER || storage_type == TensorStorageType::IMAGE_BUFFER; @@ -83,14 +83,14 @@ Status SelectDWConvolutionMali(const DepthwiseConvolution2DAttributes& attr, CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); *ptr = absl::make_unique(std::move(dw_conv)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectDWConvolutionAdreno(attr, creation_context, op_def, ptr); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h index c15f2946495..7f7cc6da604 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h @@ -26,10 +26,10 @@ namespace tflite { namespace gpu { namespace cl { -Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc index 05d28b412ad..2a04a04460d 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc @@ -27,10 +27,11 @@ namespace tflite { namespace gpu { namespace cl { -Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, int batch_size, - std::unique_ptr* ptr) { +absl::Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + int batch_size, + std::unique_ptr* ptr) { if (op_def.IsBatchSupported()) { ConvTexture conv; RETURN_IF_ERROR(CreateConvTexture(creation_context, op_def, attr, &conv)); @@ -41,13 +42,13 @@ Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr, CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique(std::move(fc)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectFullyConnectedPowerVR(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, int batch_size, - std::unique_ptr* ptr) { +absl::Status SelectFullyConnectedPowerVR( + const FullyConnectedAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + int batch_size, std::unique_ptr* ptr) { if (op_def.IsBatchSupported()) { ConvPowerVR conv; RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv)); @@ -58,13 +59,14 @@ Status SelectFullyConnectedPowerVR(const FullyConnectedAttributes& attr, CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique(std::move(fc)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, int batch_size, - std::unique_ptr* ptr) { +absl::Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + int batch_size, + std::unique_ptr* ptr) { if (op_def.IsBatchSupported()) { if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) { ConvBuffer1x1 conv; @@ -82,13 +84,13 @@ Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr, CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique(std::move(fc)); } - return OkStatus(); + return absl::OkStatus(); } -Status SelectFullyConnected(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, int batch_size, - std::unique_ptr* ptr) { +absl::Status SelectFullyConnected(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, int batch_size, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectFullyConnectedAdreno(attr, creation_context, op_def, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h index 023020b6041..4ae44490996 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h @@ -26,10 +26,10 @@ namespace tflite { namespace gpu { namespace cl { -Status SelectFullyConnected(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, int batch_size, - std::unique_ptr* ptr); +absl::Status SelectFullyConnected(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, int batch_size, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index 2fcb90fc8d1..b0996aa53ea 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -36,6 +36,7 @@ namespace tflite { namespace gpu { namespace cl { namespace { + bool IsWidthBroadcastedForSecondInput( const std::vector>*>& inputs) { return inputs.size() == 2 && @@ -74,14 +75,14 @@ bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr, return suitable_attributes && recommended_channels && recommended_hw; } -Status WinogradFromNode(const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - const BHWC& input_shape, const BHWC& output_shape, - const Convolution2DAttributes& attr, - GPUOperationsSubgraph* gpu_subgraph) { +absl::Status WinogradFromNode(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const BHWC& input_shape, const BHWC& output_shape, + const Convolution2DAttributes& attr, + GPUOperationsSubgraph* gpu_subgraph) { if (!IsSuitableForWinograd4x4To6x6(attr, *creation_context.device, output_shape)) { - return UnimplementedError("No implementation for this case."); + return absl::UnimplementedError("No implementation for this case."); } const int tiles_x = IntegralDivideRoundUp(output_shape.w, 4); @@ -140,18 +141,16 @@ Status WinogradFromNode(const CreationContext& creation_context, } RETURN_IF_ERROR(SelectWinograd36To4x4(creation_context, winograd_down_def, bias_copy, &winograd_down.operation)); - - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status GPUOperationFromNode(const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - const std::vector>*>& inputs, - const std::vector>*>& outputs, - const Node& node, - GPUOperationsSubgraph* gpu_subgraph) { +absl::Status GPUOperationFromNode( + const CreationContext& creation_context, const OperationDef& op_def, + ModelHints hints, const std::vector>*>& inputs, + const std::vector>*>& outputs, const Node& node, + GPUOperationsSubgraph* gpu_subgraph) { std::unique_ptr* gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph); auto op_type = OperationTypeFromString(node.operation.type); @@ -183,7 +182,7 @@ Status GPUOperationFromNode(const CreationContext& creation_context, } SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op); } - return OkStatus(); + return absl::OkStatus(); } } case OperationType::CONCAT: { @@ -202,7 +201,7 @@ Status GPUOperationFromNode(const CreationContext& creation_context, if (WinogradFromNode(creation_context, op_def, hints, input_shape, output_shape, attr, gpu_subgraph) .ok()) { - return OkStatus(); + return absl::OkStatus(); } else { gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph); return SelectConvolution(attr, output_shape, creation_context, op_def, @@ -228,13 +227,13 @@ Status GPUOperationFromNode(const CreationContext& creation_context, } case OperationType::LSTM: { SelectLSTM(op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::MAX_UNPOOLING_2D: { auto attr = absl::any_cast(node.operation.attributes); SelectMaxUnpooling(attr, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::MEAN: { auto attr = absl::any_cast(node.operation.attributes); @@ -256,24 +255,24 @@ Status GPUOperationFromNode(const CreationContext& creation_context, CreateElementwiseTwoInput(op_def, op_type, broadcast); *gpu_op = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } else { - return UnimplementedError( + return absl::UnimplementedError( "No support of multiply with more than 2 inputs"); } - return OkStatus(); + return absl::OkStatus(); } } case OperationType::PAD: { auto attr = absl::any_cast(node.operation.attributes); SelectPadding(attr, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::POOLING_2D: { auto attr = absl::any_cast(node.operation.attributes); SelectPooling(attr, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::PRELU: { auto attr = absl::any_cast(node.operation.attributes); @@ -288,13 +287,13 @@ Status GPUOperationFromNode(const CreationContext& creation_context, case OperationType::RELU: { auto attr = absl::any_cast(node.operation.attributes); SelectReLU(creation_context, attr, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::RESHAPE: { const int src_channels = inputs[0]->tensor.shape.c; auto attr = absl::any_cast(node.operation.attributes); SelectReshape(src_channels, attr.new_shape.c, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::RESIZE: { auto attr = absl::any_cast(node.operation.attributes); @@ -303,23 +302,23 @@ Status GPUOperationFromNode(const CreationContext& creation_context, case OperationType::SLICE: { auto attr = absl::any_cast(node.operation.attributes); SelectStridedSlice(attr, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::SOFTMAX: { SelectSoftmax(inputs[0]->tensor.shape, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::SPACE_TO_DEPTH: { auto attr = absl::any_cast(node.operation.attributes); SelectSpaceToDepth(attr, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::TRANSPOSE: { auto attr = absl::any_cast(node.operation.attributes); SelectTranspose(attr, op_def, gpu_op); - return OkStatus(); + return absl::OkStatus(); } case OperationType::ABS: case OperationType::COS: @@ -335,7 +334,7 @@ Status GPUOperationFromNode(const CreationContext& creation_context, ElementwiseOneInput operation = CreateElementwiseOneInput(op_def, op_type); *gpu_op = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } case OperationType::DIV: case OperationType::MAXIMUM: @@ -352,7 +351,7 @@ Status GPUOperationFromNode(const CreationContext& creation_context, ElementwiseTwoInput operation = CreateElementwiseTwoInput( creation_context, op_def, op_type, broadcast, attr); *gpu_op = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } default: return SelectDefault(creation_context, op_def, hints, inputs, outputs, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h index bcb46c1e0c4..dd09c16dad0 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h @@ -29,12 +29,11 @@ namespace tflite { namespace gpu { namespace cl { -Status GPUOperationFromNode(const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - const std::vector>*>& inputs, - const std::vector>*>& outputs, - const Node& node, - GPUOperationsSubgraph* gpu_subgraph); +absl::Status GPUOperationFromNode( + const CreationContext& creation_context, const OperationDef& op_def, + ModelHints hints, const std::vector>*>& inputs, + const std::vector>*>& outputs, const Node& node, + GPUOperationsSubgraph* gpu_subgraph); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc index ff26a3be601..44a88165e4c 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc @@ -59,14 +59,14 @@ void SelectReLU(const CreationContext& creation_context, *ptr = absl::make_unique(std::move(relu)); } -Status SelectPReLU(const PReLUAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectPReLU(const PReLUAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { PReLU operation; RETURN_IF_ERROR(CreatePReLU(creation_context, op_def, attr, &operation)); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def, @@ -88,31 +88,32 @@ void SelectAdd(const OperationDef& op_def, const std::vector& channels, *ptr = absl::make_unique(std::move(operation)); } -Status SelectResize(const Resize2DAttributes& attr, const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectResize(const Resize2DAttributes& attr, + const OperationDef& op_def, + std::unique_ptr* ptr) { Resize operation = CreateResize(op_def, attr); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } -Status SelectConcat(const ConcatAttributes& attr, - const std::vector& channels, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectConcat(const ConcatAttributes& attr, + const std::vector& channels, + const OperationDef& op_def, + std::unique_ptr* ptr) { switch (attr.axis) { case Axis::CHANNELS: { ConcatZ operation = CreateConcatZ(op_def, channels); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } case Axis::WIDTH: case Axis::HEIGHT: { ConcatXY operation = CreateConcatXY(op_def, attr, channels.size()); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } default: - return UnimplementedError("No concat for this axis."); + return absl::UnimplementedError("No concat for this axis."); } } @@ -147,36 +148,36 @@ void SelectStridedSlice(const SliceAttributes& attr, const OperationDef& op_def, *ptr = absl::make_unique(std::move(operation)); } -Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, + std::unique_ptr* ptr) { if (attr.dims != std::set({Axis::HEIGHT, Axis::WIDTH})) { - return UnimplementedError("Mean operation supports only HW plane"); + return absl::UnimplementedError("Mean operation supports only HW plane"); } Mean operation = CreateMean(op_def); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } -Status SelectMultiplyScalar(const MultiplyAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectMultiplyScalar(const MultiplyAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { MultiplyAdd operation; RETURN_IF_ERROR( CreateMultiplyAdd(creation_context, op_def, attr, &operation)); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } -Status SelectBroadcastAdd(const AddAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectBroadcastAdd(const AddAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { MultiplyAdd operation; RETURN_IF_ERROR( CreateMultiplyAdd(creation_context, op_def, attr, &operation)); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } void SelectSoftmax(const BHWC& shape, const OperationDef& op_def, @@ -197,18 +198,18 @@ void SelectTranspose(const TransposeAttributes& attr, *ptr = absl::make_unique(std::move(operation)); } -Status SelectWinograd4x4To36(const CreationContext& creation_context, - const Padding2D& padding, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectWinograd4x4To36(const CreationContext& creation_context, + const Padding2D& padding, + const OperationDef& op_def, + std::unique_ptr* ptr) { Winograd4x4To36 operation; RETURN_IF_ERROR( CreateWinograd4x4To36(creation_context, op_def, padding, &operation)); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } -Status SelectWinograd36To4x4( +absl::Status SelectWinograd36To4x4( const CreationContext& creation_context, const OperationDef& op_def, const ::tflite::gpu::Tensor& biases, std::unique_ptr* ptr) { @@ -216,18 +217,18 @@ Status SelectWinograd36To4x4( RETURN_IF_ERROR( CreateWinograd36To4x4(creation_context, op_def, biases, &operation)); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } -Status SelectQuantizeAndDequantize(const QuantizeAndDequantizeAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +absl::Status SelectQuantizeAndDequantize( + const QuantizeAndDequantizeAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + std::unique_ptr* ptr) { QuantizeAndDequantize operation; RETURN_IF_ERROR( CreateQuantizeAndDequantize(creation_context, op_def, attr, &operation)); *ptr = absl::make_unique(std::move(operation)); - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h index d9a5365fc9e..118701fe9b0 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h @@ -33,10 +33,10 @@ void SelectReLU(const CreationContext& creation_context, const ReLUAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr); -Status SelectPReLU(const PReLUAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectPReLU(const PReLUAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr); @@ -48,13 +48,14 @@ void SelectMaxUnpooling(const MaxUnpooling2DAttributes& attr, void SelectAdd(const OperationDef& op_def, const std::vector& channels, int dst_channels, std::unique_ptr* ptr); -Status SelectResize(const Resize2DAttributes& attr, const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectResize(const Resize2DAttributes& attr, + const OperationDef& op_def, + std::unique_ptr* ptr); -Status SelectConcat(const ConcatAttributes& attr, - const std::vector& channels, - const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectConcat(const ConcatAttributes& attr, + const std::vector& channels, + const OperationDef& op_def, + std::unique_ptr* ptr); void SelectReshape(int src_channels, int dst_channels, const OperationDef& op_def, @@ -66,18 +67,18 @@ void SelectPadding(const PadAttributes& attr, const OperationDef& op_def, void SelectStridedSlice(const SliceAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr); -Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, + std::unique_ptr* ptr); -Status SelectMultiplyScalar(const MultiplyAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectMultiplyScalar(const MultiplyAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); -Status SelectBroadcastAdd(const AddAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectBroadcastAdd(const AddAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); void SelectSoftmax(const BHWC& shape, const OperationDef& op_def, std::unique_ptr* ptr); @@ -90,20 +91,20 @@ void SelectTranspose(const TransposeAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr); -Status SelectWinograd4x4To36(const CreationContext& creation_context, - const Padding2D& padding, - const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectWinograd4x4To36(const CreationContext& creation_context, + const Padding2D& padding, + const OperationDef& op_def, + std::unique_ptr* ptr); -Status SelectWinograd36To4x4( +absl::Status SelectWinograd36To4x4( const CreationContext& creation_context, const OperationDef& op_def, const ::tflite::gpu::Tensor& biases, std::unique_ptr* ptr); -Status SelectQuantizeAndDequantize(const QuantizeAndDequantizeAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +absl::Status SelectQuantizeAndDequantize( + const QuantizeAndDequantizeAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc index 26eb3ad3538..f6201fa92ca 100644 --- a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc +++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc @@ -24,6 +24,7 @@ limitations under the License. namespace tflite { namespace gpu { namespace cl { + bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, const BHWDC& shape, const TensorDescriptor& descriptor) { diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc index e9de22c6dc0..308e1b69205 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc @@ -27,9 +27,10 @@ namespace tflite { namespace gpu { namespace cl { namespace { -Status CreateImageBufferFromBuffer(const CLContext& context, cl_mem memory, - enum DataType data_type, int width, - cl_mem* result) { + +absl::Status CreateImageBufferFromBuffer(const CLContext& context, + cl_mem memory, enum DataType data_type, + int width, cl_mem* result) { cl_image_format format; cl_image_desc desc; std::memset(&desc, 0, sizeof(desc)); @@ -44,16 +45,17 @@ Status CreateImageBufferFromBuffer(const CLContext& context, cl_mem memory, *result = clCreateImage(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error); if (error != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to create Texture2D (clCreateImage)", CLErrorCodeToString(error))); } - return OkStatus(); + return absl::OkStatus(); } -Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWDC& shape, const TensorDescriptor& descriptor, - cl_mem memory, Tensor* result) { +absl::Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWDC& shape, + const TensorDescriptor& descriptor, cl_mem memory, + Tensor* result) { const bool memory_owner = memory == nullptr; if (memory_owner) { CLMemory mem; @@ -72,8 +74,9 @@ Status CreateTensor(const CLContext& context, const CLDevice& device, } else { *result = Tensor(memory, memory_owner, shape, descriptor); } - return OkStatus(); + return absl::OkStatus(); } + } // namespace Tensor::Tensor(cl_mem memory, bool memory_owner, const BHWC& shape, @@ -156,41 +159,48 @@ int3 Tensor::GetFullTensorRegion() const { } } -Status Tensor::IsValid(const BHWC& shape) const { +absl::Status Tensor::IsValid(const BHWC& shape) const { if (shape.b != shape_.b) { - return InvalidArgumentError("Shape batch does not match tensor batch"); + return absl::InvalidArgumentError( + "Shape batch does not match tensor batch"); } if (shape.w != shape_.w) { - return InvalidArgumentError("Shape width does not match tensor width"); + return absl::InvalidArgumentError( + "Shape width does not match tensor width"); } if (shape.h != shape_.h) { - return InvalidArgumentError("Shape height does not match tensor height"); + return absl::InvalidArgumentError( + "Shape height does not match tensor height"); } if (shape.c != shape_.c) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Shape channels does not match tensor channels"); } - return OkStatus(); + return absl::OkStatus(); } -Status Tensor::IsValid(const BHWDC& shape) const { +absl::Status Tensor::IsValid(const BHWDC& shape) const { if (shape.b != shape_.b) { - return InvalidArgumentError("Shape batch does not match tensor batch"); + return absl::InvalidArgumentError( + "Shape batch does not match tensor batch"); } if (shape.w != shape_.w) { - return InvalidArgumentError("Shape width does not match tensor width"); + return absl::InvalidArgumentError( + "Shape width does not match tensor width"); } if (shape.h != shape_.h) { - return InvalidArgumentError("Shape height does not match tensor height"); + return absl::InvalidArgumentError( + "Shape height does not match tensor height"); } if (shape.d != shape_.d) { - return InvalidArgumentError("Shape depth does not match tensor depth"); + return absl::InvalidArgumentError( + "Shape depth does not match tensor depth"); } if (shape.c != shape_.c) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Shape channels does not match tensor channels"); } - return OkStatus(); + return absl::OkStatus(); } int Tensor::GetChannelsAlignment() const { @@ -230,8 +240,8 @@ cl_mem Tensor::GetMemoryPtr() const { cl_mem Tensor::GetMemoryPtrForWriting() const { return memory_; } -Status Tensor::WriteDataBHWDC(absl::Span in, - CLCommandQueue* queue) { +absl::Status Tensor::WriteDataBHWDC(absl::Span in, + CLCommandQueue* queue) { void* data_ptr = nullptr; const int aligned_channels = GetAlignedChannels(); const int elements_count = @@ -263,24 +273,26 @@ Status Tensor::WriteDataBHWDC(absl::Span in, queue->EnqueueWriteImage(memory_, GetFullTensorRegion(), data_ptr)); break; default: - return InternalError("Unsupported tensor storage type"); + return absl::InternalError("Unsupported tensor storage type"); } - return OkStatus(); + return absl::OkStatus(); } -Status Tensor::WriteData(CLCommandQueue* queue, const TensorFloat32& src) { +absl::Status Tensor::WriteData(CLCommandQueue* queue, + const TensorFloat32& src) { RETURN_IF_ERROR(IsValid(src.shape)); return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue); } -Status Tensor::WriteData(CLCommandQueue* queue, const Tensor5DFloat32& src) { +absl::Status Tensor::WriteData(CLCommandQueue* queue, + const Tensor5DFloat32& src) { RETURN_IF_ERROR(IsValid(src.shape)); return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue); } -Status Tensor::ReadDataBHWDC(absl::Span out, - CLCommandQueue* queue) const { +absl::Status Tensor::ReadDataBHWDC(absl::Span out, + CLCommandQueue* queue) const { void* data_ptr = nullptr; const int aligned_channels = GetAlignedChannels(); const int elements_count = @@ -309,7 +321,7 @@ Status Tensor::ReadDataBHWDC(absl::Span out, queue->EnqueueReadImage(memory_, GetFullTensorRegion(), data_ptr)); break; default: - return InternalError("Unsupported tensor storage type"); + return absl::InternalError("Unsupported tensor storage type"); } if (descriptor_.data_type == DataType::FLOAT32) { @@ -318,57 +330,62 @@ Status Tensor::ReadDataBHWDC(absl::Span out, DataToBHWDC(absl::MakeConstSpan(data_h.data(), data_h.size()), out); } - return OkStatus(); + return absl::OkStatus(); } -Status Tensor::ReadData(CLCommandQueue* queue, TensorFloat32* dst) const { +absl::Status Tensor::ReadData(CLCommandQueue* queue, TensorFloat32* dst) const { RETURN_IF_ERROR(IsValid(dst->shape)); return ReadDataBHWDC(absl::MakeSpan(dst->data), queue); } -Status Tensor::ReadData(CLCommandQueue* queue, Tensor5DFloat32* dst) const { +absl::Status Tensor::ReadData(CLCommandQueue* queue, + Tensor5DFloat32* dst) const { RETURN_IF_ERROR(IsValid(dst->shape)); return ReadDataBHWDC(absl::MakeSpan(dst->data), queue); } -Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWC& shape, const TensorDescriptor& descriptor, - Tensor* result) { +absl::Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWC& shape, const TensorDescriptor& descriptor, + Tensor* result) { const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c); return CreateTensor(context, device, shape5D, descriptor, nullptr, result); } -Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWDC& shape, const TensorDescriptor& descriptor, - Tensor* result) { +absl::Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWDC& shape, + const TensorDescriptor& descriptor, Tensor* result) { return CreateTensor(context, device, shape, descriptor, nullptr, result); } -Status CreateSharedTensor(const CLContext& context, const CLDevice& device, - cl_mem memory, const BHWC& shape, - const TensorDescriptor& descriptor, Tensor* result) { +absl::Status CreateSharedTensor(const CLContext& context, + const CLDevice& device, cl_mem memory, + const BHWC& shape, + const TensorDescriptor& descriptor, + Tensor* result) { const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c); return CreateTensor(context, device, shape5D, descriptor, memory, result); } -Status CreateSharedTensor(const CLContext& context, const CLDevice& device, - cl_mem memory, const BHWDC& shape, - const TensorDescriptor& descriptor, Tensor* result) { +absl::Status CreateSharedTensor(const CLContext& context, + const CLDevice& device, cl_mem memory, + const BHWDC& shape, + const TensorDescriptor& descriptor, + Tensor* result) { return CreateTensor(context, device, shape, descriptor, memory, result); } -Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, - const BHWC& shape, - const TensorDescriptor& descriptor, - CLMemory* result) { +absl::Status AllocateTensorMemory(const CLContext& context, + const CLDevice& device, const BHWC& shape, + const TensorDescriptor& descriptor, + CLMemory* result) { const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c); return AllocateTensorMemory(context, device, shape5D, descriptor, result); } -Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, - const BHWDC& shape, - const TensorDescriptor& descriptor, - CLMemory* result) { +absl::Status AllocateTensorMemory(const CLContext& context, + const CLDevice& device, const BHWDC& shape, + const TensorDescriptor& descriptor, + CLMemory* result) { const int slices = IntegralDivideRoundUp(shape.c, 4); switch (descriptor.storage_type) { case TensorStorageType::BUFFER: @@ -379,12 +396,12 @@ Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, cl_mem memory = clCreateBuffer(context.context(), CL_MEM_READ_WRITE, data_size, nullptr, &error_code); if (!memory) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to allocate device memory with clCreateBuffer", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return OkStatus(); + return absl::OkStatus(); } case TensorStorageType::TEXTURE_2D: { cl_image_desc desc; @@ -406,13 +423,13 @@ Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, cl_mem memory = CreateImage2DLegacy(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to create Texture2D (clCreateImage)", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return OkStatus(); + return absl::OkStatus(); } case TensorStorageType::TEXTURE_3D: { cl_image_desc desc; @@ -434,13 +451,13 @@ Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, cl_mem memory = CreateImage3DLegacy(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to create Texture3D (clCreateImage)", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return OkStatus(); + return absl::OkStatus(); } case TensorStorageType::TEXTURE_ARRAY: { cl_image_desc desc; @@ -463,18 +480,18 @@ Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, cl_mem memory = clCreateImage(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to create TextureArray (clCreateImage)", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return OkStatus(); + return absl::OkStatus(); } case TensorStorageType::SINGLE_TEXTURE_2D: { if (slices != 1) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "SINGLE_TEXTURE_2D support only channels in range [1-4], but ", shape.c, "was provided")); } @@ -495,7 +512,7 @@ Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, format.image_channel_data_type = ToImageChannelType(descriptor.data_type); } else { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "This device doesn't support ", shape.c, "-channel textures.")); } @@ -503,17 +520,17 @@ Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, cl_mem memory = CreateImage2DLegacy(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to create Texture2D (clCreateImage)", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return OkStatus(); + return absl::OkStatus(); } default: - return InternalError("Unsupported tensor storage type"); + return absl::InternalError("Unsupported tensor storage type"); } } diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h index 34a45436386..a27c54a74e5 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.h +++ b/tensorflow/lite/delegates/gpu/cl/tensor.h @@ -87,20 +87,22 @@ class Tensor { // memory ptr. cl_mem GetMemoryPtrForWriting() const; - Status WriteData(CLCommandQueue* queue, const TensorFloat32& src); - Status WriteData(CLCommandQueue* queue, const Tensor5DFloat32& src); - Status ReadData(CLCommandQueue* queue, TensorFloat32* dst) const; - Status ReadData(CLCommandQueue* queue, Tensor5DFloat32* dst) const; + absl::Status WriteData(CLCommandQueue* queue, const TensorFloat32& src); + absl::Status WriteData(CLCommandQueue* queue, const Tensor5DFloat32& src); + absl::Status ReadData(CLCommandQueue* queue, TensorFloat32* dst) const; + absl::Status ReadData(CLCommandQueue* queue, Tensor5DFloat32* dst) const; private: - Status IsValid(const BHWC& shape) const; - Status IsValid(const BHWDC& shape) const; + absl::Status IsValid(const BHWC& shape) const; + absl::Status IsValid(const BHWDC& shape) const; int GetChannelsAlignment() const; int GetAlignedChannels() const; - Status WriteDataBHWDC(absl::Span in, CLCommandQueue* queue); - Status ReadDataBHWDC(absl::Span out, CLCommandQueue* queue) const; + absl::Status WriteDataBHWDC(absl::Span in, + CLCommandQueue* queue); + absl::Status ReadDataBHWDC(absl::Span out, + CLCommandQueue* queue) const; template void DataFromBHWDC(absl::Span src, absl::Span dst) const; @@ -145,31 +147,35 @@ class Tensor { using TensorPtr = std::shared_ptr; -Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, - const BHWC& shape, - const TensorDescriptor& descriptor, - CLMemory* result); +absl::Status AllocateTensorMemory(const CLContext& context, + const CLDevice& device, const BHWC& shape, + const TensorDescriptor& descriptor, + CLMemory* result); -Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, - const BHWDC& shape, - const TensorDescriptor& descriptor, - CLMemory* result); +absl::Status AllocateTensorMemory(const CLContext& context, + const CLDevice& device, const BHWDC& shape, + const TensorDescriptor& descriptor, + CLMemory* result); -Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWC& shape, const TensorDescriptor& descriptor, - Tensor* result); +absl::Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWC& shape, const TensorDescriptor& descriptor, + Tensor* result); -Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWDC& shape, const TensorDescriptor& descriptor, - Tensor* result); - -Status CreateSharedTensor(const CLContext& context, const CLDevice& device, - cl_mem memory, const BHWC& shape, +absl::Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWDC& shape, const TensorDescriptor& descriptor, Tensor* result); -Status CreateSharedTensor(const CLContext& context, const CLDevice& device, - cl_mem memory, const BHWDC& shape, - const TensorDescriptor& descriptor, Tensor* result); +absl::Status CreateSharedTensor(const CLContext& context, + const CLDevice& device, cl_mem memory, + const BHWC& shape, + const TensorDescriptor& descriptor, + Tensor* result); + +absl::Status CreateSharedTensor(const CLContext& context, + const CLDevice& device, cl_mem memory, + const BHWDC& shape, + const TensorDescriptor& descriptor, + Tensor* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_test.cc b/tensorflow/lite/delegates/gpu/cl/tensor_test.cc index 7c859c43e6e..99ba269cf60 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor_test.cc @@ -30,8 +30,9 @@ namespace gpu { namespace cl { namespace { -Status TensorGenericTest(const BHWC& shape, const TensorDescriptor& descriptor, - Environment* env) { +absl::Status TensorGenericTest(const BHWC& shape, + const TensorDescriptor& descriptor, + Environment* env) { TensorFloat32 tensor_cpu; tensor_cpu.shape = shape; tensor_cpu.data.resize(shape.DimensionsProduct()); @@ -53,15 +54,15 @@ Status TensorGenericTest(const BHWC& shape, const TensorDescriptor& descriptor, for (int i = 0; i < tensor_gpu.data.size(); ++i) { if (tensor_gpu.data[i] != tensor_cpu.data[i]) { - return InternalError("Wrong value."); + return absl::InternalError("Wrong value."); } } - return OkStatus(); + return absl::OkStatus(); } -Status Tensor5DGenericTest(const BHWDC& shape, - const TensorDescriptor& descriptor, - Environment* env) { +absl::Status Tensor5DGenericTest(const BHWDC& shape, + const TensorDescriptor& descriptor, + Environment* env) { Tensor5DFloat32 tensor_cpu; tensor_cpu.shape = shape; tensor_cpu.data.resize(shape.DimensionsProduct()); @@ -83,14 +84,14 @@ Status Tensor5DGenericTest(const BHWDC& shape, for (int i = 0; i < tensor_gpu.data.size(); ++i) { if (tensor_gpu.data[i] != tensor_cpu.data[i]) { - return InternalError("Wrong value."); + return absl::InternalError("Wrong value."); } } - return OkStatus(); + return absl::OkStatus(); } -Status TensorTests(DataType data_type, TensorStorageType storage_type, - Environment* env) { +absl::Status TensorTests(DataType data_type, TensorStorageType storage_type, + Environment* env) { RETURN_IF_ERROR(TensorGenericTest( BHWC(1, 6, 7, 3), {data_type, storage_type, Layout::HWC}, env)); RETURN_IF_ERROR(TensorGenericTest( @@ -125,7 +126,7 @@ Status TensorTests(DataType data_type, TensorStorageType storage_type, BHWDC(7, 6, 1, 3, 7), {data_type, storage_type, Layout::BHWDC}, env)); RETURN_IF_ERROR(Tensor5DGenericTest( BHWDC(13, 7, 3, 4, 3), {data_type, storage_type, Layout::BHWDC}, env)); - return OkStatus(); + return absl::OkStatus(); } TEST_F(OpenCLTest, BufferF32) { diff --git a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc index f231cf3143a..151924197c2 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc +++ b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc @@ -45,10 +45,11 @@ class DelegateContext { const TfLiteDelegateParams* delegate_params) { auto denormalized_graph = reinterpret_cast(delegate_params->delegate->data_); - Status status = BuildModel(context, delegate_params, denormalized_graph); + absl::Status status = + BuildModel(context, delegate_params, denormalized_graph); if (!status.ok()) { context->ReportError(context, "Failed to convert a model: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); } return status.ok(); } @@ -82,14 +83,14 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { return status; } -Status FlatBufferToGPUGraph( +absl::Status FlatBufferToGPUGraph( const std::unique_ptr& flatbuffer, GraphFloat32* graph) { tflite::ops::builtin::BuiltinOpResolver op_resolver; std::unique_ptr interpreter; tflite::InterpreterBuilder interpreter_builder(*flatbuffer, op_resolver); if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) { - return InternalError("Unable to prepare TfLite interpreter."); + return absl::InternalError("Unable to prepare TfLite interpreter."); } interpreter->UseNNAPI(false); TfLiteDelegate delegate; @@ -101,20 +102,20 @@ Status FlatBufferToGPUGraph( delegate.FreeBufferHandle = nullptr; if (interpreter->ModifyGraphWithDelegate(&delegate) != kTfLiteOk) { - return InternalError("Conversion from TfLite model failed."); + return absl::InternalError("Conversion from TfLite model failed."); } NullTransformationReporter reporter; ModelTransformer transformer(graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return InternalError("Graph general transformations failed"); + return absl::InternalError("Graph general transformations failed"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status RunModelSample(const std::string& model_name) { +absl::Status RunModelSample(const std::string& model_name) { auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str()); GraphFloat32 graph_cl; RETURN_IF_ERROR(FlatBufferToGPUGraph(flatbuffer, &graph_cl)); @@ -160,7 +161,7 @@ Status RunModelSample(const std::string& model_name) { std::cout << "Total time - " << average_inference_time << "ms" << std::endl; } - return OkStatus(); + return absl::OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/texture2d.cc b/tensorflow/lite/delegates/gpu/cl/texture2d.cc index 907721dad8c..022c15660ce 100644 --- a/tensorflow/lite/delegates/gpu/cl/texture2d.cc +++ b/tensorflow/lite/delegates/gpu/cl/texture2d.cc @@ -21,8 +21,9 @@ namespace cl { namespace { // Creates new 4-channel 2D texture with cl_channel_type elements -Status CreateTexture2D(int width, int height, cl_channel_type type, void* data, - CLContext* context, Texture2D* result) { +absl::Status CreateTexture2D(int width, int height, cl_channel_type type, + void* data, CLContext* context, + Texture2D* result) { cl_image_desc desc; desc.image_type = CL_MEM_OBJECT_IMAGE2D; desc.image_width = width; @@ -47,14 +48,14 @@ Status CreateTexture2D(int width, int height, cl_channel_type type, void* data, cl_mem texture = CreateImage2DLegacy(context->context(), flags, &format, &desc, data, &error_code); if (error_code != CL_SUCCESS) { - return UnknownError( + return absl::UnknownError( absl::StrCat("Failed to create Texture2D (clCreateImage)", CLErrorCodeToString(error_code))); } *result = Texture2D(texture, width, height, type); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -95,20 +96,20 @@ void Texture2D::Release() { } // Creates new 4-channel 2D texture with f32 elements -Status CreateTexture2DRGBA32F(int width, int height, CLContext* context, - Texture2D* result) { +absl::Status CreateTexture2DRGBA32F(int width, int height, CLContext* context, + Texture2D* result) { return CreateTexture2D(width, height, CL_FLOAT, nullptr, context, result); } // Creates new 4-channel 2D texture with f16 elements -Status CreateTexture2DRGBA16F(int width, int height, CLContext* context, - Texture2D* result) { +absl::Status CreateTexture2DRGBA16F(int width, int height, CLContext* context, + Texture2D* result) { return CreateTexture2D(width, height, CL_HALF_FLOAT, nullptr, context, result); } -Status CreateTexture2DRGBA(DataType type, int width, int height, - CLContext* context, Texture2D* result) { +absl::Status CreateTexture2DRGBA(DataType type, int width, int height, + CLContext* context, Texture2D* result) { if (type == DataType::FLOAT32) { return CreateTexture2D(width, height, CL_FLOAT, nullptr, context, result); } else { @@ -117,8 +118,9 @@ Status CreateTexture2DRGBA(DataType type, int width, int height, } } -Status CreateTexture2DRGBA(DataType type, int width, int height, void* data, - CLContext* context, Texture2D* result) { +absl::Status CreateTexture2DRGBA(DataType type, int width, int height, + void* data, CLContext* context, + Texture2D* result) { if (type == DataType::FLOAT32) { return CreateTexture2D(width, height, CL_FLOAT, data, context, result); } else { diff --git a/tensorflow/lite/delegates/gpu/cl/texture2d.h b/tensorflow/lite/delegates/gpu/cl/texture2d.h index bdac984a2db..c12d8a2836c 100644 --- a/tensorflow/lite/delegates/gpu/cl/texture2d.h +++ b/tensorflow/lite/delegates/gpu/cl/texture2d.h @@ -50,11 +50,11 @@ class Texture2D { // Writes data to a texture. Data should point to a region that // has exact width * height * sizeof(pixel) bytes. template - Status WriteData(CLCommandQueue* queue, const absl::Span data); + absl::Status WriteData(CLCommandQueue* queue, const absl::Span data); // Reads data from Texture2D into CPU memory. template - Status ReadData(CLCommandQueue* queue, std::vector* result) const; + absl::Status ReadData(CLCommandQueue* queue, std::vector* result) const; private: void Release(); @@ -68,43 +68,45 @@ class Texture2D { using Texture2DPtr = std::shared_ptr; // Creates new 4-channel 2D texture with f32 elements -Status CreateTexture2DRGBA32F(int width, int height, CLContext* context, - Texture2D* result); +absl::Status CreateTexture2DRGBA32F(int width, int height, CLContext* context, + Texture2D* result); // Creates new 4-channel 2D texture with f16 elements -Status CreateTexture2DRGBA16F(int width, int height, CLContext* context, - Texture2D* result); +absl::Status CreateTexture2DRGBA16F(int width, int height, CLContext* context, + Texture2D* result); -Status CreateTexture2DRGBA(DataType type, int width, int height, - CLContext* context, Texture2D* result); +absl::Status CreateTexture2DRGBA(DataType type, int width, int height, + CLContext* context, Texture2D* result); -Status CreateTexture2DRGBA(DataType type, int width, int height, void* data, - CLContext* context, Texture2D* result); +absl::Status CreateTexture2DRGBA(DataType type, int width, int height, + void* data, CLContext* context, + Texture2D* result); template -Status Texture2D::WriteData(CLCommandQueue* queue, const absl::Span data) { +absl::Status Texture2D::WriteData(CLCommandQueue* queue, + const absl::Span data) { const int element_size = ChannelTypeToSizeInBytes(channel_type_); if (sizeof(T) % element_size != 0) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Template type T has not suitable element type for created texture."); } if (4 * width_ * height_ * element_size != data.size() * sizeof(T)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "absl::Span data size is different from texture allocated size."); } RETURN_IF_ERROR(queue->EnqueueWriteImage(texture_, int3(width_, height_, 1), data.data())); - return OkStatus(); + return absl::OkStatus(); } template -Status Texture2D::ReadData(CLCommandQueue* queue, - std::vector* result) const { +absl::Status Texture2D::ReadData(CLCommandQueue* queue, + std::vector* result) const { const int element_size = ChannelTypeToSizeInBytes(channel_type_); if (sizeof(T) != element_size) { - return InvalidArgumentError("Pixel format is different."); + return absl::InvalidArgumentError("Pixel format is different."); } const int elements_count = width_ * height_ * 4; diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 08612e37b3e..30ac016ff83 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -24,8 +24,8 @@ cc_library( srcs = ["custom_parsers.cc"], hdrs = ["custom_parsers.h"], deps = [ - "//tensorflow/lite/delegates/gpu/common:shape", - "//tensorflow/lite/delegates/gpu/common:status", + ":shape", + ":status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", "@flatbuffers", @@ -193,6 +193,7 @@ cc_test( cc_library( name = "status", hdrs = ["status.h"], + deps = ["@com_google_absl//absl/status"], ) cc_library( diff --git a/tensorflow/lite/delegates/gpu/common/convert.cc b/tensorflow/lite/delegates/gpu/common/convert.cc index 81d09b2797e..cee2e8f0e60 100644 --- a/tensorflow/lite/delegates/gpu/common/convert.cc +++ b/tensorflow/lite/delegates/gpu/common/convert.cc @@ -30,15 +30,15 @@ constexpr int kPhwo4i4ChannelsInPlane = 4; constexpr int kPiohw4ChannelsInPlane = 4; // Layout is Po,H,W,OI4x4. -Status ConvertToPHWO4I4(absl::Span in, const OHWI& shape, - absl::Span out, bool reverse_space) { +absl::Status ConvertToPHWO4I4(absl::Span in, const OHWI& shape, + absl::Span out, bool reverse_space) { if (in.size() != shape.DimensionsProduct()) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertToPHWO4I4: Input data size does not match expected size: ", in.size(), " != ", shape.DimensionsProduct())); } if (out.size() != GetElementsSizeForPHWO4I4(shape)) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertToPHWO4I4: Output data size does not match expected size: ", out.size(), " != ", GetElementsSizeForPHWO4I4(shape))); } @@ -69,7 +69,7 @@ Status ConvertToPHWO4I4(absl::Span in, const OHWI& shape, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -110,15 +110,15 @@ uint3 Get3DSizeForPHWO4I4(const OHWI& shape) { } // Layout is Po,H,W,OI4x4. -Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, - absl::Span out) { +absl::Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, + absl::Span out) { if (in.size() != shape.DimensionsProduct()) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertToPHWO4I4: Input data size does not match expected size: ", in.size(), " != ", shape.DimensionsProduct())); } if (out.size() != GetElementsSizeForPHWO4I4(shape)) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertToPHWO4I4: Output data size does not match expected size: ", out.size(), " != ", GetElementsSizeForPHWO4I4(shape))); } @@ -147,7 +147,7 @@ Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, } } } - return OkStatus(); + return absl::OkStatus(); } std::vector ConvertToPHWO4I4( @@ -164,15 +164,15 @@ uint32_t GetElementsSizeForPIOHW4(const OHWI& shape) { shape.w; } -Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, - absl::Span out) { +absl::Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, + absl::Span out) { if (in.size() != shape.DimensionsProduct()) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertToPIOHW4: Input data size does not match expected size: ", in.size(), " != ", shape.DimensionsProduct())); } if (out.size() != GetElementsSizeForPIOHW4(shape)) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertToPIOHW4: Output data size does not match expected size: ", out.size(), " != ", GetElementsSizeForPIOHW4(shape))); } @@ -194,7 +194,7 @@ Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, } } } - return OkStatus(); + return absl::OkStatus(); } std::vector ConvertToPIOHW4( @@ -207,29 +207,29 @@ std::vector ConvertToPIOHW4( } template -Status ValidateConvertToPHWC4(absl::Span in, const BHWC& shape, - absl::Span out) { +absl::Status ValidateConvertToPHWC4(absl::Span in, + const BHWC& shape, absl::Span out) { if (in.size() != shape.DimensionsProduct()) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertToPHWC4: Input data size does not match expected size: ", in.size(), " != ", shape.DimensionsProduct())); } if (out.size() != GetElementsSizeForPHWC4(shape)) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertToPHWC4: Output data size does not match expected size: ", out.size(), " != ", GetElementsSizeForPHWC4(shape))); } - return OkStatus(); + return absl::OkStatus(); } // Layout is Pc,H,W,C4 where P - is a plane based on channels. -Status ConvertToPHWC4(absl::Span in, const BHWC& shape, - absl::Span out) { +absl::Status ConvertToPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { RETURN_IF_ERROR(ValidateConvertToPHWC4(in, shape, out)); if (shape.c == 4) { std::memcpy(out.data(), in.data(), shape.DimensionsProduct() * sizeof(float)); - return OkStatus(); + return absl::OkStatus(); } // Layout is Pc,H,W,C4 where P - is a plane based on channels. int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); @@ -256,7 +256,7 @@ Status ConvertToPHWC4(absl::Span in, const BHWC& shape, const int remaining_channels = shape.c - num_full_planes * kPhwc4ChannelsInPlane; if (remaining_channels == 0) { - return OkStatus(); + return absl::OkStatus(); } for (int b = 0; b < shape.b; b++) { const float* src = @@ -272,12 +272,12 @@ Status ConvertToPHWC4(absl::Span in, const BHWC& shape, dest += kPhwc4ChannelsInPlane; } } - return OkStatus(); + return absl::OkStatus(); } // Layout is Pc,H,W,C4 where P - is a plane based on channels. -Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, - absl::Span out) { +absl::Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, + absl::Span out) { RETURN_IF_ERROR(ValidateConvertToPHWC4(in, shape, out)); // Layout is Pc,H,W,C4 where P - is a plane based on channels. @@ -308,7 +308,7 @@ Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, const int remaining_channels = shape.c - num_full_planes * kPhwc4ChannelsInPlane; if (remaining_channels == 0) { - return OkStatus(); + return absl::OkStatus(); } for (int b = 0; b < shape.b; b++) { @@ -349,11 +349,11 @@ Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, } break; default: - return UnimplementedError( + return absl::UnimplementedError( "ConvertToPHWC4Half: Unsupported channels per planes count."); } } - return OkStatus(); + return absl::OkStatus(); } std::vector ConvertToPHWC4( @@ -383,28 +383,28 @@ uint32_t GetElementsSizeForPHWC4(const BHWC& shape) { } template -Status ValidateConvertFromPHWC4(absl::Span in, const BHWC& shape, - absl::Span out) { +absl::Status ValidateConvertFromPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { if (in.size() != GetElementsSizeForPHWC4(shape)) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertFromPHWC4: Input data size does not match expected size: ", in.size(), " != ", GetElementsSizeForPHWC4(shape))); } if (out.size() != shape.DimensionsProduct()) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ConvertFromPHWC4: Output data size does not match expected size: ", out.size(), " != ", shape.DimensionsProduct())); } - return OkStatus(); + return absl::OkStatus(); } -Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, - absl::Span out) { +absl::Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { RETURN_IF_ERROR(ValidateConvertFromPHWC4(in, shape, out)); if (shape.c == 4) { std::memcpy(out.data(), in.data(), shape.DimensionsProduct() * sizeof(float)); - return OkStatus(); + return absl::OkStatus(); } int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); @@ -429,7 +429,7 @@ Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, const int remaining_channels = shape.c - num_full_planes * kPhwc4ChannelsInPlane; if (remaining_channels == 0) { - return OkStatus(); + return absl::OkStatus(); } for (int b = 0; b < shape.b; b++) { const float* src = in.data() + b * padded_size + @@ -443,11 +443,11 @@ Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, dest += shape.c; } } - return OkStatus(); + return absl::OkStatus(); } -Status ConvertFromPHWC4Half(absl::Span in, const BHWC& shape, - absl::Span out) { +absl::Status ConvertFromPHWC4Half(absl::Span in, + const BHWC& shape, absl::Span out) { RETURN_IF_ERROR(ValidateConvertFromPHWC4(in, shape, out)); int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); const int num_pixels = shape.h * shape.w; @@ -474,7 +474,7 @@ Status ConvertFromPHWC4Half(absl::Span in, const BHWC& shape, const int remaining_channels = shape.c - num_full_planes * kPhwc4ChannelsInPlane; if (remaining_channels == 0) { - return OkStatus(); + return absl::OkStatus(); } for (int b = 0; b < shape.b; b++) { const HalfBits* src = in.data() + b * padded_size + @@ -508,11 +508,11 @@ Status ConvertFromPHWC4Half(absl::Span in, const BHWC& shape, } break; default: - return UnimplementedError( + return absl::UnimplementedError( "ConvertToPHWC4Half: Unsupported channels per planes count."); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/convert.h b/tensorflow/lite/delegates/gpu/common/convert.h index 30a0a5f3183..3aba9c913c5 100644 --- a/tensorflow/lite/delegates/gpu/common/convert.h +++ b/tensorflow/lite/delegates/gpu/common/convert.h @@ -29,19 +29,19 @@ namespace gpu { // PHWC4 layout is where channels are grouped by 4 in a row and P stands for // a plane that was derived by dividing channels by 4. -Status ConvertToPHWC4(absl::Span in, const BHWC& shape, - absl::Span out); -Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, - absl::Span out); +absl::Status ConvertToPHWC4(absl::Span in, const BHWC& shape, + absl::Span out); +absl::Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, + absl::Span out); // @return number of elements when shape is converted into PHWC4. uint32_t GetElementsSizeForPHWC4(const BHWC& shape); // Operation is opposite to ConvertToPHWC4. -Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, - absl::Span out); -Status ConvertFromPHWC4Half(absl::Span in, const BHWC& shape, - absl::Span out); +absl::Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, + absl::Span out); +absl::Status ConvertFromPHWC4Half(absl::Span in, + const BHWC& shape, absl::Span out); // Convenience wrapper around a method above. std::vector ConvertToPHWC4( @@ -53,8 +53,8 @@ uint32_t GetElementsSizeForPIOHW4(const OHWI& shape); // PIOHW4 layout re-arranges weights in groups by 4, where outer dimension is // P which is OxI/4. -Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, - absl::Span out); +absl::Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, + absl::Span out); // Convenience wrapper around a method above. std::vector ConvertToPIOHW4( @@ -79,8 +79,8 @@ uint3 Get3DSizeForPHWO4I4(const OHWI& shape); uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape); // Layout is Po,H,W,OI4x4. -Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, - absl::Span out); +absl::Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, + absl::Span out); // Convenience wrapper around a method above. std::vector ConvertToPHWO4I4( diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.cc b/tensorflow/lite/delegates/gpu/common/custom_parsers.cc index d46a9247c81..e43cba05525 100644 --- a/tensorflow/lite/delegates/gpu/common/custom_parsers.cc +++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.cc @@ -25,10 +25,10 @@ limitations under the License. namespace tflite { namespace gpu { -Status ParseCustomAttributes(absl::string_view op_name, const void* data, - uint32_t data_size, absl::any* attr, - BHWC* output_shape) { - return UnimplementedError(absl::StrCat( +absl::Status ParseCustomAttributes(absl::string_view op_name, const void* data, + uint32_t data_size, absl::any* attr, + BHWC* output_shape) { + return absl::UnimplementedError(absl::StrCat( "Attributes parsing is not enabled for ", op_name, " operation")); } diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.h b/tensorflow/lite/delegates/gpu/common/custom_parsers.h index e9a191d46cb..707087e6fdb 100644 --- a/tensorflow/lite/delegates/gpu/common/custom_parsers.h +++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.h @@ -27,9 +27,9 @@ namespace gpu { // Matches the custom operation by the string name and parses attributes stored // as flexbuffers. -Status ParseCustomAttributes(absl::string_view op_name, const void* data, - uint32_t data_size, absl::any* attr, - BHWC* output_shape); +absl::Status ParseCustomAttributes(absl::string_view op_name, const void* data, + uint32_t data_size, absl::any* attr, + BHWC* output_shape); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.cc b/tensorflow/lite/delegates/gpu/common/memory_management.cc index 5cfd26b1832..d7e6a060eb2 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management.cc @@ -55,8 +55,9 @@ OffsetsAssignment ObjectsToOffsets( return result; } -Status BestGreedy(const std::vector>& usage_records, - ObjectsAssignment* assignment) { +absl::Status BestGreedy( + const std::vector>& usage_records, + ObjectsAssignment* assignment) { RETURN_IF_ERROR( GreedyBySizeDistPriorityAssignment(usage_records, assignment)); ObjectsAssignment assignment_by_breadth; @@ -64,11 +65,11 @@ Status BestGreedy(const std::vector>& usage_records, TotalSize(assignment_by_breadth) < TotalSize(*assignment)) { std::swap(*assignment, assignment_by_breadth); } - return OkStatus(); + return absl::OkStatus(); } template <> -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -89,14 +90,14 @@ Status AssignObjectsToTensors( case MemoryStrategy::MINCOSTFLOW: return MinCostFlowAssignment(usage_records, assignment); default: - return InternalError( + return absl::InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return OkStatus(); + return absl::OkStatus(); } template <> -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -106,14 +107,14 @@ Status AssignObjectsToTensors( case MemoryStrategy::EQUALITY: return EqualityAssignmentWithHash(usage_records, assignment); default: - return InternalError( + return absl::InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return OkStatus(); + return absl::OkStatus(); } template <> -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -125,14 +126,14 @@ Status AssignObjectsToTensors( case MemoryStrategy::GREEDY_IN_ORDER: return GreedyInOrderAssignmentMultidimensional(usage_records, assignment); default: - return InternalError( + return absl::InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return OkStatus(); + return absl::OkStatus(); } template <> -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -144,13 +145,13 @@ Status AssignObjectsToTensors( case MemoryStrategy::GREEDY_IN_ORDER: return GreedyInOrderAssignmentMultidimensional(usage_records, assignment); default: - return InternalError( + return absl::InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return OkStatus(); + return absl::OkStatus(); } -Status AssignOffsetsToTensors( +absl::Status AssignOffsetsToTensors( const std::vector>& usage_records, const MemoryStrategy& strategy, OffsetsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -161,7 +162,7 @@ Status AssignOffsetsToTensors( RETURN_IF_ERROR(AssignObjectsToTensors( usage_records, strategy, &objects_assignment, reallocation_graph)); *assignment = ObjectsToOffsets(objects_assignment); - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.h b/tensorflow/lite/delegates/gpu/common/memory_management.h index e45c361d955..7df4947ee3d 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management.h @@ -79,8 +79,9 @@ enum class MemoryStrategy { // Chooses greedy algorithm with the lowest memory consumption for given usage // records and returns corresponding shared objects assignment. -Status BestGreedy(const std::vector>& usage_records, - ObjectsAssignment* assignment); +absl::Status BestGreedy( + const std::vector>& usage_records, + ObjectsAssignment* assignment); // Calculates the assignment of shared objects to given tensors, including // objects' sizes. Below there are specializations for different types, that @@ -90,7 +91,7 @@ Status BestGreedy(const std::vector>& usage_records, // can be larger. Currently only GREEDY_IN_ORDER strategy can use this // reallocation_graph. template -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph = nullptr) { @@ -100,39 +101,39 @@ Status AssignObjectsToTensors( case MemoryStrategy::EQUALITY: return EqualityAssignment(usage_records, assignment); default: - return InternalError( + return absl::InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return OkStatus(); + return absl::OkStatus(); } template <> -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); template <> -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); template <> -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); template <> -Status AssignObjectsToTensors( +absl::Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); // Calculates the assignment of tensors to offsets, considering those tensors // are going to be allocated in one continuous memory block. -Status AssignOffsetsToTensors( +absl::Status AssignOffsetsToTensors( const std::vector>& usage_records, const MemoryStrategy& strategy, OffsetsAssignment* assignment, const UsageGraph* reallocation_graph = nullptr); diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h index 0955393e00c..fdccce5159f 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h @@ -29,7 +29,7 @@ namespace gpu { // Fast version of Equality Assignments for hashable types. template -Status EqualityAssignmentWithHash( +absl::Status EqualityAssignmentWithHash( const std::vector>& usage_records, ObjectsAssignment* assignment) { size_t num_records = usage_records.size(); @@ -69,12 +69,12 @@ Status EqualityAssignmentWithHash( {usage_records[i].last_task, assignment->object_ids[i]}); } } - return OkStatus(); + return absl::OkStatus(); } // Slower version of Equality Assignments for unhashable types. template -Status EqualityAssignment( +absl::Status EqualityAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { size_t num_records = usage_records.size(); @@ -109,7 +109,7 @@ Status EqualityAssignment( dealloc_task[best_obj] = usage_records[i].last_task; } } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc index 5d0f6b620b0..2c138b4c14c 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc @@ -46,7 +46,7 @@ struct TaskBreadthWithId { } // namespace -Status GreedyByBreadthAssignment( +absl::Status GreedyByBreadthAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { std::vector task_profiles = CalculateTaskProfiles(usage_records); @@ -133,10 +133,10 @@ Status GreedyByBreadthAssignment( // In the end all tensors must be assigned to some objects. for (const auto& obj_id : assignment->object_ids) { if (obj_id == kNotAssigned) { - return InternalError("Error while calculating the assignment."); + return absl::InternalError("Error while calculating the assignment."); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h index c139ba0fe0f..47035229920 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h @@ -44,7 +44,7 @@ namespace gpu { // tensor’s size, assign current tensor to the smallest of them; // - If there are suitable objects only with size less than current tensor’s // size, assign current tensor to the largest of them and increase its size. -Status GreedyByBreadthAssignment( +absl::Status GreedyByBreadthAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment); diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc index bf56c6d92dd..76309ce8f1b 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc @@ -60,7 +60,7 @@ struct SizeDistPriorityInfo { } // namespace -Status GreedyBySizeAssignment( +absl::Status GreedyBySizeAssignment( const std::vector>& usage_records, OffsetsAssignment* assignment) { const size_t num_tensors = usage_records.size(); @@ -104,7 +104,7 @@ Status GreedyBySizeAssignment( prev_offset, cur_offset + usage_records[allocated_id].tensor_size); } if (assignment->total_size < prev_offset) { - return InternalError("Total size is wrong."); + return absl::InternalError("Total size is wrong."); } // If no suitable gap found, we should allocate current tensor after the @@ -125,7 +125,7 @@ Status GreedyBySizeAssignment( assignment->total_size = std::max(assignment->total_size, best_offset + rec->tensor_size); } - return OkStatus(); + return absl::OkStatus(); } // Assigns given tensors to shared objects, using the following greedy @@ -152,7 +152,7 @@ Status GreedyBySizeAssignment( // object with size equal to current tensor's size; // - Modify SizeDistPriority records of tensors, that haven't been assigned yet, // to reflect distance changes after that assignment. -Status GreedyBySizeDistPriorityAssignment( +absl::Status GreedyBySizeDistPriorityAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { std::vector positional_max = @@ -175,7 +175,7 @@ Status GreedyBySizeDistPriorityAssignment( ++pos; } if (pos == 0) { - return InternalError("Variable pos must be positive."); + return absl::InternalError("Variable pos must be positive."); } priority_info[rec_id].position = pos - 1; } @@ -198,7 +198,7 @@ Status GreedyBySizeDistPriorityAssignment( if (best_info_id == kNotAssigned) { // During each iteration we assign exactly one of the tensors, so some not // yet assigned tensors must exist. - return InternalError("Invalid value for variable best_info_id."); + return absl::InternalError("Invalid value for variable best_info_id."); } size_t best_rec_id = priority_info[best_info_id].tensor_usage_id; @@ -271,7 +271,7 @@ Status GreedyBySizeDistPriorityAssignment( } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h index fb875fd0920..b0ad9d18911 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h @@ -38,7 +38,7 @@ namespace gpu { // gap. Otherwise we can allocate it after the rightmost tensor, which usage // interval intersects with usage interval of current tensor. So we assign // corresponding offset to current tensor and the tensor becomes assigned. -Status GreedyBySizeAssignment( +absl::Status GreedyBySizeAssignment( const std::vector>& usage_records, OffsetsAssignment* assignment); @@ -66,7 +66,7 @@ Status GreedyBySizeAssignment( // object with size equal to current tensor's size; // - Modify SizeDistPriority records of tensors, that haven't been assigned yet, // to reflect distance changes after that assignment. -Status GreedyBySizeDistPriorityAssignment( +absl::Status GreedyBySizeDistPriorityAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment); diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h index b454920ffcb..8c3719e4a8b 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h @@ -46,7 +46,7 @@ namespace gpu { // // 3. Shared object size may increase when tensor requests larger size. template -Status GreedyInOrderAssignment( +absl::Status GreedyInOrderAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph = nullptr) { @@ -111,7 +111,7 @@ Status GreedyInOrderAssignment( } // best_it can't be equal to pool.end(), because pool is not empty if (best_it == pool.end()) { - return InternalError( + return absl::InternalError( "No shared object is found in non-empty pool in " "GreedyInOrderAssignment."); } @@ -135,14 +135,14 @@ Status GreedyInOrderAssignment( {usage_records[i].last_task, assignment->object_ids[i]}); } } - return OkStatus(); + return absl::OkStatus(); } // The same algorithm as above, but for multidimensional case. The only // difference is that shared object dimensions can't be increased to be reused // for tensor, that is larger (at least by one dimension). template -Status GreedyInOrderAssignmentMultidimensional( +absl::Status GreedyInOrderAssignmentMultidimensional( const std::vector>& usage_records, ObjectsAssignment* assignment) { size_t num_records = usage_records.size(); @@ -198,7 +198,7 @@ Status GreedyInOrderAssignmentMultidimensional( {usage_records[i].last_task, assignment->object_ids[i]}); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc index ab15af88429..059c23fab33 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc @@ -211,14 +211,14 @@ class MinCostFlowSolver { // auxiliary flow graph, find minimum-cost flow in it and calculates the // assignment of shared objects to tensors, using the result of the flow // algorithm. -Status MinCostFlowAssignment( +absl::Status MinCostFlowAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { MinCostFlowSolver solver; solver.Build(usage_records); solver.Solve(); solver.CalculateAssignment(assignment); - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h index 7e45f83c79e..1284c12c5c2 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h @@ -30,7 +30,7 @@ namespace gpu { // auxiliary flow graph, find minimum-cost flow in it and calculates the // assignment of shared objects to tensors, using the result of the flow // algorithm. -Status MinCostFlowAssignment( +absl::Status MinCostFlowAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment); diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h index 94cd41ed9a5..8a00c67d853 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h @@ -30,7 +30,7 @@ namespace gpu { // The problem of memory management is NP-complete. This implements a // naive algorithm that assigns each tensor to a separate object in memory. template -Status NaiveAssignment( +absl::Status NaiveAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { assignment->object_sizes.resize(usage_records.size()); @@ -40,7 +40,7 @@ Status NaiveAssignment( assignment->object_ids[i] = i; assignment->object_sizes[i] = record.tensor_size; } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/model.h b/tensorflow/lite/delegates/gpu/common/model.h index 6989584a24c..2e38bcc5f3f 100644 --- a/tensorflow/lite/delegates/gpu/common/model.h +++ b/tensorflow/lite/delegates/gpu/common/model.h @@ -136,33 +136,33 @@ class Graph { // for a value. If a value had another producer, it will reassign producer // appropriately. If a value didn't have a producer, it will be removed // from a graph's input. - virtual Status SetProducer(NodeId producer, ValueId value) = 0; + virtual absl::Status SetProducer(NodeId producer, ValueId value) = 0; // Removes a producer for the given value. Value becomes producer-less and // therefore becomes graph's input. - virtual Status RemoveProducer(ValueId value) = 0; + virtual absl::Status RemoveProducer(ValueId value) = 0; // Sets a consumer for the given value. There could be multiple consumers // for a value. - virtual Status AddConsumer(NodeId consumer, ValueId value) = 0; + virtual absl::Status AddConsumer(NodeId consumer, ValueId value) = 0; // Replace input value for given node. - virtual Status ReplaceInput(NodeId node, ValueId old_value, - ValueId new_value) = 0; + virtual absl::Status ReplaceInput(NodeId node, ValueId old_value, + ValueId new_value) = 0; // Removes a consumer for the given value. If value does not have any // consumers it becomes graph's output. - virtual Status RemoveConsumer(NodeId consumer, ValueId value) = 0; + virtual absl::Status RemoveConsumer(NodeId consumer, ValueId value) = 0; // Removes node from this graph. For all input values this node will be // removed from consumers and for all output values a producer will be // removed. - virtual Status DeleteNode(NodeId id) = 0; + virtual absl::Status DeleteNode(NodeId id) = 0; // Removes value from this graph. It will be removed from inputs for all // dependent nodes. A node that was a producer of this value will loose its // output. - virtual Status DeleteValue(ValueId id) = 0; + virtual absl::Status DeleteValue(ValueId id) = 0; }; // Implementation of a Graph interface. It keeps values and nodes referenced by @@ -268,7 +268,7 @@ class Model : public Graph { return values_[id].consumers; } - Status SetProducer(NodeId producer, ValueId value) final { + absl::Status SetProducer(NodeId producer, ValueId value) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(value, &v)); Value* value_ptr = v->value.get(); @@ -278,12 +278,13 @@ class Model : public Graph { // check if this value has the same producer already if (node_ptr == v->producer) { - return InvalidArgumentError("Node is already a producer of the value"); + return absl::InvalidArgumentError( + "Node is already a producer of the value"); } // Check if the node is a consumer of this value. if (IsInput(producer, value)) { - return InvalidArgumentError("Node is a consumer of the value"); + return absl::InvalidArgumentError("Node is a consumer of the value"); } // TODO(akulik): detect circular dependency? @@ -293,22 +294,23 @@ class Model : public Graph { } v->producer = node_ptr; n->outputs.push_back(value_ptr); - return OkStatus(); + return absl::OkStatus(); } - Status RemoveProducer(ValueId value) final { + absl::Status RemoveProducer(ValueId value) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(value, &v)); Value* value_ptr = v->value.get(); if (v->producer == nullptr) { - return InvalidArgumentError("Value does not have a producer"); + return absl::InvalidArgumentError("Value does not have a producer"); } Erase(&nodes_[v->producer->id].outputs, value_ptr); v->producer = nullptr; - return OkStatus(); + return absl::OkStatus(); } - Status ReplaceInput(NodeId node, ValueId old_value, ValueId new_value) final { + absl::Status ReplaceInput(NodeId node, ValueId old_value, + ValueId new_value) final { ValueDef* v_old; RETURN_IF_ERROR(LookupValue(old_value, &v_old)); Value* value_old_ptr = v_old->value.get(); @@ -321,17 +323,17 @@ class Model : public Graph { // Check if the node is a consumer of old_value. if (!IsInput(node, old_value)) { - return InvalidArgumentError("old_value must be input of node."); + return absl::InvalidArgumentError("old_value must be input of node."); } // Check if the node is not a consumer of new_value. if (IsInput(node, new_value)) { - return InvalidArgumentError("new_value can not be input of node."); + return absl::InvalidArgumentError("new_value can not be input of node."); } // Check if this value has the same producer already if (node_ptr == v_new->producer) { - return InvalidArgumentError("new_value can not be output of node."); + return absl::InvalidArgumentError("new_value can not be output of node."); } for (int i = 0; i < n->inputs.size(); ++i) { @@ -342,10 +344,10 @@ class Model : public Graph { } v_new->consumers.push_back(node_ptr); Erase(&v_old->consumers, node_ptr); - return OkStatus(); + return absl::OkStatus(); } - Status AddConsumer(NodeId consumer, ValueId value) final { + absl::Status AddConsumer(NodeId consumer, ValueId value) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(value, &v)); Value* value_ptr = v->value.get(); @@ -355,20 +357,21 @@ class Model : public Graph { // check if this value has the same producer already if (node_ptr == v->producer) { - return InvalidArgumentError("Node is a producer of the value"); + return absl::InvalidArgumentError("Node is a producer of the value"); } // check if this value has the same consumer already if (IsInput(consumer, value)) { - return InvalidArgumentError("Node is already a consumer of the value"); + return absl::InvalidArgumentError( + "Node is already a consumer of the value"); } n->inputs.push_back(value_ptr); v->consumers.push_back(node_ptr); - return OkStatus(); + return absl::OkStatus(); } - Status RemoveConsumer(NodeId consumer, ValueId value) final { + absl::Status RemoveConsumer(NodeId consumer, ValueId value) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(value, &v)); Value* value_ptr = v->value.get(); @@ -376,14 +379,14 @@ class Model : public Graph { RETURN_IF_ERROR(LookupNode(consumer, &n)); Node* node_ptr = n->node.get(); if (!IsInput(consumer, value)) { - return InvalidArgumentError("Node is not a consumer of the value"); + return absl::InvalidArgumentError("Node is not a consumer of the value"); } Erase(&n->inputs, value_ptr); Erase(&v->consumers, node_ptr); - return OkStatus(); + return absl::OkStatus(); } - Status DeleteNode(NodeId id) final { + absl::Status DeleteNode(NodeId id) final { NodeDef* n; RETURN_IF_ERROR(LookupNode(id, &n)); Node* node_ptr = n->node.get(); @@ -396,10 +399,10 @@ class Model : public Graph { n->inputs.clear(); n->outputs.clear(); n->node.reset(); - return OkStatus(); + return absl::OkStatus(); } - Status DeleteValue(ValueId id) final { + absl::Status DeleteValue(ValueId id) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(id, &v)); Value* value_ptr = v->value.get(); @@ -414,10 +417,10 @@ class Model : public Graph { v->producer = nullptr; v->consumers.clear(); v->value.reset(); - return OkStatus(); + return absl::OkStatus(); } - Status MakeExactCopy(Model* model) const { + absl::Status MakeExactCopy(Model* model) const { model->nodes_.clear(); model->values_.clear(); model->name_ = name_; @@ -440,7 +443,7 @@ class Model : public Graph { } } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -475,29 +478,29 @@ class Model : public Graph { } // @return non-nullptr NodeDef that has valid Node or an error - Status LookupNode(NodeId id, NodeDef** node_def) { + absl::Status LookupNode(NodeId id, NodeDef** node_def) { if (id >= nodes_.size()) { - return OutOfRangeError("NodeId is out of range"); + return absl::OutOfRangeError("NodeId is out of range"); } auto& n = nodes_[id]; if (!n.node) { - return OutOfRangeError("Node is already deleted"); + return absl::OutOfRangeError("Node is already deleted"); } *node_def = &n; - return OkStatus(); + return absl::OkStatus(); } // @return non-nullptr ValueDef that has valid Value or an error - Status LookupValue(ValueId id, ValueDef** value_def) { + absl::Status LookupValue(ValueId id, ValueDef** value_def) { if (id >= values_.size()) { - return OutOfRangeError("ValueId is out of range"); + return absl::OutOfRangeError("ValueId is out of range"); } auto& v = values_[id]; if (!v.value) { - return OutOfRangeError("Value is already deleted"); + return absl::OutOfRangeError("Value is already deleted"); } *value_def = &v; - return OkStatus(); + return absl::OkStatus(); } template @@ -537,14 +540,14 @@ class Model : public Graph { // outputs that are consumed only by to_keep. In such case to_keep inherits all // to_remove inputs. template -Status RemovePrecedingNode(Graph* graph, const Node* to_remove, - const Node* to_keep) { +absl::Status RemovePrecedingNode(Graph* graph, const Node* to_remove, + const Node* to_keep) { // Make sure all outputs from to_remove are consumed by to_keep. for (auto output : graph->FindOutputs(to_remove->id)) { auto consumers = graph->FindConsumers(output->id); if (consumers.size() > 1 || (consumers.size() == 1 && consumers[0] != to_keep)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Output from to_remove node has other consumers"); } } @@ -562,13 +565,13 @@ Status RemovePrecedingNode(Graph* graph, const Node* to_remove, // Removes to_remove node that follows to_keep node only if to_remove has inputs // that are produced by to_keep. to_keep inherits all to_remove inputs. template -Status RemoveFollowingNode(Graph* graph, const Node* to_remove, - const Node* to_keep) { +absl::Status RemoveFollowingNode(Graph* graph, const Node* to_remove, + const Node* to_keep) { // Make sure all inputs to to_remove are produced by to_keep. for (auto input : graph->FindInputs(to_remove->id)) { Node* producer = graph->FindProducer(input->id); if (producer->id != to_keep->id) { - return InvalidArgumentError("To_remove node has other inputs"); + return absl::InvalidArgumentError("To_remove node has other inputs"); } } @@ -584,12 +587,12 @@ Status RemoveFollowingNode(Graph* graph, const Node* to_remove, // Removes to_remove node. // Requires that node has one input and one output; template -Status RemoveOneInputOneOutputNode(Graph* graph, - const Node* to_remove) { +absl::Status RemoveOneInputOneOutputNode(Graph* graph, + const Node* to_remove) { auto inputs = graph->FindInputs(to_remove->id); auto outputs = graph->FindOutputs(to_remove->id); if (inputs.size() != 1 || outputs.size() != 1) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "To_remove node must have 1 input and 1 output"); } auto input_id = inputs[0]->id; @@ -604,26 +607,26 @@ Status RemoveOneInputOneOutputNode(Graph* graph, if (!producer && consumers.empty()) { RETURN_IF_ERROR(graph->DeleteValue(input_id)); } - return OkStatus(); + return absl::OkStatus(); } template -Status AddOutput(Graph* graph, const Node* from_node, - Value** output) { +absl::Status AddOutput(Graph* graph, const Node* from_node, + Value** output) { auto link = graph->NewValue(); RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id)); *output = link; - return OkStatus(); + return absl::OkStatus(); } template -Status ConnectTwoNodes(Graph* graph, const Node* from_node, - const Node* to_node, Value** output) { +absl::Status ConnectTwoNodes(Graph* graph, const Node* from_node, + const Node* to_node, Value** output) { Value* link; RETURN_IF_ERROR(AddOutput(graph, from_node, &link)); RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id)); *output = link; - return OkStatus(); + return absl::OkStatus(); } using GraphFloat32 = Model>; diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index b37c3542413..94899efe91e 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -65,9 +65,9 @@ namespace { // node(output) // will turn into: // node(copy(output)) <- passthrough_node(output) -Status NewPassthroughNode(GraphFloat32* graph, Node* node, - const Value>* output, - Node** passthru_node) { +absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node, + const Value>* output, + Node** passthru_node) { *passthru_node = graph->NewNode(); // Make copies for every output in the original node. RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id)); @@ -76,18 +76,18 @@ Status NewPassthroughNode(GraphFloat32* graph, Node* node, RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id)); copy_output->tensor = output->tensor; copy_output->tensor.ref = -1; - return OkStatus(); + return absl::OkStatus(); } template -Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) { +absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) { if (tensor.bytes % sizeof(T) != 0) { - return InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("Input data size ", tensor.bytes, " is not aligned to expected type: ", sizeof(T))); } std::memcpy(tensor_data, tensor.data.uint8, tensor.bytes); - return OkStatus(); + return absl::OkStatus(); } void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src, @@ -98,8 +98,8 @@ void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src, } template <> -Status CreateVectorCopyData(const TfLiteTensor& tensor, - float* tensor_data) { +absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, + float* tensor_data) { switch (tensor.type) { case kTfLiteFloat32: std::memcpy(tensor_data, tensor.data.f, tensor.bytes); @@ -110,104 +110,97 @@ Status CreateVectorCopyData(const TfLiteTensor& tensor, reinterpret_cast(tensor.data.f16), tensor_data); break; default: - return InvalidArgumentError("Unsupported data type for float32 tensor"); + return absl::InvalidArgumentError( + "Unsupported data type for float32 tensor"); } - return OkStatus(); + return absl::OkStatus(); } template -Status SetAllDimensions(const TfLiteIntArray* dimensions, ShapeT* shape); +absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, ShapeT* shape); template <> -Status SetAllDimensions(const TfLiteIntArray* dimensions, - Scalar* shape) { +absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, + Scalar* shape) { if (dimensions->size < 0) { - return InvalidArgumentError("Invalid Scalar dimensions"); + return absl::InvalidArgumentError("Invalid Scalar dimensions"); } for (int i = 0; i < dimensions->size; ++i) { if (dimensions->data[i] != 1) { - return InvalidArgumentError("Dimension can not be reduced to scalar."); + return absl::InvalidArgumentError( + "Dimension can not be reduced to scalar."); } } shape->v = 1; - return OkStatus(); + return absl::OkStatus(); } template <> -Status SetAllDimensions(const TfLiteIntArray* dimensions, - Linear* shape) { +absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, + Linear* shape) { if (dimensions->size <= 0) { - return InvalidArgumentError("Dimension is empty."); + return absl::InvalidArgumentError("Dimension is empty."); } for (int i = 0; i < dimensions->size - 1; ++i) { if (dimensions->data[i] != 1) { - return InvalidArgumentError("Dimension can not be reduced to linear."); + return absl::InvalidArgumentError( + "Dimension can not be reduced to linear."); } } shape->v = dimensions->data[dimensions->size - 1]; - return OkStatus(); + return absl::OkStatus(); } template <> -Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) { +absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, + HWC* shape) { if (dimensions->size != 4) { - return InvalidArgumentError("Dimensions are not HWC"); + return absl::InvalidArgumentError("Dimensions are not HWC"); } if (dimensions->data[0] != 1) { - return UnimplementedError("Batch size is not equal to 1."); + return absl::UnimplementedError("Batch size is not equal to 1."); } shape->h = dimensions->data[1]; shape->w = dimensions->data[2]; shape->c = dimensions->data[3]; - return OkStatus(); + return absl::OkStatus(); } template <> -Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) { +absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) { if (dimensions->size != 2) { - return InvalidArgumentError("Dimensions are not HW"); + return absl::InvalidArgumentError("Dimensions are not HW"); } shape->h = dimensions->data[0]; shape->w = dimensions->data[1]; - return OkStatus(); + return absl::OkStatus(); } template <> -Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) { +absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, + OHWI* shape) { if (dimensions->size != 4) { - return InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("Dimensions are not OHWI: ", dimensions->size)); } shape->o = dimensions->data[0]; shape->h = dimensions->data[1]; shape->w = dimensions->data[2]; shape->i = dimensions->data[3]; - return OkStatus(); + return absl::OkStatus(); } template <> -Status SetAllDimensions(const TfLiteIntArray* dimensions, IHWO* shape) { +absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, + BHWC* shape) { if (dimensions->size != 4) { - return InvalidArgumentError( - absl::StrCat("Dimensions are not IHWO: ", dimensions->size)); - } - shape->i = dimensions->data[0]; - shape->h = dimensions->data[1]; - shape->w = dimensions->data[2]; - shape->o = dimensions->data[3]; - return OkStatus(); -} - -template <> -Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) { - if (dimensions->size != 4) { - return InvalidArgumentError("Dimensions are not BHWC"); + return absl::InvalidArgumentError("Dimensions are not BHWC"); } shape->b = dimensions->data[0]; shape->h = dimensions->data[1]; shape->w = dimensions->data[2]; shape->c = dimensions->data[3]; - return OkStatus(); + return absl::OkStatus(); } DataType ToDataType(TfLiteType type) { @@ -253,46 +246,46 @@ int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context, return number_of_runtime_outputs; } -Status CheckTensorIsAvailable(const TfLiteContext* context, - const TfLiteNode* tflite_node, int idx) { +absl::Status CheckTensorIsAvailable(const TfLiteContext* context, + const TfLiteNode* tflite_node, int idx) { // If tensor id is in range, it's guaranteed that it'll be available. if (idx >= tflite_node->inputs->size) { - return OutOfRangeError( + return absl::OutOfRangeError( absl::StrFormat("Requested index goes beyond array size (%d vs %d).", idx, tflite_node->inputs->data[idx])); } - return OkStatus(); + return absl::OkStatus(); } -Status CheckInputsOutputs(const TfLiteContext* context, - const TfLiteNode* tflite_node, int runtime_inputs, - int outputs) { +absl::Status CheckInputsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, + int runtime_inputs, int outputs) { int runtime_inputs_from_model = GetNumberOfRuntimeInputsForNode(context, tflite_node); if (runtime_inputs_from_model != runtime_inputs) { - return InternalError(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Expected %d runtime input tensor(s), but node has %d runtime " "input(s).", runtime_inputs, runtime_inputs_from_model)); } int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); if (runtime_outputs != outputs) { - return InternalError( + return absl::InternalError( absl::StrFormat("Expected %d output tensor(s), but node has %d " "output(s).", outputs, runtime_outputs)); } - return OkStatus(); + return absl::OkStatus(); } -Status CheckInputsConstsOutputs(const TfLiteContext* context, - const TfLiteNode* tflite_node, - int runtime_inputs, int const_inputs, - int outputs) { +absl::Status CheckInputsConstsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, + int runtime_inputs, int const_inputs, + int outputs) { int const_inputs_from_model = GetNumberOfConstInputsForNode(context, tflite_node); if (const_inputs_from_model != const_inputs) { - return InternalError(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Expected %d const input tensor(s), but node has %d const " "input(s).", const_inputs, const_inputs_from_model)); @@ -310,9 +303,9 @@ class ObjectReader { tflite_node_(tflite_node), tensor_to_value_(tensor_to_value) {} - Status ReadValue(uint32_t idx, Value>** value) const { + absl::Status ReadValue(uint32_t idx, Value>** value) const { if (idx >= tflite_node_->inputs->size) { - return OutOfRangeError( + return absl::OutOfRangeError( absl::StrCat("ReadValue: input tensor index: ", idx)); } return ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value); @@ -322,21 +315,21 @@ class ObjectReader { return GetNumberOfRuntimeInputsForNode(context_, tflite_node_); } - Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const { + absl::Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const { if (idx >= tflite_node_->inputs->size) { - return OutOfRangeError(absl::StrCat("Input tensor index: ", idx)); + return absl::OutOfRangeError(absl::StrCat("Input tensor index: ", idx)); } const int tensor_idx = tflite_node_->inputs->data[idx]; if (tensor_idx < 0 || tensor_idx > context_->tensors_size) { - return OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx)); + return absl::OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx)); } const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx]; *dimensions = *tflite_tensor.dims; - return OkStatus(); + return absl::OkStatus(); } template - Status ReadTensor(uint32_t idx, TensorT* t) const { + absl::Status ReadTensor(uint32_t idx, TensorT* t) const { RETURN_IF_ERROR(CheckTensorIsAvailable(context_, tflite_node_, idx)); const int32_t tensor_idx = tflite_node_->inputs->data[idx]; const TfLiteTensor* tflite_tensor = context_->tensors + tensor_idx; @@ -349,9 +342,9 @@ class ObjectReader { return SetAllDimensions(tflite_tensor->dims, &t->shape); } - Status AddOutput(const Node* node, int id) { + absl::Status AddOutput(const Node* node, int id) { if (tflite_node_->outputs->size <= id) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Data id ", id, " must be less than tflite node outputs size ", tflite_node_->outputs->size)); } @@ -359,32 +352,32 @@ class ObjectReader { Value>* value; RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value)); RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id)); - return OkStatus(); + return absl::OkStatus(); } - Status AddOutputs(const Node* node) { + absl::Status AddOutputs(const Node* node) { for (int i = 0; i < tflite_node_->outputs->size; ++i) { RETURN_IF_ERROR(AddOutput(node, i)); } - return OkStatus(); + return absl::OkStatus(); } - Status AddInput(const Node* node, uint32_t idx) { + absl::Status AddInput(const Node* node, uint32_t idx) { Value>* input; RETURN_IF_ERROR(ReadValue(idx, &input)); return graph_->AddConsumer(node->id, input->id); } - Status ReadValueByTensorIdx(uint32_t tensor_idx, - Value>** value) const { + absl::Status ReadValueByTensorIdx(uint32_t tensor_idx, + Value>** value) const { if (tensor_idx >= tensor_to_value_->size()) { - return OutOfRangeError( + return absl::OutOfRangeError( absl::StrCat("ReadValue: input tensor index: ", tensor_idx)); } if ((*tensor_to_value_)[tensor_idx] == nullptr) { const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx]; if (tflite::IsConstantTensor(&tflite_tensor)) { - return NotFoundError(absl::StrCat( + return absl::NotFoundError(absl::StrCat( "ReadValue: value is a constant tensor: ", tensor_idx)); } Value>* value = graph_->NewValue(); @@ -394,7 +387,7 @@ class ObjectReader { (*tensor_to_value_)[tensor_idx] = value; } *value = (*tensor_to_value_)[tensor_idx]; - return OkStatus(); + return absl::OkStatus(); } TfLiteTensor* GetInputTensor(int index) const { @@ -409,9 +402,9 @@ class ObjectReader { : nullptr; } - Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node, - int runtime_inputs, int const_inputs, - int outputs) { + absl::Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node, + int runtime_inputs, int const_inputs, + int outputs) { return CheckInputsConstsOutputs(context_, tflite_node, runtime_inputs, const_inputs, outputs); } @@ -430,28 +423,30 @@ class TFLiteOperationParser { // Parses TFLite operation. This method allows expanding fused operations // into more than one node. - virtual Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) = 0; + virtual absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) = 0; // Verifies whether passed tflite node may be built by GPU delegate or not. - virtual Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) = 0; + virtual absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) = 0; }; -Status IsActivationSupported(TfLiteFusedActivation fused_activation) { +absl::Status IsActivationSupported(TfLiteFusedActivation fused_activation) { switch (fused_activation) { case kTfLiteActNone: case kTfLiteActRelu: case kTfLiteActRelu1: case kTfLiteActRelu6: case kTfLiteActTanh: - return OkStatus(); + return absl::OkStatus(); case kTfLiteActSignBit: - return UnimplementedError("TfLiteFusedActivation.kTfLiteActSignBit"); + return absl::UnimplementedError( + "TfLiteFusedActivation.kTfLiteActSignBit"); case kTfLiteActSigmoid: - return UnimplementedError("TfLiteFusedActivation.kTfLiteActSigmoid"); + return absl::UnimplementedError( + "TfLiteFusedActivation.kTfLiteActSigmoid"); // Do not add default; we want compilation error rather than run-time // error. @@ -461,15 +456,15 @@ Status IsActivationSupported(TfLiteFusedActivation fused_activation) { // If there is fused activation present, then there will be another node created // that will have identical output as the given node. New operation node will // depend on the given node output. -Status MaybeFuseActivation(TfLiteFusedActivation fused_activation, - const std::vector& output_indices, - GraphFloat32* graph, Node* node) { +absl::Status MaybeFuseActivation(TfLiteFusedActivation fused_activation, + const std::vector& output_indices, + GraphFloat32* graph, Node* node) { if (fused_activation == kTfLiteActNone) { - return OkStatus(); + return absl::OkStatus(); } const auto& outputs = graph->FindOutputs(node->id); if (outputs.empty()) { - return InternalError("Empty outputs in fused node"); + return absl::InternalError("Empty outputs in fused node"); } switch (fused_activation) { case kTfLiteActRelu: @@ -497,16 +492,16 @@ Status MaybeFuseActivation(TfLiteFusedActivation fused_activation, } break; default: - return NotFoundError( + return absl::NotFoundError( absl::StrCat("Unsupported fused activation: ", fused_activation)); } - return OkStatus(); + return absl::OkStatus(); } -Status MaybeFuseActivationToTheSingleOutput( +absl::Status MaybeFuseActivationToTheSingleOutput( TfLiteFusedActivation fused_activation, GraphFloat32* graph, Node* node) { if (graph->FindOutputs(node->id).size() != 1) { - return InternalError("Number of outputs exceeds 1"); + return absl::InternalError("Number of outputs exceeds 1"); } return MaybeFuseActivation(fused_activation, {0}, graph, node); } @@ -524,9 +519,10 @@ void UpdatePadding(const TfLitePadding& padding, const BHWC& input_shape, } } -Status GetFullyConnectedAttributes(int weights_tensor_id, int bias_tensor_id, - ObjectReader* reader, - FullyConnectedAttributes* attr) { +absl::Status GetFullyConnectedAttributes(int weights_tensor_id, + int bias_tensor_id, + ObjectReader* reader, + FullyConnectedAttributes* attr) { Tensor weights; RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights)); attr->weights.data = std::move(weights.data); @@ -537,100 +533,100 @@ Status GetFullyConnectedAttributes(int weights_tensor_id, int bias_tensor_id, attr->weights.shape.i = weights.shape.w; reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional - return OkStatus(); + return absl::OkStatus(); } template -Status RetrieveBuiltinData(const TfLiteNode* tflite_node, - ParamsT** tf_options) { +absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node, + ParamsT** tf_options) { const auto* params = reinterpret_cast(tflite_node->builtin_data); if (!params) { - return InternalError("Unable to retrieve builtin_data."); + return absl::InternalError("Unable to retrieve builtin_data."); } *tf_options = const_cast(params); - return OkStatus(); + return absl::OkStatus(); } template -Status RetrieveCustomInitialData(const TfLiteNode* tflite_node, - ParamsType** tf_options) { +absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node, + ParamsType** tf_options) { const auto* params = reinterpret_cast(tflite_node->custom_initial_data); if (!params) { - return InternalError("Unable to retrieve custom_initial_data."); + return absl::InternalError("Unable to retrieve custom_initial_data."); } *tf_options = const_cast(params); - return OkStatus(); + return absl::OkStatus(); } -Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration, - int max_version) { +absl::Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration, + int max_version) { const int op_version = registration->version; if (op_version > max_version) { - return UnimplementedError( + return absl::UnimplementedError( absl::StrFormat("Max version supported: %d. Requested version %d.", max_version, op_version)); } - return OkStatus(); + return absl::OkStatus(); } -Status CheckExactSupportedOpVersion(const TfLiteRegistration* registration, - int expected_version) { +absl::Status CheckExactSupportedOpVersion( + const TfLiteRegistration* registration, int expected_version) { int op_version = registration->version; if (op_version != expected_version) { - return UnimplementedError( + return absl::UnimplementedError( absl::StrFormat("Only version %d is supported. Requested version %d.", expected_version, op_version)); } - return OkStatus(); + return absl::OkStatus(); } -Status CheckKernels(int kernel_h, int kernel_w) { +absl::Status CheckKernels(int kernel_h, int kernel_w) { if (kernel_h <= 0 || kernel_w <= 0) { - return InvalidArgumentError(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "Incorrect kernel values: kernel_height = %d, kernel_width = %d.", kernel_h, kernel_w)); } - return OkStatus(); + return absl::OkStatus(); } -Status CheckStrides(int strides_h, int strides_w) { +absl::Status CheckStrides(int strides_h, int strides_w) { if (strides_h <= 0 || strides_w <= 0) { - return InvalidArgumentError(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "Incorrect stride values: stride_height = %d, stride_width = %d.", strides_h, strides_w)); } - return OkStatus(); + return absl::OkStatus(); } -Status CheckDilation(int dilation_h, int dilation_w) { +absl::Status CheckDilation(int dilation_h, int dilation_w) { if (dilation_h <= 0 || dilation_w <= 0) { - return InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrFormat("Incorrect dilation values: dilation_factor = %d, " "dilation_factor = %d.", dilation_h, dilation_w)); } - return OkStatus(); + return absl::OkStatus(); } -Status CheckStridesAndDilation(int strides_h, int strides_w, int dilation_h, - int dilation_w) { +absl::Status CheckStridesAndDilation(int strides_h, int strides_w, + int dilation_h, int dilation_w) { RETURN_IF_ERROR(CheckStrides(strides_h, strides_w)); RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w)); - return OkStatus(); + return absl::OkStatus(); } -Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h, - int strides_w) { +absl::Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h, + int strides_w) { RETURN_IF_ERROR(CheckKernels(kernel_h, kernel_w)); RETURN_IF_ERROR(CheckStrides(strides_h, strides_w)); - return OkStatus(); + return absl::OkStatus(); } // Creates a simple node that holds tensor value. -Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, - Value>** value) { +absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, + Value>** value) { ConstTensorAttributes attr; attr.tensor = std::move(t); Node* node = graph->NewNode(); @@ -642,59 +638,59 @@ Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, (*value)->tensor.ref = attr.tensor.id; (*value)->tensor.type = attr.tensor.kType; (*value)->tensor.shape = attr.tensor.shape; - return OkStatus(); + return absl::OkStatus(); } -Status ParsePoolingAttributes(const TfLitePoolParams* tf_options, - const BHWC& input_shape, - Pooling2DAttributes* attr) { +absl::Status ParsePoolingAttributes(const TfLitePoolParams* tf_options, + const BHWC& input_shape, + Pooling2DAttributes* attr) { attr->kernel = ToHW(tf_options->filter_height, tf_options->filter_width); attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width); UpdatePadding(tf_options->padding, input_shape, attr); - return OkStatus(); + return absl::OkStatus(); } -Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) { +absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) { const TfLiteIntArray* dims = tflite_tensor.dims; switch (dims->size) { case 1: *bhwc = BHWC(dims->data[0], 1, 1, 1); - return OkStatus(); + return absl::OkStatus(); case 2: *bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]); - return OkStatus(); + return absl::OkStatus(); case 3: *bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]); - return OkStatus(); + return absl::OkStatus(); case 4: *bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]); - return OkStatus(); + return absl::OkStatus(); default: - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Tensor \"", tflite_tensor.name ? tflite_tensor.name : "nullptr", "\" has bad input dims size: ", dims->size, ".")); } } -Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader, - TensorOrScalar* tensor_or_scalar) { +absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader, + TensorOrScalar* tensor_or_scalar) { const std::string& opname = node->operation.type; // Determine runtime/constant tensors. const TfLiteTensor* input0 = reader->GetInputTensor(0); if (!input0) { - return InvalidArgumentError("Couldn't get the 1st input tensor for " + - opname); + return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " + + opname); } const TfLiteTensor* input1 = reader->GetInputTensor(1); if (!input1) { - return InvalidArgumentError("Couldn't get the 2nd input tensor for " + - opname); + return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " + + opname); } const bool constant_tensor0 = IsConstantTensor(input0); const bool constant_tensor1 = IsConstantTensor(input1); if (constant_tensor0 && constant_tensor1) { - return InvalidArgumentError("No runtime input tensors for " + opname); + return absl::InvalidArgumentError("No runtime input tensors for " + opname); } const bool runtime_tensor0 = !constant_tensor0; const bool runtime_tensor1 = !constant_tensor1; @@ -722,26 +718,26 @@ Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader, *tensor_or_scalar = std::move(tensor); } } - return OkStatus(); + return absl::OkStatus(); } class AddOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); if (tflite_node->inputs->size != 2) { - return UnimplementedError("ADD requires two input tensors."); + return absl::UnimplementedError("ADD requires two input tensors."); } // TODO(eignasheva): Add shapes check. TfLiteAddParams* tf_options = nullptr; return RetrieveBuiltinData(tflite_node, &tf_options); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { // TFLite currently only supports 2 input ADDs. Thus, the logic below only // considers 2 input cases. The underlying GPU shader programs can accept // more inputs, but the logic below would have to be expanded. @@ -755,7 +751,7 @@ class AddOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node); @@ -764,9 +760,9 @@ class AddOperationParser : public TFLiteOperationParser { class ConcatenationOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); // TODO(eignasheva): add proper tensor availability checking @@ -776,12 +772,12 @@ class ConcatenationOperationParser : public TFLiteOperationParser { // TODO(eignasheva): add axis checking. TfLiteConcatenationParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { ConcatAttributes attr; // Read inputs first to make sure const node is added to a graph before // concat node to ensure topological order. @@ -832,16 +828,16 @@ class ConcatenationOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node)); node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } private: - Status SetAxis(const std::vector& input_shapes, Axis* axis) { + absl::Status SetAxis(const std::vector& input_shapes, Axis* axis) { *axis = Axis::BATCH; for (int i = 1; i < input_shapes.size(); i++) { if (input_shapes[0].h != input_shapes[i].h && @@ -851,7 +847,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser { break; } } - if (*axis == Axis::BATCH) return OkStatus(); + if (*axis == Axis::BATCH) return absl::OkStatus(); for (int i = 1; i < input_shapes.size(); i++) { if (input_shapes[0].b != input_shapes[i].b && input_shapes[0].w != input_shapes[i].w && @@ -860,7 +856,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser { break; } } - if (*axis == Axis::HEIGHT) return OkStatus(); + if (*axis == Axis::HEIGHT) return absl::OkStatus(); for (int i = 1; i < input_shapes.size(); i++) { if (input_shapes[0].b != input_shapes[i].b && input_shapes[0].h != input_shapes[i].h && @@ -869,25 +865,25 @@ class ConcatenationOperationParser : public TFLiteOperationParser { break; } } - if (*axis == Axis::WIDTH) return OkStatus(); + if (*axis == Axis::WIDTH) return absl::OkStatus(); for (int i = 1; i < input_shapes.size(); i++) { if (input_shapes[0].b != input_shapes[i].b && input_shapes[0].w != input_shapes[i].w && input_shapes[0].h != input_shapes[i].h) { - return UnimplementedError( + return absl::UnimplementedError( "Can concatenate tensors only by batch, height, width, or " "channels."); } } - return OkStatus(); + return absl::OkStatus(); } }; class Conv2DOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -900,9 +896,9 @@ class Conv2DOperationParser : public TFLiteOperationParser { return IsActivationSupported(tf_options->activation); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::CONVOLUTION_2D); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -915,7 +911,7 @@ class Conv2DOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); attr.dilations = HW(tf_options->dilation_height_factor, @@ -925,26 +921,26 @@ class Conv2DOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node)); node->operation.attributes = std::move(attr); - return OkStatus(); + return absl::OkStatus(); } }; class Convolution2DTransposeBiasParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); TfLiteTransposeConvParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); RETURN_IF_ERROR( CheckStrides(tf_options->stride_height, tf_options->stride_width)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -963,15 +959,15 @@ class Convolution2DTransposeBiasParser : public TFLiteOperationParser { &attr); node->operation.attributes = std::move(attr); - return OkStatus(); + return absl::OkStatus(); } }; class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -991,37 +987,38 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { : nullptr; const auto* output = context->tensors + tflite_node->outputs->data[0]; if (!input->dims || input->dims->size != 4) { - return InvalidArgumentError("input.dims.size != 4"); + return absl::InvalidArgumentError("input.dims.size != 4"); } if (!filter->dims || filter->dims->size != 4) { - return InvalidArgumentError("filter.dims.size != 4"); + return absl::InvalidArgumentError("filter.dims.size != 4"); } if (!output->dims || output->dims->size != 4) { - return InvalidArgumentError("output.dims.size != 4"); + return absl::InvalidArgumentError("output.dims.size != 4"); } if (input->dims->data[0] != output->dims->data[0]) { - return InvalidArgumentError("input.b != output.b"); + return absl::InvalidArgumentError("input.b != output.b"); } const int input_depth = input->dims->data[3]; const int output_depth = output->dims->data[3]; if (filter->dims->data[3] != output_depth) { - return InvalidArgumentError("filter.i != output.c"); + return absl::InvalidArgumentError("filter.i != output.c"); } if (output_depth != input_depth * depth_multiplier) { - return InvalidArgumentError("output.c != input.c * depth_multiplier"); + return absl::InvalidArgumentError( + "output.c != input.c * depth_multiplier"); } if (bias && NumElements(bias) != output_depth) { - return InvalidArgumentError("bias.size != output.c"); + return absl::InvalidArgumentError("bias.size != output.c"); } if (depth_multiplier != 1 && input_depth != 1) { - return UnimplementedError("depth_multiplier != 1 && input.c != 1"); + return absl::UnimplementedError("depth_multiplier != 1 && input.c != 1"); } - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1047,7 +1044,7 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { TransposeWeights(input, filter, output, depth_multiplier, &attr); } node->operation.attributes = std::move(attr); - return OkStatus(); + return absl::OkStatus(); } private: @@ -1086,9 +1083,9 @@ class ElementwiseOperationParser : public TFLiteOperationParser { explicit ElementwiseOperationParser(OperationType operation_type) : operation_type_(operation_type) {} - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); if (IsOneArgumentOperation()) { RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, @@ -1106,16 +1103,17 @@ class ElementwiseOperationParser : public TFLiteOperationParser { /*const_inputs=*/1, /*outputs=*/1)); } else { - return InvalidArgumentError("Op can only handle 1 or 2 operand(s)."); + return absl::InvalidArgumentError( + "Op can only handle 1 or 2 operand(s)."); } TfLiteFusedActivation activation; RETURN_IF_ERROR(GetActivation(tflite_node, &activation)); return IsActivationSupported(activation); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(operation_type_); @@ -1132,7 +1130,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser { /*const_inputs=*/0, /*outputs=*/1)); if (tflite_node->inputs->size != 2) { - return InvalidArgumentError("Applies only two input tensors"); + return absl::InvalidArgumentError("Applies only two input tensors"); } RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddInput(node, 1)); @@ -1173,32 +1171,32 @@ class ElementwiseOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); node->operation.attributes = std::move(attr); } else { - return InvalidArgumentError("Incorrect operation type passed"); + return absl::InvalidArgumentError("Incorrect operation type passed"); } return reader->AddOutputs(node); } private: - Status GetActivation(const TfLiteNode* tflite_node, - TfLiteFusedActivation* activation) const { + absl::Status GetActivation(const TfLiteNode* tflite_node, + TfLiteFusedActivation* activation) const { if (operation_type_ == OperationType::DIV) { TfLiteDivParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); *activation = tf_options ? tf_options->activation : kTfLiteActNone; - return OkStatus(); + return absl::OkStatus(); } if (operation_type_ == OperationType::SUB) { TfLiteSubParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); *activation = tf_options ? tf_options->activation : kTfLiteActNone; - return OkStatus(); + return absl::OkStatus(); } // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or // TfLiteXxxParams.activation. *activation = kTfLiteActNone; - return OkStatus(); + return absl::OkStatus(); } bool IsOneArgumentOperation() const { @@ -1247,23 +1245,24 @@ class ElementwiseOperationParser : public TFLiteOperationParser { class FullyConnectedOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); TfLiteFullyConnectedParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) { - return UnimplementedError("Unsupported FullyConnected weights format."); + return absl::UnimplementedError( + "Unsupported FullyConnected weights format."); } // TODO(eignasheva): check input shape - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1272,7 +1271,8 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { tflite_node->builtin_data); if (tf_options->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) { - return UnimplementedError("Unsupported FullyConnected weights format."); + return absl::UnimplementedError( + "Unsupported FullyConnected weights format."); } FullyConnectedAttributes attr; @@ -1284,7 +1284,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { int batch_size = input->tensor.shape.b; if (input->tensor.shape.DimensionsProduct() / batch_size != weights.shape.w) { - return UnimplementedError( + return absl::UnimplementedError( "Amount of input data should match weights width"); } @@ -1306,7 +1306,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { conv->operation.type = ToString(OperationType::FULLY_CONNECTED); conv->operation.attributes = std::move(attr); - Status result = reader->AddOutputs(conv); + absl::Status result = reader->AddOutputs(conv); RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, conv)); @@ -1316,15 +1316,15 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { class HardSwishOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration*) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration*) final { return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } - Status Parse(const TfLiteNode*, const TfLiteRegistration*, - GraphFloat32* graph, ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::HARD_SWISH); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1353,9 +1353,9 @@ class HardSwishOperationParser : public TFLiteOperationParser { // class LSTMOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2)); // TODO(eignasheva): Fix bad check. // RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, @@ -1364,23 +1364,23 @@ class LSTMOperationParser : public TFLiteOperationParser { TfLiteLSTMParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckParameters(tf_options)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { if (tflite_node->inputs->size != 5) { - return InvalidArgumentError("LSTM should have 5 input tensors"); + return absl::InvalidArgumentError("LSTM should have 5 input tensors"); } if (tflite_node->outputs->size != 4) { - return InvalidArgumentError("LSTM should have 4 output tensors"); + return absl::InvalidArgumentError("LSTM should have 4 output tensors"); } const auto* params = reinterpret_cast(tflite_node->builtin_data); if (!params) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } RETURN_IF_ERROR(CheckParameters(params)); @@ -1423,58 +1423,61 @@ class LSTMOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation - return OkStatus(); + return absl::OkStatus(); } private: - Status CheckParameters(const TfLiteLSTMParams* tf_options) { + absl::Status CheckParameters(const TfLiteLSTMParams* tf_options) { if (tf_options->kernel_type != TfLiteLSTMKernelType::kTfLiteLSTMBasicKernel) { - return UnimplementedError("Only kTfLiteLSTMBasicKernel is supported."); + return absl::UnimplementedError( + "Only kTfLiteLSTMBasicKernel is supported."); } if (tf_options->activation != kTfLiteActTanh) { - return UnimplementedError("Only TANH activation is supported."); + return absl::UnimplementedError("Only TANH activation is supported."); } if (tf_options->cell_clip != 0.0f) { - return UnimplementedError("cell_clip is not supported."); + return absl::UnimplementedError("cell_clip is not supported."); } if (tf_options->proj_clip != 0.0f) { - return UnimplementedError("proj_clip is not supported."); + return absl::UnimplementedError("proj_clip is not supported."); } - return OkStatus(); + return absl::OkStatus(); } }; class MulOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); if (tflite_node->inputs->size != 2) { - return UnimplementedError("MUL requires two input tensors."); + return absl::UnimplementedError("MUL requires two input tensors."); } TfLiteMulParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); return IsActivationSupported(tf_options->activation); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { // Determine runtime/constant tensors. const TfLiteTensor* input0 = reader->GetInputTensor(0); if (!input0) { - return InvalidArgumentError("Couldn't get the 1st input tensor for MUL."); + return absl::InvalidArgumentError( + "Couldn't get the 1st input tensor for MUL."); } const TfLiteTensor* input1 = reader->GetInputTensor(1); if (!input1) { - return InvalidArgumentError("Couldn't get the 2nd input tensor for MUL."); + return absl::InvalidArgumentError( + "Couldn't get the 2nd input tensor for MUL."); } const bool constant_tensor0 = IsConstantTensor(input0); const bool constant_tensor1 = IsConstantTensor(input1); if (constant_tensor0 && constant_tensor1) { - return InvalidArgumentError("No runtime input tensors for MUL."); + return absl::InvalidArgumentError("No runtime input tensors for MUL."); } const bool runtime_tensor0 = !constant_tensor0; const bool runtime_tensor1 = !constant_tensor1; @@ -1516,24 +1519,24 @@ class MulOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing TfLiteMulParams"); + return absl::InternalError("Missing TfLiteMulParams"); } return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node); } private: - Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1, - GraphFloat32* graph, ObjectReader* reader) { + absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1, + GraphFloat32* graph, ObjectReader* reader) { RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); return reader->AddOutputs(node); } - Status ParseMultiplyScalar(Node* node, int runtime_tensor, - int constant_tensor, - const TfLiteIntArray* constant_dims, - GraphFloat32* graph, ObjectReader* reader) { + absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor, + int constant_tensor, + const TfLiteIntArray* constant_dims, + GraphFloat32* graph, ObjectReader* reader) { RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); MultiplyAttributes attr; if (constant_dims->size <= 0) { @@ -1552,16 +1555,16 @@ class MulOperationParser : public TFLiteOperationParser { class PReLUOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); // TODO(eignasheva): add params check - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::PRELU); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1569,10 +1572,10 @@ class PReLUOperationParser : public TFLiteOperationParser { PReLUAttributes attr; Tensor linear_alpha; - Status status = reader->ReadTensor(1, &linear_alpha); + absl::Status status = reader->ReadTensor(1, &linear_alpha); if (status.ok()) { if (linear_alpha.shape.v != input_shape.c) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Linear alpha shape does not match the number of input channels."); } attr.alpha = std::move(linear_alpha); @@ -1582,7 +1585,8 @@ class PReLUOperationParser : public TFLiteOperationParser { if (hwc_alpha.shape.h != input_shape.h || hwc_alpha.shape.w != input_shape.w || hwc_alpha.shape.c != input_shape.c) { - return InvalidArgumentError("Alpha shape does not match input shape."); + return absl::InvalidArgumentError( + "Alpha shape does not match input shape."); } attr.alpha = std::move(hwc_alpha); } @@ -1595,15 +1599,15 @@ class PadOperationParser : public TFLiteOperationParser { public: explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {} - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { if (mirror_pad_) { auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (tf_options->mode != TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Only Reflective padding is supported for Mirror Pad operation."); } } @@ -1611,12 +1615,12 @@ class PadOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::PAD); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1634,14 +1638,15 @@ class PadOperationParser : public TFLiteOperationParser { // 4x2 tensor with paddings. if (paddings.shape.h != 4 || paddings.shape.w != 2) { - return InvalidArgumentError("Paddings tensor has unexpected shape."); + return absl::InvalidArgumentError( + "Paddings tensor has unexpected shape."); } attr.prepended = BHWC(paddings.data[0], paddings.data[2], paddings.data[4], paddings.data[6]); attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5], paddings.data[7]); node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } private: @@ -1650,9 +1655,9 @@ class PadOperationParser : public TFLiteOperationParser { class Pooling2DOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); TfLitePoolParams* tf_options = nullptr; auto status = RetrieveCustomInitialData(tflite_node, &tf_options); @@ -1675,9 +1680,9 @@ class Pooling2DOperationParser : public TFLiteOperationParser { public: explicit Pooling2DOperationParser(PoolingType type) : type_(type) {} - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::POOLING_2D); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1699,7 +1704,7 @@ class Pooling2DOperationParser : public TFLiteOperationParser { reinterpret_cast(tflite_node->builtin_data); } if (!tf_options) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } std::vector max_tensor_id{0}; @@ -1719,7 +1724,7 @@ class Pooling2DOperationParser : public TFLiteOperationParser { } RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr)); node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } private: @@ -1730,16 +1735,16 @@ class ReLUOperationParser : public TFLiteOperationParser { public: explicit ReLUOperationParser(int clip) : clip_(clip) {} - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::RELU); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1759,19 +1764,19 @@ class ReLUOperationParser : public TFLiteOperationParser { class ReshapeOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); // TODO(eignasheva): add shape checking - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::RESHAPE); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1784,7 +1789,7 @@ class ReshapeOperationParser : public TFLiteOperationParser { ReshapeAttributes attr; attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape; node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } }; @@ -1793,9 +1798,9 @@ class Resize2DOperationParser : public TFLiteOperationParser { explicit Resize2DOperationParser(SamplingType sampling_type) : sampling_type_(sampling_type) {} - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -1805,12 +1810,12 @@ class Resize2DOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners)); bool half_pixel_centers; RETURN_IF_ERROR(GetHalfPixelCentersValue(tflite_node, &half_pixel_centers)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::RESIZE); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1826,12 +1831,12 @@ class Resize2DOperationParser : public TFLiteOperationParser { attr.new_shape.CopyAllDefinedAxis( graph->FindOutputs(node->id)[0]->tensor.shape); node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } private: - Status GetAlignCornersValue(const TfLiteNode* tflite_node, - bool* align_corners) { + absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node, + bool* align_corners) { switch (sampling_type_) { case SamplingType::BILINEAR: return GetAlignCornersValueForType( @@ -1840,61 +1845,62 @@ class Resize2DOperationParser : public TFLiteOperationParser { return GetAlignCornersValueForType( tflite_node, align_corners); case SamplingType::UNKNOWN: - return InternalError("Sampling type is not specified"); + return absl::InternalError("Sampling type is not specified"); } - return OkStatus(); + return absl::OkStatus(); } template - Status GetAlignCornersValueForType(const TfLiteNode* tflite_node, - bool* align_corners) { + absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node, + bool* align_corners) { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } *align_corners = tf_options->align_corners; - return OkStatus(); + return absl::OkStatus(); } - Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node, - bool* half_pixel_centers) { + absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node, + bool* half_pixel_centers) { if (sampling_type_ == SamplingType::BILINEAR) { const auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params for ResizeBilinear op"); + return absl::InternalError( + "Missing tflite params for ResizeBilinear op"); } if (tf_options->align_corners && tf_options->half_pixel_centers) { - return InternalError( + return absl::InternalError( "If half_pixel_centers is True, align_corners must be False."); } *half_pixel_centers = tf_options->half_pixel_centers; } else { *half_pixel_centers = false; } - return OkStatus(); + return absl::OkStatus(); } - Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node) { + absl::Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node) { const auto* input = context->tensors + tflite_node->inputs->data[0]; const auto* output = context->tensors + tflite_node->outputs->data[0]; if (!input->dims || input->dims->size != 4) { - return InvalidArgumentError("input.dims.size != 4"); + return absl::InvalidArgumentError("input.dims.size != 4"); } if (!output->dims || output->dims->size != 4) { - return InvalidArgumentError("output.dims.size != 4"); + return absl::InvalidArgumentError("output.dims.size != 4"); } if (output->dims->data[1] < input->dims->data[1] || output->dims->data[2] < input->dims->data[2]) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Only upsampling is supported, received output h,w = ", output->dims->data[1], ",", output->dims->data[2], " input h,w = ", input->dims->data[1], ",", input->dims->data[2])); } - return OkStatus(); + return absl::OkStatus(); } SamplingType sampling_type_ = SamplingType::UNKNOWN; @@ -1902,16 +1908,16 @@ class Resize2DOperationParser : public TFLiteOperationParser { class SliceOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::SLICE); RETURN_IF_ERROR(reader->AddOutputs(node)); @@ -1925,7 +1931,7 @@ class SliceOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->ReadTensor(1, &starts)); RETURN_IF_ERROR(reader->ReadTensor(2, &sizes)); if (starts.data.size() != sizes.data.size()) { - return InvalidArgumentError("Starts amount != sizes amount."); + return absl::InvalidArgumentError("Starts amount != sizes amount."); } if (starts.data.size() == 4) { attr.starts = @@ -1939,30 +1945,31 @@ class SliceOperationParser : public TFLiteOperationParser { BHWC(input->tensor.shape.b, starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]); } else { - return UnimplementedError( + return absl::UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); } RETURN_IF_ERROR(UpdateIfNegative(input->tensor.shape, &attr)); auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; if ((attr.ends.b - attr.starts.b) != out_shape.b) { - return UnimplementedError("Output batch don't match"); + return absl::UnimplementedError("Output batch don't match"); } if ((attr.ends.h - attr.starts.h) != out_shape.h) { - return UnimplementedError("Output height doesn't match"); + return absl::UnimplementedError("Output height doesn't match"); } if ((attr.ends.w - attr.starts.w) != out_shape.w) { - return UnimplementedError("Output width doesn't match"); + return absl::UnimplementedError("Output width doesn't match"); } if ((attr.ends.c - attr.starts.c) != out_shape.c) { - return UnimplementedError("Output channels don't match"); + return absl::UnimplementedError("Output channels don't match"); } node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } private: - Status UpdateIfNegative(const BHWC& input_shape, SliceAttributes* attr) { + absl::Status UpdateIfNegative(const BHWC& input_shape, + SliceAttributes* attr) { if (attr->ends.h < 0) { attr->ends.h = input_shape.h + attr->ends.h; } @@ -1975,15 +1982,15 @@ class SliceOperationParser : public TFLiteOperationParser { if (attr->ends.b < 0) { attr->ends.b = input_shape.b + attr->ends.b; } - return OkStatus(); + return absl::OkStatus(); } }; class SoftmaxOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -1991,14 +1998,14 @@ class SoftmaxOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->beta != 1) { // TODO(eignasheva): figure out, what's wrong with softmax. - return UnimplementedError("Softmax.beta != 1 is not supported."); + return absl::UnimplementedError("Softmax.beta != 1 is not supported."); } - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::SOFTMAX); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2007,27 +2014,27 @@ class SoftmaxOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } if (tf_options->beta != 1) { // there is multiply by scalar operation fused in softmax. Make a layer // out of it before softmax. - return UnimplementedError("Softmax.beta != 1 is not supported."); + return absl::UnimplementedError("Softmax.beta != 1 is not supported."); // auto mul_node = reader->NewPassthroughNode(node); // mul_node->operation.type = ToString(OperationType::MUL); } SoftmaxAttributes attr; attr.axis = Axis::CHANNELS; // always by channels node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } }; class SpaceToDepthOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -2035,17 +2042,19 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser { TfLiteSpaceToDepthParams* s2d_params = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params)); if (s2d_params->block_size == 1) { - return InvalidArgumentError("SPACE_TO_DEPTH block_size = 1 is a no-op."); + return absl::InvalidArgumentError( + "SPACE_TO_DEPTH block_size = 1 is a no-op."); } if (s2d_params->block_size < 1) { - return InvalidArgumentError("SPACE_TO_DEPTH block_size must be > 1."); + return absl::InvalidArgumentError( + "SPACE_TO_DEPTH block_size must be > 1."); } - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::SPACE_TO_DEPTH); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2055,25 +2064,25 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser { SpaceToDepthAttributes attr; attr.block_size = tf_options->block_size; node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } }; class StridedSliceOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); TfLiteStridedSliceParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::SLICE); RETURN_IF_ERROR(reader->AddOutputs(node)); @@ -2087,7 +2096,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser { bool read_without_batch = tmp.data.size() == 3; bool read_with_batch = tmp.data.size() == 4; if (!read_without_batch && !read_with_batch) { - return UnimplementedError( + return absl::UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); } @@ -2095,7 +2104,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser { tflite_node->builtin_data); auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; if (!tf_options) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); @@ -2110,36 +2119,37 @@ class StridedSliceOperationParser : public TFLiteOperationParser { } if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 || attr.strides.c == 0) { - return InvalidArgumentError("stride values must be non-zero"); + return absl::InvalidArgumentError("stride values must be non-zero"); } if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 || attr.strides.c < 0) { - return UnimplementedError("Reverse slices are not supported."); + return absl::UnimplementedError("Reverse slices are not supported."); } if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b != out_shape.b) { - return UnimplementedError("Output batch don't match"); + return absl::UnimplementedError("Output batch don't match"); } if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h != out_shape.h) { - return UnimplementedError("Output height doesn't match"); + return absl::UnimplementedError("Output height doesn't match"); } if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w != out_shape.w) { - return UnimplementedError("Output width doesn't match"); + return absl::UnimplementedError("Output width doesn't match"); } if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c != out_shape.c) { - return UnimplementedError("Output channels don't match"); + return absl::UnimplementedError("Output channels don't match"); } node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } private: - Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options, - const BHWC& input_shape, int ignore_b, int ignore_h, - int ignore_w, int ignore_c, SliceAttributes* attr) { + absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, int ignore_b, + int ignore_h, int ignore_w, int ignore_c, + SliceAttributes* attr) { if (tf_options->begin_mask & ignore_h) { attr->starts.h = 0; } @@ -2165,10 +2175,11 @@ class StridedSliceOperationParser : public TFLiteOperationParser { if (tf_options->end_mask & ignore_b) { attr->ends.b = input_shape.b; } - return OkStatus(); + return absl::OkStatus(); } - Status UpdateIfNegative(const BHWC& input_shape, SliceAttributes* attr) { + absl::Status UpdateIfNegative(const BHWC& input_shape, + SliceAttributes* attr) { if (attr->ends.h < 0) { attr->ends.h = input_shape.h + attr->ends.h; } @@ -2181,17 +2192,18 @@ class StridedSliceOperationParser : public TFLiteOperationParser { if (attr->ends.b < 0) { attr->ends.b = input_shape.b + attr->ends.b; } - return OkStatus(); + return absl::OkStatus(); } - Status ReadAttribsWithBatch(const ObjectReader* reader, - const TfLiteStridedSliceParams* tf_options, - const BHWC& input_shape, SliceAttributes* attr) { - auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> Status { + absl::Status ReadAttribsWithBatch(const ObjectReader* reader, + const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, + SliceAttributes* attr) { + auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status { Tensor t; RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]); - return OkStatus(); + return absl::OkStatus(); }; RETURN_IF_ERROR(read_bhwc(1, &attr->starts)); @@ -2199,18 +2211,17 @@ class StridedSliceOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(read_bhwc(3, &attr->strides)); RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr)); RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr)); - return OkStatus(); + return absl::OkStatus(); } - Status ReadAttribsWithoutBatch(const ObjectReader* reader, - const TfLiteStridedSliceParams* tf_options, - const BHWC& input_shape, - SliceAttributes* attr) { - auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> Status { + absl::Status ReadAttribsWithoutBatch( + const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, SliceAttributes* attr) { + auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status { Tensor t; RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]); - return OkStatus(); + return absl::OkStatus(); }; RETURN_IF_ERROR(read_hwc(1, &attr->starts)); @@ -2221,43 +2232,43 @@ class StridedSliceOperationParser : public TFLiteOperationParser { attr->starts.b = 0; attr->ends.b = input_shape.b; attr->strides.b = 1; - return OkStatus(); + return absl::OkStatus(); } - Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) { + absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) { if (tf_options->ellipsis_mask) { - return UnimplementedError("Slice does not support ellipsis_mask."); + return absl::UnimplementedError("Slice does not support ellipsis_mask."); } if (tf_options->new_axis_mask) { - return UnimplementedError("Slice does not support new_axis_mask."); + return absl::UnimplementedError("Slice does not support new_axis_mask."); } if (tf_options->shrink_axis_mask) { - return UnimplementedError( + return absl::UnimplementedError( "Slice does not support shrink_axis_mask parameter. "); } - return OkStatus(); + return absl::OkStatus(); } }; class TransposeConvOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); TfLiteTransposeConvParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR( CheckStrides(tf_options->stride_height, tf_options->stride_width)); - return OkStatus(); + return absl::OkStatus(); } // TFLite's TRANSPOSE_CONV expects 3 input (output shape, weights, and input) // and allows configurable padding & stride. // TODO(impjdi): Translate output_shape to attr.adjacent. - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); Value>* input; @@ -2268,7 +2279,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return InternalError("Missing tflite options."); + return absl::InternalError("Missing tflite options."); } ConvolutionTransposedAttributes attr; attr.stride = tf_options @@ -2281,24 +2292,24 @@ class TransposeConvOperationParser : public TFLiteOperationParser { UpdatePadding(tf_options->padding, graph->FindInputs(node->id)[0]->tensor.shape, &attr); node->operation.attributes = std::move(attr); - return OkStatus(); + return absl::OkStatus(); } }; class TransposeOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::TRANSPOSE); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2314,19 +2325,20 @@ class TransposeOperationParser : public TFLiteOperationParser { } else if (perm.data.size() == 2) { attr.perm = BHWC(0, 1, perm.data[0] + 2, perm.data[1] + 2); } else { - return InvalidArgumentError("Permutation for transpose is invalid."); + return absl::InvalidArgumentError( + "Permutation for transpose is invalid."); } node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } }; class Unpooling2DOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { TfLitePoolParams* tf_options = nullptr; RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/2, /*outputs=*/1)); @@ -2334,12 +2346,12 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckKernelsAndStrides( tf_options->filter_height, tf_options->filter_width, tf_options->stride_height, tf_options->stride_width)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2350,7 +2362,7 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->custom_initial_data); if (!tf_options) { - return InternalError("Missing tflite params"); + return absl::InternalError("Missing tflite params"); } attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width); attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); @@ -2360,22 +2372,22 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { auto output_value = graph->FindOutputs(node->id)[0]; output_value->tensor.shape = CalculateOutputShape(input_shape, attr); - return OkStatus(); + return absl::OkStatus(); } }; // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported. class BatchToSpaceOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return OkStatus(); + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::BATCH_TO_SPACE); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2385,7 +2397,7 @@ class BatchToSpaceOperationParser : public TFLiteOperationParser { Tensor block; RETURN_IF_ERROR(reader->ReadTensor(1, &block)); if (block.shape.v != 2) { - return InternalError("Space has to be HxW."); + return absl::InternalError("Space has to be HxW."); } bs_attr.block.h = block.data[0]; bs_attr.block.w = block.data[1]; @@ -2394,7 +2406,7 @@ class BatchToSpaceOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->ReadTensor(2, &crop)); auto crop_shape = crop.shape; if (crop_shape.h != 2 && crop_shape.w != 2) { - return InternalError("Space has to be HxW."); + return absl::InternalError("Space has to be HxW."); } bs_attr.crop.prepended.h = crop.data[0]; @@ -2404,21 +2416,21 @@ class BatchToSpaceOperationParser : public TFLiteOperationParser { bs_attr.crop.appended.w = crop.data[3]; node->operation.attributes = std::move(bs_attr); - return OkStatus(); + return absl::OkStatus(); } }; class SpaceToBatchOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return OkStatus(); + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::SPACE_TO_BATCH); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2427,7 +2439,7 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser { Tensor block; RETURN_IF_ERROR(reader->ReadTensor(1, &block)); if (block.shape.v != 2) { - return InternalError("Space has to be HxW."); + return absl::InternalError("Space has to be HxW."); } sb_attr.block.h = block.data[0]; sb_attr.block.w = block.data[1]; @@ -2437,7 +2449,7 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser { auto padding_shape = padding.shape; if (padding_shape.h != 2 && padding_shape.w != 2) { - return InternalError("Space has to be HxW."); + return absl::InternalError("Space has to be HxW."); } sb_attr.padding.prepended.h = padding.data[0]; @@ -2447,23 +2459,23 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser { sb_attr.padding.appended.w = padding.data[3]; node->operation.attributes = std::move(sb_attr); - return OkStatus(); + return absl::OkStatus(); } }; class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox RETURN_IF_ERROR(reader->AddOutputs(node)); @@ -2478,7 +2490,7 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { auto output_value = graph->FindOutputs(node->id)[0]; output_value->tensor.shape = output_shape; - return OkStatus(); + return absl::OkStatus(); } private: @@ -2486,17 +2498,17 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { class TransformTensorOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/2, /*outputs=*/1)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); // data RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox @@ -2515,7 +2527,7 @@ class TransformTensorOperationParser : public TFLiteOperationParser { output_value->tensor.shape = BHWC(1, output_shape.h, output_shape.w, graph->FindInputs(node->id)[0]->tensor.shape.c); - return OkStatus(); + return absl::OkStatus(); } private: @@ -2523,17 +2535,17 @@ class TransformTensorOperationParser : public TFLiteOperationParser { class TransformLandmarksOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/2, /*outputs=*/1)); - return OkStatus(); + return absl::OkStatus(); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); // data RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox @@ -2549,7 +2561,7 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { auto output_value = graph->FindOutputs(node->id)[0]; output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape; - return OkStatus(); + return absl::OkStatus(); } private: @@ -2557,16 +2569,16 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix @@ -2581,7 +2593,7 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { auto output_value = graph->FindOutputs(node->id)[0]; output_value->tensor.shape = output_shape; - return OkStatus(); + return absl::OkStatus(); } private: @@ -2589,16 +2601,16 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { class MeanOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::MEAN); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2623,27 +2635,27 @@ class MeanOperationParser : public TFLiteOperationParser { unsupported = unsupported.empty() ? "channels" : unsupported; ABSL_FALLTHROUGH_INTENDED; default: - return UnimplementedError( + return absl::UnimplementedError( absl::StrCat("Unsupported mean dimension: ", unsupported)); } } node->operation.attributes = attr; - return OkStatus(); + return absl::OkStatus(); } }; class UnsupportedOperationParser : public TFLiteOperationParser { public: - Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return UnimplementedError("Operation is not supported."); + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return absl::UnimplementedError("Operation is not supported."); } - Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, GraphFloat32* graph, - ObjectReader* reader) final { - return UnimplementedError("Operation is not supported."); + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + return absl::UnimplementedError("Operation is not supported."); } }; @@ -2772,15 +2784,15 @@ std::unique_ptr NewOperationParser( return absl::make_unique(); } -Status GetNodeAndRegistration(TfLiteContext* context, int node_id, - TfLiteNode** tflite_node, - TfLiteRegistration** registration) { +absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id, + TfLiteNode** tflite_node, + TfLiteRegistration** registration) { if (context->GetNodeAndRegistration(context, node_id, tflite_node, registration) != kTfLiteOk) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Couldn't get node and registration info for op: ", node_id)); } - return OkStatus(); + return absl::OkStatus(); } using IsNodeSupportedFn = tflite::delegates::IsNodeSupportedFn; @@ -2963,8 +2975,8 @@ class GraphWithDequantPartitionHelper std::set dequant_nodes_to_save_; }; -Status IsSupported(const TfLiteContext* context, TfLiteNode* node, - const TfLiteRegistration* registration) { +absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node, + const TfLiteRegistration* registration) { return NewOperationParser(registration) ->IsSupported(context, node, registration); } @@ -2983,8 +2995,8 @@ bool IsAllFloatTensors(const TfLiteContext* context, } } // namespace -Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, - TensorRef* tensor_ref) { +absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, + TensorRef* tensor_ref) { tensor_ref->type = ToDataType(tflite_tensor.type); return ExtractTensorShape(tflite_tensor, &tensor_ref->shape); } @@ -2998,7 +3010,9 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { std::string* unsupported_details) -> bool { const auto status = IsSupported(context, node, registration); if (!status.ok()) { - if (unsupported_details) *unsupported_details = status.error_message(); + if (unsupported_details) { + *unsupported_details = std::string(status.message()); + } return false; } @@ -3048,9 +3062,9 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { return ConvertVectorToTfLiteIntArray(ops_to_replace); } -Status BuildModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph) { +absl::Status BuildModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph) { std::vector> operations; std::vector tflite_nodes; for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { @@ -3065,7 +3079,7 @@ Status BuildModel(TfLiteContext* context, } auto op_parser = NewOperationParser(registration); if (!op_parser) { - return UnimplementedError( + return absl::UnimplementedError( absl::StrCat("Operation ", registration->builtin_code, "(", registration->custom_name, ") is not supported by TFLite GPU Delegate.")); @@ -3085,25 +3099,25 @@ Status BuildModel(TfLiteContext* context, const auto status = operations[i]->Parse(tflite_node, registration, graph, &reader); if (!status.ok()) { - return InternalError(absl::StrCat(GetOpNameByRegistration(*registration), - ": ", status.error_message())); + return absl::InternalError(absl::StrCat( + GetOpNameByRegistration(*registration), ": ", status.message())); } } - return OkStatus(); + return absl::OkStatus(); } -Status BuildFinalModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph) { +absl::Status BuildFinalModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph) { RETURN_IF_ERROR(BuildModel(context, delegate_params, graph)); // Apply general transformations on the graph. NullTransformationReporter reporter; ModelTransformer transformer(graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return InternalError("Graph general transformations failed"); + return absl::InternalError("Graph general transformations failed"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h index f81dd90933c..b8fcab0c5c8 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -32,19 +32,19 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context); // Extracts TFLite delegate execution plan from the input TFLite context and // converts it into generic graph format. -Status BuildModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph); +absl::Status BuildModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph); // Same as above but also apply all transformations on the final graph. // Prefer using this method instead of BuildModel. -Status BuildFinalModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph); +absl::Status BuildFinalModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph); // Module-internal converter, exposed for unit testing purpose only. -Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, - TensorRef* tensor_ref); +absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, + TensorRef* tensor_ref); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc index b20b24d28c3..771ed7378b9 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.cc +++ b/tensorflow/lite/delegates/gpu/common/operations.cc @@ -519,14 +519,15 @@ BHWC CalculateOutputShape(const BHWC& input, const MeanAttributes& attr) { return BHWC(b, h, w, c); } -Status CalculateOutputShape(const std::vector& input, - const ConcatAttributes& attr, BHWC* output_shape) { +absl::Status CalculateOutputShape(const std::vector& input, + const ConcatAttributes& attr, + BHWC* output_shape) { BHWC new_shape = input[0]; switch (attr.axis) { case Axis::CHANNELS: for (int i = 1; i < input.size(); i++) { if (input[i].h != new_shape.h || input[i].w != new_shape.w) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Height and Width must be the same when concatenating " "by channels axis"); } @@ -536,7 +537,7 @@ Status CalculateOutputShape(const std::vector& input, case Axis::HEIGHT: for (int i = 1; i < input.size(); i++) { if (input[i].w != new_shape.w || input[i].c != new_shape.c) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Channels and Width must be the same when concatenating " "by height axis"); } @@ -546,7 +547,7 @@ Status CalculateOutputShape(const std::vector& input, case Axis::WIDTH: for (int i = 1; i < input.size(); i++) { if (input[i].h != new_shape.h || input[i].c != new_shape.c) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Height and Channels must be the same when concatenating " "by width axis"); } @@ -554,11 +555,11 @@ Status CalculateOutputShape(const std::vector& input, } break; default: - return InvalidArgumentError("Invalid axis"); + return absl::InvalidArgumentError("Invalid axis"); break; } *output_shape = new_shape; - return OkStatus(); + return absl::OkStatus(); } Padding2D CalculateSamePadding(const BHWC& input, diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h index 16016d334cf..4eb41dfe1a3 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.h +++ b/tensorflow/lite/delegates/gpu/common/operations.h @@ -202,8 +202,9 @@ BHWDC CalculateOutputShape(const BHWDC& input, const Pooling3DAttributes& attr); // @return shape of a tensor after Concat operation is applied to the given // input. -Status CalculateOutputShape(const std::vector& input, - const ConcatAttributes& attr, BHWC* output_shape); +absl::Status CalculateOutputShape(const std::vector& input, + const ConcatAttributes& attr, + BHWC* output_shape); // @return padding for pooling operation to make sure output keep the same shape // as the given input. diff --git a/tensorflow/lite/delegates/gpu/common/status.h b/tensorflow/lite/delegates/gpu/common/status.h index 250a3b5e3eb..d6b5dd8a94a 100644 --- a/tensorflow/lite/delegates/gpu/common/status.h +++ b/tensorflow/lite/delegates/gpu/common/status.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,109 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ -#include - -namespace tflite { -namespace gpu { - -enum class StatusCode { - kOk = 0, - kCancelled = 1, - kUnknown = 2, - kInvalidArgument = 3, - kDeadlineExceeded = 4, - kNotFound = 5, - kAlreadyExists = 6, - kPermissionDenied = 7, - kResourceExhausted = 8, - kFailedPrecondition = 9, - kAborted = 10, - kOutOfRange = 11, - kUnimplemented = 12, - kInternal = 13, - kUnavailable = 14, - kDataLoss = 15, - kUnauthenticated = 16, - kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 -}; - -// Lite version of Status without dependency on protobuf. -// TODO(b/128867901): Migrate to absl::Status. -class Status { - public: - Status() = default; - Status(StatusCode code) : code_(code) {} - Status(StatusCode code, const std::string& error_message) - : code_(code), error_message_(error_message) {} - - const std::string& error_message() const { return error_message_; } - StatusCode code() const { return code_; } - bool ok() const { return code_ == StatusCode::kOk; } - - void IgnoreError() const {} - - private: - StatusCode code_ = StatusCode::kOk; - std::string error_message_; -}; - -#define RETURN_IF_ERROR(status) \ - { \ - const auto status2 = (status); \ - if (!status2.ok()) return status2; \ - } - -inline Status OkStatus() { return Status(); } - -inline Status AlreadyExistsError(const std::string& message) { - return Status(StatusCode::kAlreadyExists, message); -} - -inline Status DeadlineExceededError(const std::string& message) { - return Status(StatusCode::kDeadlineExceeded, message); -} - -inline Status FailedPreconditionError(const std::string& message) { - return Status(StatusCode::kFailedPrecondition, message); -} - -inline Status InternalError(const std::string& message) { - return Status(StatusCode::kInternal, message); -} - -inline Status InvalidArgumentError(const std::string& message) { - return Status(StatusCode::kInvalidArgument, message); -} - -inline Status NotFoundError(const std::string& message) { - return Status(StatusCode::kNotFound, message); -} - -inline Status OutOfRangeError(const std::string& message) { - return Status(StatusCode::kOutOfRange, message); -} - -inline Status PermissionDeniedError(const std::string& message) { - return Status(StatusCode::kPermissionDenied, message); -} - -inline Status ResourceExhaustedError(const std::string& message) { - return Status(StatusCode::kResourceExhausted, message); -} - -inline Status UnavailableError(const std::string& message) { - return Status(StatusCode::kUnavailable, message); -} - -inline Status UnimplementedError(const std::string& message) { - return Status(StatusCode::kUnimplemented, message); -} - -inline Status UnknownError(const std::string& message) { - return Status(StatusCode::kUnknown, message); -} - -} // namespace gpu -} // namespace tflite +#include "absl/status/status.h" +#define RETURN_IF_ERROR(s) {auto c=(s);if(!c.ok())return c;} #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc index cbd62fa6853..08d9448f7e5 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc @@ -30,21 +30,21 @@ namespace tflite { namespace gpu { namespace testing { -Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, - TfLiteDelegate* delegate, - const OpResolver& op_resolver, - const std::vector& inputs, - std::vector* outputs) { +absl::Status InterpreterInvokeWithOpResolver( + const ::tflite::Model* model, TfLiteDelegate* delegate, + const OpResolver& op_resolver, const std::vector& inputs, + std::vector* outputs) { auto interpreter = absl::make_unique(); if (InterpreterBuilder(model, op_resolver)(&interpreter) != kTfLiteOk) { - return InternalError("Unable to create TfLite InterpreterBuilder"); + return absl::InternalError("Unable to create TfLite InterpreterBuilder"); } if (delegate && interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) { - return InternalError("Unable to modify TfLite graph with the delegate"); + return absl::InternalError( + "Unable to modify TfLite graph with the delegate"); } interpreter->SetNumThreads(1); if (interpreter->AllocateTensors() != kTfLiteOk) { - return InternalError("Unable to allocate TfLite tensors"); + return absl::InternalError("Unable to allocate TfLite tensors"); } for (int i = 0; i < inputs.size(); ++i) { DCHECK_EQ(interpreter->tensor(interpreter->inputs()[i])->type, @@ -57,10 +57,10 @@ Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, inputs[i].data.size() * sizeof(float)); } if (interpreter->Invoke() != kTfLiteOk) { - return InternalError("Unable to invoke TfLite interpreter"); + return absl::InternalError("Unable to invoke TfLite interpreter"); } if (!outputs || !outputs->empty()) { - return InternalError("Invalid outputs pointer"); + return absl::InternalError("Invalid outputs pointer"); } outputs->reserve(interpreter->outputs().size()); for (auto t : interpreter->outputs()) { @@ -69,7 +69,7 @@ Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, bhwc.id = t; // TODO(impjdi) Relax this condition to arbitrary batch size. if (out_tensor->dims->data[0] != 1) { - return InternalError("Batch dimension is expected to be 1"); + return absl::InternalError("Batch dimension is expected to be 1"); } bhwc.shape.b = out_tensor->dims->data[0]; switch (out_tensor->dims->size) { @@ -89,20 +89,21 @@ Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, bhwc.shape.c = out_tensor->dims->data[3]; break; default: - return InternalError("Unsupported dimensions size " + - std::to_string(out_tensor->dims->size)); + return absl::InternalError("Unsupported dimensions size " + + std::to_string(out_tensor->dims->size)); } bhwc.data = std::vector( out_tensor->data.f, out_tensor->data.f + out_tensor->bytes / sizeof(float)); outputs->push_back(bhwc); } - return OkStatus(); + return absl::OkStatus(); } -Status InterpreterInvoke(const ::tflite::Model* model, TfLiteDelegate* delegate, - const std::vector& inputs, - std::vector* outputs) { +absl::Status InterpreterInvoke(const ::tflite::Model* model, + TfLiteDelegate* delegate, + const std::vector& inputs, + std::vector* outputs) { ops::builtin::BuiltinOpResolver builtin_op_resolver; return InterpreterInvokeWithOpResolver(model, delegate, builtin_op_resolver, inputs, outputs); diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h index a38a5d1363a..ca2825b7563 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h +++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h @@ -31,18 +31,18 @@ namespace testing { // Runs Tensorflow Lite model using Tensorflow Lite with a delegate and // an appropriate operations resolver. If delegate is nullptr, inference will // be done only on CPU. -Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, - TfLiteDelegate* delegate, - const OpResolver& op_resolver, - const std::vector& inputs, - std::vector* outputs); +absl::Status InterpreterInvokeWithOpResolver( + const ::tflite::Model* model, TfLiteDelegate* delegate, + const OpResolver& op_resolver, const std::vector& inputs, + std::vector* outputs); // Runs Tensorflow Lite model using Tensorflow Lite with a delegate and // builtin operations resolver. If delegate is nullptr, inference will // be done only on CPU. -Status InterpreterInvoke(const ::tflite::Model* model, TfLiteDelegate* delegate, - const std::vector& inputs, - std::vector* outputs); +absl::Status InterpreterInvoke(const ::tflite::Model* model, + TfLiteDelegate* delegate, + const std::vector& inputs, + std::vector* outputs); } // namespace testing } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc index 872c4bcd903..0011cc24dfa 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc @@ -61,7 +61,7 @@ class AddQuantAdjustments : public NodeTransformation { // The tensor information should rename the same. Value>* adjusted_value = graph->NewValue(); adjusted_value->tensor = output_value->tensor; - Status status = + absl::Status status = graph->SetProducer(quant_and_dequant_node->id, adjusted_value->id); if (!status.ok()) { return {TransformStatus::INVALID, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc index 586c7a34a37..4efb98a6847 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -81,11 +81,11 @@ class MergeConvolutionWithAdd : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } - Status status = RemoveFollowingNode(graph, &add_node, &conv_node); + absl::Status status = RemoveFollowingNode(graph, &add_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove add node after convolution: " + - status.error_message()}; + std::string(status.message())}; } return {TransformStatus::APPLIED, ""}; } @@ -131,11 +131,11 @@ class MergeAddWithConvolution : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } - Status status = RemovePrecedingNode(graph, &add_node, &conv_node); + absl::Status status = RemovePrecedingNode(graph, &add_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove add node after convolution: " + - status.error_message()}; + std::string(status.message())}; } return {TransformStatus::APPLIED, ""}; } diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc index 6b106a4be62..055327d3534 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc @@ -74,11 +74,11 @@ class MergeConvolutionWithMul : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } - Status status = RemoveFollowingNode(graph, &mul_node, &conv_node); + absl::Status status = RemoveFollowingNode(graph, &mul_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove mul node after convolution: " + - status.error_message()}; + std::string(status.message())}; } return {TransformStatus::APPLIED, ""}; } @@ -134,11 +134,11 @@ class MergeMulWithConvolution : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } - Status status = RemovePrecedingNode(graph, &mul_node, &conv_node); + absl::Status status = RemovePrecedingNode(graph, &mul_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove mul node after convolution: " + - status.error_message()}; + std::string(status.message())}; } return {TransformStatus::APPLIED, ""}; } diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc index 5e98edac943..17aac83baf7 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc @@ -76,10 +76,10 @@ class MakePaddingFromZerosConcat : public NodeTransformation { "Padding for concat axis is unsupported: " + ToString(concat_attr.axis)}; } - Status status = RemovePrecedingNode(graph, dep, node); + absl::Status status = RemovePrecedingNode(graph, dep, node); if (!status.ok()) { - return {TransformStatus::INVALID, - "Unable to remove const node: " + status.error_message()}; + return {TransformStatus::INVALID, "Unable to remove const node: " + + std::string(status.message())}; } node->operation.attributes = pad_attr; node->operation.type = ToString(OperationType::PAD); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc index 5257ba44f0e..f1c56477834 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc @@ -72,7 +72,7 @@ class MatchDilatedConvolution : public SequenceTransformation { conv_node.operation.attributes = std::move(conv2d_attr); } - Status status = RemoveFollowingNode(graph, &bs_node, &conv_node); + absl::Status status = RemoveFollowingNode(graph, &bs_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove batch_to_space node after convolution."}; diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc index 5e2f1e17f54..23e99bc3305 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc @@ -62,11 +62,11 @@ class MergePaddingWith2DOperation : public SequenceTransformation { } Attr* node_attr = absl::any_cast(&op_node->operation.attributes); - Status status = RemovePrecedingNode(graph, pad_node, op_node); + absl::Status status = RemovePrecedingNode(graph, pad_node, op_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove Pad node with Operation node: " + - status.error_message()}; + std::string(status.message())}; } node_attr->padding.appended.h += pad_attr.appended.h; @@ -154,10 +154,10 @@ class MergePaddingWithAddOperation : public NodeTransformation { "Cannot remove padding when this broadcast/scalar ADD"}; } - Status status = RemovePrecedingNode(graph, node, add_node); + absl::Status status = RemovePrecedingNode(graph, node, add_node); if (!status.ok()) { return {TransformStatus::INVALID, - "Unable to remove Pad node " + status.error_message()}; + "Unable to remove Pad node " + std::string(status.message())}; } return {TransformStatus::APPLIED, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc index 64779990178..e80b244b34f 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc @@ -44,10 +44,10 @@ class RemoveOperation : public SequenceTransformation { if (!remove_predicate_(graph, op_node)) { return {TransformStatus::SKIPPED, ""}; } - Status status = RemoveFollowingNode(graph, op_node, prev_op_node); + absl::Status status = RemoveFollowingNode(graph, op_node, prev_op_node); if (!status.ok()) { return {TransformStatus::INVALID, - "Unable to remove a node: " + status.error_message()}; + "Unable to remove a node: " + std::string(status.message())}; } return {TransformStatus::APPLIED, ""}; } @@ -116,10 +116,10 @@ class RemoveIdentityReshape : public NodeTransformation { return {TransformStatus::SKIPPED, "Can not apply transformation when node output is graph output"}; } - Status status = RemoveOneInputOneOutputNode(graph, node); + absl::Status status = RemoveOneInputOneOutputNode(graph, node); if (!status.ok()) { return {TransformStatus::INVALID, - "Unable to remove a node: " + status.error_message()}; + "Unable to remove a node: " + std::string(status.message())}; } return {TransformStatus::APPLIED, "Removed reshape with input_shape == output_shape."}; diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc index d6d22aa6a62..d18e3726a1c 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc @@ -184,10 +184,9 @@ template std::vector GenerateWorkGroupSizes( WorkGroupSizeAlignment z_alignment); template -Status GenerateWorkGroupSizesAlignedToGrid(const T& grid, - const T& max_work_group_size, - const int max_work_group_invocations, - std::vector* work_groups) { +absl::Status GenerateWorkGroupSizesAlignedToGrid( + const T& grid, const T& max_work_group_size, + const int max_work_group_invocations, std::vector* work_groups) { auto alignment = WorkGroupSizeAlignment::PRECISE; *work_groups = GenerateWorkGroupSizes( grid, /*min_work_group_total_size = */ 32, max_work_group_invocations, @@ -197,16 +196,16 @@ Status GenerateWorkGroupSizesAlignedToGrid(const T& grid, AddCornerCases(grid, max_work_group_invocations, max_work_group_size, alignment, alignment, alignment, work_groups); } - return OkStatus(); + return absl::OkStatus(); } // Specializations of GenerateWorkGroupSizesAlignedToGrid for int3 and uint3 -template Status GenerateWorkGroupSizesAlignedToGrid( +template absl::Status GenerateWorkGroupSizesAlignedToGrid( const int3& grid, const int3& max_work_group_size, const int max_work_group_invocations, std::vector* work_groups); -template Status GenerateWorkGroupSizesAlignedToGrid( +template absl::Status GenerateWorkGroupSizesAlignedToGrid( const uint3& grid, const uint3& max_work_group_size, const int max_work_group_invocations, std::vector* work_groups); diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h index 80915ff5c95..75967cb04df 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h @@ -42,10 +42,9 @@ std::vector GenerateWorkGroupSizes( WorkGroupSizeAlignment y_alignment, WorkGroupSizeAlignment z_alignment); template -Status GenerateWorkGroupSizesAlignedToGrid(const T& grid, - const T& max_work_group_size, - const int max_work_group_invocations, - std::vector* work_groups); +absl::Status GenerateWorkGroupSizesAlignedToGrid( + const T& grid, const T& max_work_group_size, + const int max_work_group_invocations, std::vector* work_groups); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 452f81f536d..3451119c71d 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -70,8 +70,8 @@ class Delegate { options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default(); } - Status Prepare(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params) { + absl::Status Prepare(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { thread_id_prepare_ = std::this_thread::get_id(); // Extract TFLite delegate execution plan from the context and convert it @@ -98,9 +98,10 @@ class Delegate { std::unique_ptr builder; bool graph_is_destroyed; - Status status = InitializeOpenClApi(&graph, &builder, &graph_is_destroyed); + absl::Status status = + InitializeOpenClApi(&graph, &builder, &graph_is_destroyed); if (!status.ok()) { - context->ReportError(context, "%s", status.error_message().c_str()); + TF_LITE_KERNEL_LOG(context, std::string(status.message()).c_str()); context->ReportError(context, "Falling back to OpenGL"); // Graph need to be re-created because it is moved above. @@ -132,7 +133,7 @@ class Delegate { return builder->Build(&runner_); } - Status SetInputsAndOutputs(TfLiteContext* context) { + absl::Status SetInputsAndOutputs(TfLiteContext* context) { int i = 0; for (auto index : input_indices_) { RETURN_IF_ERROR( @@ -143,15 +144,15 @@ class Delegate { RETURN_IF_ERROR( runner_->SetOutputObject(i++, GetTensorObject(index, context))); } - return OkStatus(); + return absl::OkStatus(); } - Status Invoke(TfLiteContext* context) { + absl::Status Invoke(TfLiteContext* context) { if (thread_id_prepare_ != std::this_thread::get_id()) { TFLITE_LOG(tflite::TFLITE_LOG_WARNING, "GpuDelegate invoke thread != prepare thread"); if (enforce_same_thread_) { - return FailedPreconditionError( + return absl::FailedPreconditionError( "GpuDelegate must run on the same thread where it was " "initialized."); } @@ -178,9 +179,9 @@ class Delegate { TfLiteDelegate* tflite_delegate() { return &delegate_; } private: - Status InitializeOpenClApi(GraphFloat32* graph, - std::unique_ptr* builder, - bool* graph_is_destroyed) { + absl::Status InitializeOpenClApi(GraphFloat32* graph, + std::unique_ptr* builder, + bool* graph_is_destroyed) { *graph_is_destroyed = false; cl::InferenceEnvironmentOptions env_options; cl::InferenceEnvironmentProperties properties; @@ -207,11 +208,11 @@ class Delegate { options, std::move(*graph), builder)); TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Initialized OpenCL-based API."); - return OkStatus(); + return absl::OkStatus(); } - Status InitializeOpenGlApi(GraphFloat32* graph, - std::unique_ptr* builder) { + absl::Status InitializeOpenGlApi(GraphFloat32* graph, + std::unique_ptr* builder) { gl::InferenceEnvironmentOptions env_options; gl::InferenceEnvironmentProperties properties; RETURN_IF_ERROR( @@ -226,7 +227,7 @@ class Delegate { enforce_same_thread_ = true; TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Initialized OpenGL-based API."); - return OkStatus(); + return absl::OkStatus(); } TfLiteDelegate delegate_ = { @@ -269,7 +270,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = gpu_delegate->Prepare(context, params); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Init: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return nullptr; } return gpu_delegate; @@ -294,7 +295,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = GetDelegate(node)->Invoke(context); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Invoke: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/lite/delegates/gpu/gl/api.cc b/tensorflow/lite/delegates/gpu/gl/api.cc index f9adbf253c1..f50f3458a8f 100644 --- a/tensorflow/lite/delegates/gpu/gl/api.cc +++ b/tensorflow/lite/delegates/gpu/gl/api.cc @@ -58,20 +58,20 @@ class InferenceContextImpl : public InferenceContext { explicit InferenceContextImpl(std::unique_ptr runtime) : runtime_(std::move(runtime)) {} - Status Execute() final { + absl::Status Execute() final { std::lock_guard lock(guard_); if (state_ != InferenceContextState::NOT_STARTED) { - return FailedPreconditionError("InferenceContext is not reset"); + return absl::FailedPreconditionError("InferenceContext is not reset"); } state_ = InferenceContextState::IN_PROGRESS; return runtime_->Execute(); } - Status Reset() final { + absl::Status Reset() final { std::lock_guard lock(guard_); // TODO(akulik): should Reset not return Status? state_ = InferenceContextState::NOT_STARTED; - return OkStatus(); + return absl::OkStatus(); } RuntimeStats stats() const final { return runtime_->stats(); } @@ -94,10 +94,10 @@ class InferenceContextWithBatchImpl : public InferenceContext { refs_(std::move(refs)), runtime_(std::move(runtime)) {} - Status Execute() final { + absl::Status Execute() final { std::lock_guard lock(guard_); if (state_ != InferenceContextState::NOT_STARTED) { - return FailedPreconditionError("InferenceContext is not reset"); + return absl::FailedPreconditionError("InferenceContext is not reset"); } state_ = InferenceContextState::IN_PROGRESS; @@ -112,7 +112,7 @@ class InferenceContextWithBatchImpl : public InferenceContext { if (!buffer) continue; if (buffer->bytes_size() % byte_size) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Object ", id, " does not match expected byte size: ", byte_size)); } @@ -120,7 +120,7 @@ class InferenceContextWithBatchImpl : public InferenceContext { if (num_batches == 0) { num_batches = b; } else if (num_batches != b) { - return InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Object ", id, " size does not match expected batch size: ", b, " vs ", num_batches)); } @@ -135,7 +135,7 @@ class InferenceContextWithBatchImpl : public InferenceContext { if (buffer) { auto ref = refs_->FindBuffer(id); if (!ref) { - return InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("Reference to ", id, " is not found")); } RETURN_IF_ERROR(buffer->MakeView(b * byte_size, byte_size, ref)); @@ -143,14 +143,14 @@ class InferenceContextWithBatchImpl : public InferenceContext { } RETURN_IF_ERROR(runtime_->Execute()); } - return OkStatus(); + return absl::OkStatus(); } - Status Reset() final { + absl::Status Reset() final { std::lock_guard lock(guard_); state_ = InferenceContextState::NOT_STARTED; // TODO(akulik): should Reset not return Status? - return OkStatus(); + return absl::OkStatus(); } RuntimeStats stats() const final { return runtime_->stats(); } @@ -197,8 +197,8 @@ class CompiledModelImpl explicit CompiledModelImpl(const GpuInfo& gpu_info) : gpu_info_(gpu_info) {} // Called while compiling shaders from scratch - Status Add(const WorkgroupsCalculator& workgroup_calculator, - ShaderCode code) { + absl::Status Add(const WorkgroupsCalculator& workgroup_calculator, + ShaderCode code) { // Calculate workgroup size. uint3 workgroup_size = workgroup_calculator.Calculate(code); uint3 num_workgroups = IntegralDivideRoundUp(code.workload, workgroup_size); @@ -220,13 +220,13 @@ class CompiledModelImpl num_workgroups, shader_idx, }); - return OkStatus(); + return absl::OkStatus(); } // Store full shader and compile it if necessary. // Returns full_shader_index - Status AddFullShader(const std::string& partial_shader, - const uint3& workgroup_size, size_t* size) { + absl::Status AddFullShader(const std::string& partial_shader, + const uint3& workgroup_size, size_t* size) { std::string shader_src = GetShaderHeader(workgroup_size) + partial_shader; auto it = shader_to_index_.find(shader_src); if (it == shader_to_index_.end()) { @@ -239,10 +239,10 @@ class CompiledModelImpl } else { *size = it->second; } - return OkStatus(); + return absl::OkStatus(); } - Status NewRun( + absl::Status NewRun( const RuntimeOptions& options, const ObjectManager* objects, CommandQueue* command_queue, std::unique_ptr* inference_context) const final { @@ -273,15 +273,16 @@ class CompiledModelImpl *inference_context = absl::make_unique(std::move(runtime)); } - return OkStatus(); + return absl::OkStatus(); } #ifndef TFLITE_GPU_BINARY_RELEASE // Called on deserialization - Status OnProgram(const std::vector& parameters, - const std::vector& objects, - const uint3& workgroup_size, const uint3& num_workgroups, - size_t partial_shader_index) final { + absl::Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, + const uint3& num_workgroups, + size_t partial_shader_index) final { for (auto& object : objects) { if (IsRef(object)) { object_sizes_[GetRef(object)] = ByteSizeOf(object); @@ -298,10 +299,10 @@ class CompiledModelImpl num_workgroups, shader_idx, }); - return OkStatus(); + return absl::OkStatus(); } - Status Serialize( + absl::Status Serialize( std::vector* serialized_compiled_model) const final { SerializedCompiledModelBuilder builder; @@ -338,13 +339,13 @@ class CompiledModelImpl auto data = builder.Finalize(options); serialized_compiled_model->insert(serialized_compiled_model->end(), data.begin(), data.end()); - return OkStatus(); + return absl::OkStatus(); } - Status OnShader(absl::Span shader_src) final { + absl::Status OnShader(absl::Span shader_src) final { std::string source(shader_src.data(), shader_src.size()); partial_shaders_.push_back(source); - return OkStatus(); + return absl::OkStatus(); } void OnOptions(const CompiledModelOptions& options) final { @@ -371,45 +372,48 @@ class CompiledModelImpl }; } // namespace -Status Compile(const CompilationOptions& options, const GraphFloat32& model, - const std::unordered_set& tflite_graph_io, - const NodeShader& node_shader, - const WorkgroupsCalculator& workgroup_calculator, - std::unique_ptr* compiled_model) { +absl::Status Compile(const CompilationOptions& options, + const GraphFloat32& model, + const std::unordered_set& tflite_graph_io, + const NodeShader& node_shader, + const WorkgroupsCalculator& workgroup_calculator, + std::unique_ptr* compiled_model) { if (!IsBatchMatchesForAllValues(model)) { - return InvalidArgumentError("Only identical batch dimension is supported"); + return absl::InvalidArgumentError( + "Only identical batch dimension is supported"); } GpuInfo gpu_info; RETURN_IF_ERROR(RequestGpuInfo(&gpu_info)); if (!IsOpenGl31OrAbove(gpu_info)) { - return InternalError( + return absl::InternalError( "OpenGL ES 3.1 or above is required to use OpenGL inference."); } auto compiled_model_impl = absl::make_unique(gpu_info); compiled_model_impl->set_dynamic_batch(options.dynamic_batch); auto compiler = NewCompiler(&node_shader, &gpu_info, options); - RETURN_IF_ERROR( - compiler->Compile(model, tflite_graph_io, [&](ShaderCode code) -> Status { + RETURN_IF_ERROR(compiler->Compile( + model, tflite_graph_io, [&](ShaderCode code) -> absl::Status { return compiled_model_impl->Add(workgroup_calculator, std::move(code)); })); *compiled_model = std::move(compiled_model_impl); - return OkStatus(); + return absl::OkStatus(); } #ifndef TFLITE_GPU_BINARY_RELEASE -Status ReadSerializedModel(const std::vector& serialized_model, - std::unique_ptr* compiled_model) { +absl::Status ReadSerializedModel( + const std::vector& serialized_model, + std::unique_ptr* compiled_model) { GpuInfo gpu_info; RETURN_IF_ERROR(RequestGpuInfo(&gpu_info)); if (!IsOpenGl31OrAbove(gpu_info)) { - return InternalError( + return absl::InternalError( "OpenGL ES 3.1 or above is required to use OpenGL inference."); } auto compiled_model_impl = absl::make_unique(gpu_info); RETURN_IF_ERROR(DeserializeCompiledModel( absl::MakeConstSpan(serialized_model), compiled_model_impl.get())); *compiled_model = std::move(compiled_model_impl); - return OkStatus(); + return absl::OkStatus(); } #endif // TFLITE_GPU_BINARY_RELEASE diff --git a/tensorflow/lite/delegates/gpu/gl/api.h b/tensorflow/lite/delegates/gpu/gl/api.h index 78b277852d0..c37eb9b7772 100644 --- a/tensorflow/lite/delegates/gpu/gl/api.h +++ b/tensorflow/lite/delegates/gpu/gl/api.h @@ -51,7 +51,7 @@ class CompiledModel { // // NewRun call as well as subsequent calls to InferenceContext methods should // be done from the same EGL context. - virtual Status NewRun( + virtual absl::Status NewRun( const RuntimeOptions& options, const ObjectManager* objects, CommandQueue* command_queue, std::unique_ptr* inference_context) const = 0; @@ -59,23 +59,25 @@ class CompiledModel { #ifndef TFLITE_GPU_BINARY_RELEASE // Serializes compiled model to a string. // @return true if serialization finished successfully. - virtual Status Serialize( + virtual absl::Status Serialize( std::vector* serialized_compiled_model) const = 0; #endif // TFLITE_GPU_BINARY_RELEASE }; // Turns the given model into "compiled" form that is suitable for inference. -Status Compile(const CompilationOptions& options, const GraphFloat32& model, - const std::unordered_set& tflite_graph_io, - const NodeShader& node_shader, - const WorkgroupsCalculator& workgroup_calculator, - std::unique_ptr* compiled_model); +absl::Status Compile(const CompilationOptions& options, + const GraphFloat32& model, + const std::unordered_set& tflite_graph_io, + const NodeShader& node_shader, + const WorkgroupsCalculator& workgroup_calculator, + std::unique_ptr* compiled_model); #ifndef TFLITE_GPU_BINARY_RELEASE // Reads serialized representation previously created with // CompiledModel::Serialize call. -Status ReadSerializedModel(const std::vector& serialized_model, - std::unique_ptr* compiled_model); +absl::Status ReadSerializedModel( + const std::vector& serialized_model, + std::unique_ptr* compiled_model); #endif // TFLITE_GPU_BINARY_RELEASE // Encapsulates everything needed for one or more inference executions done @@ -89,13 +91,13 @@ class InferenceContext { virtual RuntimeStats stats() const = 0; // Executes inference. - virtual Status Execute() = 0; + virtual absl::Status Execute() = 0; // Asks context to reset it for another round. Keep in mind that does not // affect inputs nor outputs which are not cleared, so it is possible to // re-use them. // It is an error to call Reset while previous run is still in progress. - virtual Status Reset() = 0; + virtual absl::Status Reset() = 0; }; } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/api2.cc b/tensorflow/lite/delegates/gpu/gl/api2.cc index 68bfa42411f..64e301338e1 100644 --- a/tensorflow/lite/delegates/gpu/gl/api2.cc +++ b/tensorflow/lite/delegates/gpu/gl/api2.cc @@ -50,16 +50,16 @@ std::string GetShaderHeader(uint3 localsize) { } // Wraps given SSBO into GlBuffer object that does not have ownership. -Status WrapSSBO(OpenGlBuffer ssbo, GlBuffer* buffer) { +absl::Status WrapSSBO(OpenGlBuffer ssbo, GlBuffer* buffer) { int64_t size_bytes; RETURN_IF_ERROR(GetSSBOSize(ssbo.id, &size_bytes)); *buffer = GlBuffer(GL_SHADER_STORAGE_BUFFER, ssbo.id, size_bytes, 0, false); - return OkStatus(); + return absl::OkStatus(); } -Status MaybeAllocateGlBuffer(const TensorObjectDef& def, GlBuffer* ssbo) { +absl::Status MaybeAllocateGlBuffer(const TensorObjectDef& def, GlBuffer* ssbo) { if (def.object_def.object_type != gpu::ObjectType::OPENGL_SSBO) { - return InvalidArgumentError("Tensor object is not GL SSBO"); + return absl::InvalidArgumentError("Tensor object is not GL SSBO"); } const uint32_t num_elements = NumElements(def); switch (def.object_def.data_type) { @@ -68,10 +68,10 @@ Status MaybeAllocateGlBuffer(const TensorObjectDef& def, GlBuffer* ssbo) { case DataType::FLOAT16: return CreateReadWriteShaderStorageBuffer(num_elements, ssbo); default: - return InternalError( + return absl::InternalError( "Unable to create new GL SSBO. Unsupported data type."); } - return OkStatus(); + return absl::OkStatus(); } // Does one-step conversion between internal and external objects. @@ -89,58 +89,59 @@ class DefaultTensorTie : public TensorTie { converter_builder.IsSupported(def.external_def, def.internal_def); } - static Status New(const TensorTieDef& def, - TensorObjectConverterBuilder* converter_builder, - ObjectManager* objects, std::unique_ptr* tie) { + static absl::Status New(const TensorTieDef& def, + TensorObjectConverterBuilder* converter_builder, + ObjectManager* objects, + std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def, TensorObject{}, objects); RETURN_IF_ERROR(tie_impl->Init(converter_builder)); *tie = std::move(tie_impl); - return OkStatus(); + return absl::OkStatus(); } - static Status New(const TensorTieDef& def, - TensorObjectConverterBuilder* converter_builder, - TensorObject internal_object, - std::unique_ptr* tie) { + static absl::Status New(const TensorTieDef& def, + TensorObjectConverterBuilder* converter_builder, + TensorObject internal_object, + std::unique_ptr* tie) { if (!IsValid(def.internal_def, internal_object)) { - return InternalError("Internal object does not match definition."); + return absl::InternalError("Internal object does not match definition."); } auto tie_impl = absl::make_unique(def, internal_object, nullptr); RETURN_IF_ERROR(tie_impl->Init(converter_builder)); *tie = std::move(tie_impl); - return OkStatus(); + return absl::OkStatus(); } - Status CopyToExternalObject() final { + absl::Status CopyToExternalObject() final { if (!converter_to_) { - return OkStatus(); + return absl::OkStatus(); } return converter_to_->Convert(internal_obj_, GetExternalObject()); } - Status CopyFromExternalObject() final { + absl::Status CopyFromExternalObject() final { if (!converter_from_) { - return OkStatus(); + return absl::OkStatus(); } return converter_from_->Convert(GetExternalObject(), internal_obj_); } - Status SetExternalObject(TensorObject obj) final { + absl::Status SetExternalObject(TensorObject obj) final { if (!def().external_def.object_def.user_provided) { - return InvalidArgumentError("External object is read-only"); + return absl::InvalidArgumentError("External object is read-only"); } if (!IsValid(def().external_def, obj)) { - return InvalidArgumentError("Given object is not valid"); + return absl::InvalidArgumentError("Given object is not valid"); } // TODO(akulik): external object should propagate to internal. if (IsSameDef()) { - return UnimplementedError("Not supported"); + return absl::UnimplementedError("Not supported"); } external_obj_ = obj; - return OkStatus(); + return absl::OkStatus(); } TensorObject GetExternalObject() final { return external_obj_; } @@ -159,7 +160,8 @@ class DefaultTensorTie : public TensorTie { internal_def.data_layout == DataLayout::DHWC4 && def().external_def.dimensions.c == 4); } - Status Init(TensorObjectConverterBuilder* converter_builder) { + + absl::Status Init(TensorObjectConverterBuilder* converter_builder) { // First check is an object is user provided. const auto& external_def = def().external_def.object_def; @@ -174,7 +176,7 @@ class DefaultTensorTie : public TensorTie { if (external_def.user_provided) { if (is_same_def) { - return OkStatus(); + return absl::OkStatus(); } // Object is provided by a user, but runtime expects different object // type. Therefore, we have to allocate internal object and convert. @@ -186,19 +188,19 @@ class DefaultTensorTie : public TensorTie { // Object is NOT provided by a user, but it matches definition expected // by runtime. Conversion is not needed. external_obj_ = internal_obj_; - return OkStatus(); + return absl::OkStatus(); } // Object is NOT provided by a user. return MaybeAllocateExternalObject(); } - return OkStatus(); + return absl::OkStatus(); } - Status MaybeAllocateInternalObject() { + absl::Status MaybeAllocateInternalObject() { const TensorObjectDef& d = def().internal_def; if (d.object_def.user_provided) { - return OkStatus(); + return absl::OkStatus(); } switch (d.object_def.object_type) { case gpu::ObjectType::OPENGL_SSBO: { @@ -210,12 +212,12 @@ class DefaultTensorTie : public TensorTie { } // TODO(akulik): support textures as internal object when compiler permits default: - return InternalError("Unexpected object type"); + return absl::InternalError("Unexpected object type"); } - return OkStatus(); + return absl::OkStatus(); } - Status MaybeAllocateExternalObject() { + absl::Status MaybeAllocateExternalObject() { const TensorObjectDef& d = def().external_def; switch (d.object_def.object_type) { case gpu::ObjectType::CPU_MEMORY: { @@ -232,9 +234,9 @@ class DefaultTensorTie : public TensorTie { break; } default: - return InternalError("Unexpected object type"); + return absl::InternalError("Unexpected object type"); } - return OkStatus(); + return absl::OkStatus(); } ObjectManager* objects_; @@ -266,26 +268,27 @@ class TwoStepTensorTie : public TensorTie { DefaultTensorTie::IsSupported(defs.second, converter_builder); } - static Status New(const TensorTieDef& def, - TensorObjectConverterBuilder* converter_builder, - ObjectManager* objects, std::unique_ptr* tie) { + static absl::Status New(const TensorTieDef& def, + TensorObjectConverterBuilder* converter_builder, + ObjectManager* objects, + std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def); RETURN_IF_ERROR(tie_impl->Init(converter_builder, objects)); *tie = std::move(tie_impl); - return OkStatus(); + return absl::OkStatus(); } - Status CopyToExternalObject() final { + absl::Status CopyToExternalObject() final { RETURN_IF_ERROR(inner_tie_->CopyToExternalObject()); return outer_tie_->CopyToExternalObject(); } - Status CopyFromExternalObject() final { + absl::Status CopyFromExternalObject() final { RETURN_IF_ERROR(outer_tie_->CopyFromExternalObject()); return inner_tie_->CopyFromExternalObject(); } - Status SetExternalObject(TensorObject obj) final { + absl::Status SetExternalObject(TensorObject obj) final { return outer_tie_->SetExternalObject(obj); } @@ -321,8 +324,8 @@ class TwoStepTensorTie : public TensorTie { return std::make_pair(outer_def, inner_def); } - Status Init(TensorObjectConverterBuilder* converter_builder, - ObjectManager* objects) { + absl::Status Init(TensorObjectConverterBuilder* converter_builder, + ObjectManager* objects) { auto defs = MakeOuterInnerDefs(def()); RETURN_IF_ERROR(DefaultTensorTie::New(defs.second, converter_builder, objects, &inner_tie_)); @@ -346,8 +349,8 @@ class TensorTieFactory { TwoStepTensorTie::IsSupported(def, *converter_builder_)); } - Status NewTensorTie(const TensorTieDef& def, ObjectManager* objects, - std::unique_ptr* tie) { + absl::Status NewTensorTie(const TensorTieDef& def, ObjectManager* objects, + std::unique_ptr* tie) { auto converter = converter_builder_.get(); if (DefaultTensorTie::IsSupported(def, *converter)) { return DefaultTensorTie::New(def, converter, objects, tie); @@ -355,7 +358,7 @@ class TensorTieFactory { if (TwoStepTensorTie::IsSupported(def, *converter)) { return TwoStepTensorTie::New(def, converter, objects, tie); } - return UnimplementedError("Unsupported tensor tie definition."); + return absl::UnimplementedError("Unsupported tensor tie definition."); } private: @@ -368,16 +371,16 @@ class InferenceRunnerImpl : public InferenceRunner { std::unique_ptr objects) : runtime_(std::move(runtime)), objects_(std::move(objects)) {} - Status Initialize(const std::vector& inputs, - const std::vector& outputs, - TensorTieFactory* tie_factory) { + absl::Status Initialize(const std::vector& inputs, + const std::vector& outputs, + TensorTieFactory* tie_factory) { RETURN_IF_ERROR(LinkTensors(inputs, tie_factory, &inputs_)); RETURN_IF_ERROR(LinkTensors(outputs, tie_factory, &outputs_)); for (const auto& def : outputs) { output_to_cpu_ |= def.external_def.object_def.object_type == gpu::ObjectType::CPU_MEMORY; } - return OkStatus(); + return absl::OkStatus(); } std::vector inputs() const override { @@ -388,37 +391,37 @@ class InferenceRunnerImpl : public InferenceRunner { return GetExternalDefinitions(outputs_); } - Status GetInputObject(int index, TensorObject* object) override { + absl::Status GetInputObject(int index, TensorObject* object) override { if (index < 0 || index >= inputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } *object = inputs_[index]->GetExternalObject(); - return OkStatus(); + return absl::OkStatus(); } - Status GetOutputObject(int index, TensorObject* object) override { + absl::Status GetOutputObject(int index, TensorObject* object) override { if (index < 0 || index >= outputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } *object = outputs_[index]->GetExternalObject(); - return OkStatus(); + return absl::OkStatus(); } - Status SetInputObject(int index, TensorObject object) override { + absl::Status SetInputObject(int index, TensorObject object) override { if (index < 0 || index >= inputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } return inputs_[index]->SetExternalObject(object); } - Status SetOutputObject(int index, TensorObject object) override { + absl::Status SetOutputObject(int index, TensorObject object) override { if (index < 0 || index >= outputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } return outputs_[index]->SetExternalObject(object); } - Status Run() override { + absl::Status Run() override { for (auto& obj : inputs_) { RETURN_IF_ERROR(obj->CopyFromExternalObject()); } @@ -430,20 +433,20 @@ class InferenceRunnerImpl : public InferenceRunner { if (output_to_cpu_) { RETURN_IF_ERROR(runtime_->command_queue()->WaitForCompletion()); } - return OkStatus(); + return absl::OkStatus(); } private: - Status LinkTensors(const std::vector& defs, - TensorTieFactory* tie_factory, - std::vector>* objects) { + absl::Status LinkTensors(const std::vector& defs, + TensorTieFactory* tie_factory, + std::vector>* objects) { objects->reserve(defs.size()); for (auto& def : defs) { std::unique_ptr object; RETURN_IF_ERROR(tie_factory->NewTensorTie(def, objects_.get(), &object)); objects->push_back(std::move(object)); } - return OkStatus(); + return absl::OkStatus(); } static std::vector GetExternalDefinitions( @@ -474,10 +477,10 @@ class InferenceBuilderImpl : public InferenceBuilder { gpu_info_(gpu_info), tie_factory_(env_options_) {} - Status Initialize() { + absl::Status Initialize() { inputs_ = LinkTensors(graph_.inputs()); outputs_ = LinkTensors(graph_.outputs()); - return OkStatus(); + return absl::OkStatus(); } std::vector inputs() const final { @@ -488,40 +491,42 @@ class InferenceBuilderImpl : public InferenceBuilder { return GetExternalDefinitions(outputs_); } - Status SetInputShape(int index, const Dimensions& dimensions) final { + absl::Status SetInputShape(int index, const Dimensions& dimensions) final { if (index < 0 || index >= inputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } - return UnimplementedError("Changing input shapes is not supported"); + return absl::UnimplementedError("Changing input shapes is not supported"); } - Status SetInputObjectDef(int index, ObjectDef new_def) final { + absl::Status SetInputObjectDef(int index, ObjectDef new_def) final { if (index < 0 || index >= inputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } auto def = inputs_[index]; def.external_def.object_def = new_def; if (!tie_factory_.IsSupported(def)) { - return InvalidArgumentError("New object definition is not supported."); + return absl::InvalidArgumentError( + "New object definition is not supported."); } inputs_[index] = def; - return OkStatus(); + return absl::OkStatus(); } - Status SetOutputObjectDef(int index, ObjectDef new_def) final { + absl::Status SetOutputObjectDef(int index, ObjectDef new_def) final { if (index < 0 || index >= outputs_.size()) { - return OutOfRangeError("Index is out of range"); + return absl::OutOfRangeError("Index is out of range"); } auto def = outputs_[index]; def.external_def.object_def = new_def; if (!tie_factory_.IsSupported(def)) { - return InvalidArgumentError("New object definition is not supported."); + return absl::InvalidArgumentError( + "New object definition is not supported."); } outputs_[index] = def; - return OkStatus(); + return absl::OkStatus(); } - Status Build(std::unique_ptr* runner) final { + absl::Status Build(std::unique_ptr* runner) final { auto kernels = NewNodeShaderRegistry(); CompilationOptions compiler_options; compiler_options.allow_precision_loss = @@ -551,7 +556,7 @@ class InferenceBuilderImpl : public InferenceBuilder { std::move(runtime), std::move(external_objects)); RETURN_IF_ERROR(runner_impl->Initialize(inputs_, outputs_, &tie_factory_)); RETURN_IF_ERROR( - compiler->Compile(graph_, {}, [&](ShaderCode code) -> Status { + compiler->Compile(graph_, {}, [&](ShaderCode code) -> absl::Status { auto workgroup = workgroup_calculator->Calculate(code); size_t shader_index; std::string shader_src = @@ -574,7 +579,7 @@ class InferenceBuilderImpl : public InferenceBuilder { })); RETURN_IF_ERROR(runtime_ptr->PrepareForExecution()); *runner = std::move(runner_impl); - return OkStatus(); + return absl::OkStatus(); } private: @@ -624,39 +629,39 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { explicit InferenceEnvironmentImpl(const InferenceEnvironmentOptions& options) : env_options_(options) {} - Status Init() { + absl::Status Init() { RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&egl_env_)); RETURN_IF_ERROR(RequestGpuInfo(&gpu_info_)); properties_.is_opengl_available = IsOpenGl31OrAbove(gpu_info_); if (!properties_.is_opengl_available) { - return InternalError( + return absl::InternalError( "OpenGL ES 3.1 or above is required to use OpenGL inference."); } if (!env_options_.queue) { queue_ = NewCommandQueue(gpu_info_); env_options_.queue = queue_.get(); } - return OkStatus(); + return absl::OkStatus(); } - Status NewInferenceBuilder(GraphFloat32&& model, - const InferenceOptions& options, - std::unique_ptr* builder) final { + absl::Status NewInferenceBuilder( + GraphFloat32&& model, const InferenceOptions& options, + std::unique_ptr* builder) final { if (!IsValid(options)) { - return InvalidArgumentError("InferenceOptions are invalid."); + return absl::InvalidArgumentError("InferenceOptions are invalid."); } InferenceOptions resolved_options = options; ResolveAutoPriority(&resolved_options); if (!IsBatchMatchesForAllValues(model)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Only identical batch dimension is supported"); } auto builder_impl = absl::make_unique( env_options_, resolved_options, std::move(model), &gpu_info_); RETURN_IF_ERROR(builder_impl->Initialize()); *builder = std::move(builder_impl); - return OkStatus(); + return absl::OkStatus(); } const InferenceEnvironmentProperties& properties() const { @@ -673,18 +678,18 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { } // namespace -Status NewInferenceEnvironment( +absl::Status NewInferenceEnvironment( const InferenceEnvironmentOptions& options, std::unique_ptr* environment, InferenceEnvironmentProperties* properties) { auto env_impl = absl::make_unique(options); - Status status = env_impl->Init(); + absl::Status status = env_impl->Init(); if (properties) { *properties = env_impl->properties(); } RETURN_IF_ERROR(status); *environment = std::move(env_impl); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/api2.h b/tensorflow/lite/delegates/gpu/gl/api2.h index ac58fef0ffa..05062064dd6 100644 --- a/tensorflow/lite/delegates/gpu/gl/api2.h +++ b/tensorflow/lite/delegates/gpu/gl/api2.h @@ -41,7 +41,7 @@ class InferenceEnvironment { public: virtual ~InferenceEnvironment() = default; - virtual Status NewInferenceBuilder( + virtual absl::Status NewInferenceBuilder( GraphFloat32&& model, const InferenceOptions& options, std::unique_ptr* builder) = 0; }; @@ -52,7 +52,7 @@ struct InferenceEnvironmentOptions { // Creates a new OpenGL environment that needs to stay around until all // inference runners are destroyed. -Status NewInferenceEnvironment( +absl::Status NewInferenceEnvironment( const InferenceEnvironmentOptions& options, std::unique_ptr* environment, InferenceEnvironmentProperties* properties /* optional */); diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.cc b/tensorflow/lite/delegates/gpu/gl/command_queue.cc index 87823761127..8500a50859c 100644 --- a/tensorflow/lite/delegates/gpu/gl/command_queue.cc +++ b/tensorflow/lite/delegates/gpu/gl/command_queue.cc @@ -30,17 +30,18 @@ namespace { class DefaultCommandQueue : public CommandQueue { public: - Status Dispatch(const GlProgram& program, const uint3& workgroups) override { + absl::Status Dispatch(const GlProgram& program, + const uint3& workgroups) override { RETURN_IF_ERROR(program.Dispatch(workgroups)); return TFLITE_GPU_CALL_GL(glMemoryBarrier, GL_ALL_BARRIER_BITS); } - Status WaitForCompletion() override { + absl::Status WaitForCompletion() override { // TODO(akulik): Maybe let the user choose which wait method to use. return GlActiveSyncWait(); } - Status Flush() override { return OkStatus(); } + absl::Status Flush() override { return absl::OkStatus(); } }; // On Adreno do flush periodically as this affects performance. Command queue @@ -54,26 +55,27 @@ class AdrenoCommandQueue : public DefaultCommandQueue { explicit AdrenoCommandQueue(int flush_every_n) : flush_every_n_(flush_every_n) {} - Status Dispatch(const GlProgram& program, const uint3& workgroups) final { + absl::Status Dispatch(const GlProgram& program, + const uint3& workgroups) final { RETURN_IF_ERROR(DefaultCommandQueue::Dispatch(program, workgroups)); if ((++program_counter_ % flush_every_n_) == 0) { glFlush(); } - return OkStatus(); + return absl::OkStatus(); } - Status WaitForCompletion() override { + absl::Status WaitForCompletion() override { program_counter_ = 0; return DefaultCommandQueue::WaitForCompletion(); } - Status Flush() final { + absl::Status Flush() final { // Flush exactly once after the last dispatch. if (program_counter_ != 0) { program_counter_ = 0; glFlush(); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.h b/tensorflow/lite/delegates/gpu/gl/command_queue.h index 6695852fc86..d9bff04a837 100644 --- a/tensorflow/lite/delegates/gpu/gl/command_queue.h +++ b/tensorflow/lite/delegates/gpu/gl/command_queue.h @@ -35,14 +35,14 @@ class CommandQueue { virtual ~CommandQueue() = default; // Dispatches a program. It may or may not call glFlush. - virtual Status Dispatch(const GlProgram& program, - const uint3& workgroups) = 0; + virtual absl::Status Dispatch(const GlProgram& program, + const uint3& workgroups) = 0; // Called at the end of dispatching of all programs. - virtual Status Flush() = 0; + virtual absl::Status Flush() = 0; // Waits until all programs dispatched prior this call are completed. - virtual Status WaitForCompletion() = 0; + virtual absl::Status WaitForCompletion() = 0; }; // By default memory barrier is inserted after every dispatch. diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.cc b/tensorflow/lite/delegates/gpu/gl/compiler.cc index cef8139fe1e..a5f5b35f2d2 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler.cc @@ -102,9 +102,9 @@ class CompilerImpl : public Compiler { } } - Status Compile(const GraphFloat32& graph, - const std::unordered_set& tflite_graph_io, - const ShaderCodeCallback& callback) final { + absl::Status Compile(const GraphFloat32& graph, + const std::unordered_set& tflite_graph_io, + const ShaderCodeCallback& callback) final { // It is important to have ids in a compiled graph identical to the given // graph. RETURN_IF_ERROR(graph.MakeExactCopy(&compiled_graph_)); @@ -129,22 +129,22 @@ class CompilerImpl : public Compiler { if (options_.fuse_operations) { FuseAutoOutputWithInline fuse_inline; if (!transformer.Apply("fuse_auto_with_inline", &fuse_inline)) { - return InternalError("fuse_auto_with_inline failed"); + return absl::InternalError("fuse_auto_with_inline failed"); } FuseInplaceUpdate fuse_inplace; if (!transformer.Apply("fuse_inplace_update", &fuse_inplace)) { - return InternalError("fuse_inplace failed"); + return absl::InternalError("fuse_inplace failed"); } if (options_.auto_input_fusion) { FuseAutoInput fuse_auto_input; if (!transformer.Apply("fuse_auto_input", &fuse_auto_input)) { - return InternalError("fuse_auto_input failed"); + return absl::InternalError("fuse_auto_input failed"); } } } RemoveUnusedInplaceUpdates remove_inplace_updates; if (!transformer.Apply("remove_inplace_updates", &remove_inplace_updates)) { - return InternalError("remove_inplace_updates failed"); + return absl::InternalError("remove_inplace_updates failed"); } // Prepare internal objects. @@ -176,7 +176,7 @@ class CompilerImpl : public Compiler { auto shape = outputs[0]->tensor.shape; for (auto output : outputs) { if (shape != output->tensor.shape) { - return FailedPreconditionError( + return absl::FailedPreconditionError( "Workload uint3() requires all output sizes to match"); } } @@ -274,7 +274,7 @@ class CompilerImpl : public Compiler { RETURN_IF_ERROR(codegen.Build(std::move(attr), &shader_code)); RETURN_IF_ERROR(callback(std::move(shader_code))); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.h b/tensorflow/lite/delegates/gpu/gl/compiler.h index e8b434869e2..7769890b769 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler.h @@ -31,7 +31,7 @@ namespace tflite { namespace gpu { namespace gl { -using ShaderCodeCallback = std::function; +using ShaderCodeCallback = std::function; class Compiler { public: @@ -40,9 +40,9 @@ class Compiler { // Goes over a graph and generates OpenGL shaders for the given graph. // Callback is called for every generated shader. Callback may execute shaders // as they come or store them elsewhere to execute later. - virtual Status Compile(const GraphFloat32& graph, - const std::unordered_set& tflite_graph_io, - const ShaderCodeCallback& callback) = 0; + virtual absl::Status Compile(const GraphFloat32& graph, + const std::unordered_set& tflite_graph_io, + const ShaderCodeCallback& callback) = 0; }; std::unique_ptr NewCompiler( diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc index 923b0bd47ec..4048a07d087 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc @@ -25,8 +25,8 @@ namespace tflite { namespace gpu { namespace gl { -Status MergeCode(CompiledNodeAttributes* attr, - CompiledNodeAttributes* merged_attr) { +absl::Status MergeCode(CompiledNodeAttributes* attr, + CompiledNodeAttributes* merged_attr) { // build a map of known names. std::unordered_set known_names; for (const auto& parameter : merged_attr->code.parameters) { @@ -56,7 +56,7 @@ Status MergeCode(CompiledNodeAttributes* attr, std::back_inserter(merged_attr->code.parameters)); std::move(attr->node_indices.begin(), attr->node_indices.end(), std::back_inserter(merged_attr->node_indices)); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h index d41a734f4e2..8d36504d0c3 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h @@ -42,8 +42,8 @@ struct CompiledNodeAttributes { // Moves all code objects, parameters and node indices from attr to merged_attr. // Parameters and objects in attr.code.source_code are renamed to ensure // uniqueness. -Status MergeCode(CompiledNodeAttributes* attr, - CompiledNodeAttributes* merged_attr); +absl::Status MergeCode(CompiledNodeAttributes* attr, + CompiledNodeAttributes* merged_attr); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc index 01ea764b0b0..55e6d94eb7d 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc @@ -46,8 +46,8 @@ absl::string_view PastSubstr(absl::string_view s, absl::string_view subs) { } // namespace -Status TextPreprocessor::Rewrite(const std::string& input, - std::string* output) { +absl::Status TextPreprocessor::Rewrite(const std::string& input, + std::string* output) { absl::string_view s = input; std::string result; while (true) { @@ -57,7 +57,7 @@ Status TextPreprocessor::Rewrite(const std::string& input, break; } if (inline_block.size() == 1) { - return NotFoundError("Unable to find end of inline block"); + return absl::NotFoundError("Unable to find end of inline block"); } s = PastSubstr(s, inline_block); bool processed = false; @@ -74,20 +74,20 @@ Status TextPreprocessor::Rewrite(const std::string& input, processed = true; break; case RewriteStatus::ERROR: - return InternalError(absl::StrCat("Error while rewriting '", - inline_block, "': ", result)); + return absl::InternalError(absl::StrCat("Error while rewriting '", + inline_block, "': ", result)); } } if (!processed) { if (!keep_unknown_rewrites_) { - return NotFoundError(absl::StrCat("Didn't find inline rewrite for '", - inline_block, "'")); + return absl::NotFoundError(absl::StrCat( + "Didn't find inline rewrite for '", inline_block, "'")); } absl::StrAppend(&result, inline_block); } } *output = std::move(result); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h index f01698e784f..29fad004d3c 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h @@ -58,7 +58,7 @@ class TextPreprocessor { } // input and output may point to the same object. - Status Rewrite(const std::string& input, std::string* output); + absl::Status Rewrite(const std::string& input, std::string* output); private: const char inline_delimiter_; diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc index 674002b74b2..956f6afae28 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc @@ -174,17 +174,17 @@ class ObjectRewriter : public InlineRewrite { } // namespace -Status Rename(const NameFunctor& name_func, GeneratedCode* code) { +absl::Status Rename(const NameFunctor& name_func, GeneratedCode* code) { VariableRewriter variable_rewriter("$", name_func); ObjectRewriter object_rewriter("$", name_func); for (auto&& uniform_parameter : code->parameters) { if (!variable_rewriter.AddVariable(std::move(uniform_parameter))) { - return InternalError("Variable name already exists"); + return absl::InternalError("Variable name already exists"); } } for (auto&& object : code->objects) { if (!object_rewriter.AddObject(object.first, std::move(object.second))) { - return InternalError("Object name already exists"); + return absl::InternalError("Object name already exists"); } } TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true); @@ -195,7 +195,7 @@ Status Rename(const NameFunctor& name_func, GeneratedCode* code) { code->source_code = source_code; code->parameters = variable_rewriter.GetUniformParameters(); code->objects = object_rewriter.GetObjects(); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.h b/tensorflow/lite/delegates/gpu/gl/compiler/rename.h index 06921dbe3da..e38ade1a3b9 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/rename.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.h @@ -32,7 +32,7 @@ using NameFunctor = std::function; // Rewrites source code, objects and parameters with the new names supplied // by the given functor. -Status Rename(const NameFunctor& name_func, GeneratedCode* code); +absl::Status Rename(const NameFunctor& name_func, GeneratedCode* code); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc index e6100919097..e473f9e77ff 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc @@ -32,8 +32,8 @@ ShaderCodegen::ShaderCodegen(const CompilationOptions& options, const GpuInfo& gpu_info) : options_(options), gpu_type_(gpu_info.type) {} -Status ShaderCodegen::Build(CompiledNodeAttributes attr, - ShaderCode* shader_code) const { +absl::Status ShaderCodegen::Build(CompiledNodeAttributes attr, + ShaderCode* shader_code) const { VariableAccessor variable_accessor(options_.inline_parameters, options_.vulkan_support); ObjectAccessor object_accessor(gpu_type_ == GpuType::MALI, @@ -41,18 +41,18 @@ Status ShaderCodegen::Build(CompiledNodeAttributes attr, const auto add_object = [&](const std::string& name, Object&& object) { if (!object_accessor.AddObject(name, std::forward(object))) { - return AlreadyExistsError(absl::StrCat("Object \"", name, "\"")); + return absl::AlreadyExistsError(absl::StrCat("Object \"", name, "\"")); } - return OkStatus(); + return absl::OkStatus(); }; const auto add_uniform_parameter = [&](Variable&& variable) { const std::string name = variable.name; if (!variable_accessor.AddUniformParameter(std::move(variable))) { - return AlreadyExistsError( + return absl::AlreadyExistsError( absl::StrCat("Uniform parameter \"", name, "\"")); } - return OkStatus(); + return absl::OkStatus(); }; for (auto&& object : attr.code.objects) { @@ -62,7 +62,8 @@ Status ShaderCodegen::Build(CompiledNodeAttributes attr, for (auto&& variable : attr.code.shared_variables) { const std::string name = variable.name; if (!variable_accessor.AddSharedVariable(std::move(variable))) { - return AlreadyExistsError(absl::StrCat("Shared variable \"", name, "\"")); + return absl::AlreadyExistsError( + absl::StrCat("Shared variable \"", name, "\"")); } } @@ -169,7 +170,7 @@ Status ShaderCodegen::Build(CompiledNodeAttributes attr, ShaderCode(variable_accessor.GetUniformParameters(), object_accessor.GetObjects(), attr.code.workload, attr.code.workgroup, partial_source_code, attr.node_indices); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h index c4f09a3b6b9..12d2708d221 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h @@ -39,7 +39,8 @@ class ShaderCodegen { ShaderCodegen(const CompilationOptions& options, const GpuInfo& gpu_info); // Builds final program representation. - Status Build(CompiledNodeAttributes attr, ShaderCode* shader_code) const; + absl::Status Build(CompiledNodeAttributes attr, + ShaderCode* shader_code) const; private: const CompilationOptions options_; diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc index 3b37ba26058..fc86b0f3cb1 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc @@ -31,7 +31,7 @@ namespace tflite { namespace gpu { namespace gl { -Status ConverterBhwcToPhwc4::Create(ConverterBhwcToPhwc4* converter) { +absl::Status ConverterBhwcToPhwc4::Create(ConverterBhwcToPhwc4* converter) { uint3 workgroup_size = uint3(4, 4, 4); std::string shader_source = GetShaderHeader(workgroup_size) + R"( layout(std430) buffer; @@ -69,22 +69,24 @@ Status ConverterBhwcToPhwc4::Create(ConverterBhwcToPhwc4* converter) { GlProgram program; RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); *converter = ConverterBhwcToPhwc4(std::move(program), workgroup_size); - return OkStatus(); + return absl::OkStatus(); } -Status ConverterBhwcToPhwc4::Convert(const BHWC& shape, const GlBuffer& source, - CommandQueue* command_queue, - GlBuffer* destination) { +absl::Status ConverterBhwcToPhwc4::Convert(const BHWC& shape, + const GlBuffer& source, + CommandQueue* command_queue, + GlBuffer* destination) { if (source.bytes_size() < BytesForBHWC(shape)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "BhwcToPhwc4: Input data size does not match expected size."); } if (destination->bytes_size() < BytesForPHWC4(shape)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "BhwcToPhwc4: output data size does not match expected size."); } if (shape.b != 1) { - return UnimplementedError("BhwcToPhwc4: Batch size is not equal to 1."); + return absl::UnimplementedError( + "BhwcToPhwc4: Batch size is not equal to 1."); } uint3 workload = uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)); uint3 num_workgroups = IntegralDivideRoundUp(workload, workgroup_size_); diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h index 9d9e6402ffa..9f699433a50 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h @@ -32,11 +32,11 @@ class ConverterBhwcToPhwc4 { // Creates invalid object. ConverterBhwcToPhwc4() : program_(), workgroup_size_() {} - static Status Create(ConverterBhwcToPhwc4* converter); + static absl::Status Create(ConverterBhwcToPhwc4* converter); - Status Convert(const BHWC& shape, const GlBuffer& source, - CommandQueue* command_queue /* optional */, - GlBuffer* destination); + absl::Status Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue /* optional */, + GlBuffer* destination); private: explicit ConverterBhwcToPhwc4(GlProgram program, const uint3& workgroup_size) diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc index 6fc424047a1..73ab9f67d94 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc @@ -41,7 +41,7 @@ inline std::vector GenerateFloats(float multiplier, int size) { return v; } -Status RunTest(const BHWC& shape) { +absl::Status RunTest(const BHWC& shape) { // Create random input and calculate expected output for it. std::vector input = GenerateFloats(0.01, shape.DimensionsProduct()); std::vector output(GetElementsSizeForPHWC4(shape), 0); @@ -71,9 +71,9 @@ Status RunTest(const BHWC& shape) { RETURN_IF_ERROR(output_buffer.Read( absl::MakeSpan(converted_output.data(), converted_output.size()))); if (output != converted_output) { - return InternalError("Outputs don't match"); + return absl::InternalError("Outputs don't match"); } - return OkStatus(); + return absl::OkStatus(); } TEST(HwcToPhwc4, Smoke) { diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc index c63fee9f8bd..5a9f51c0425 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc @@ -31,7 +31,7 @@ namespace tflite { namespace gpu { namespace gl { -Status ConverterPhwc4ToBhwc::Create(ConverterPhwc4ToBhwc* converter) { +absl::Status ConverterPhwc4ToBhwc::Create(ConverterPhwc4ToBhwc* converter) { uint3 workgroup_size = uint3(4, 4, 4); std::string shader_source = GetShaderHeader(workgroup_size) + R"( layout(std430) buffer; @@ -62,22 +62,24 @@ Status ConverterPhwc4ToBhwc::Create(ConverterPhwc4ToBhwc* converter) { GlProgram program; RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); *converter = ConverterPhwc4ToBhwc(std::move(program), workgroup_size); - return OkStatus(); + return absl::OkStatus(); } -Status ConverterPhwc4ToBhwc::Convert(const BHWC& shape, const GlBuffer& source, - CommandQueue* command_queue, - GlBuffer* destination) { +absl::Status ConverterPhwc4ToBhwc::Convert(const BHWC& shape, + const GlBuffer& source, + CommandQueue* command_queue, + GlBuffer* destination) { if (source.bytes_size() < BytesForPHWC4(shape)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Phwc4ToBhwc: Input data size does not match expected size."); } if (destination->bytes_size() < BytesForBHWC(shape)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Phwc4ToBhwc: output data size does not match expected size."); } if (shape.b != 1) { - return UnimplementedError("Phwc4ToBhwc: Batch size is not equal to 1."); + return absl::UnimplementedError( + "Phwc4ToBhwc: Batch size is not equal to 1."); } uint3 workload = uint3(shape.w, shape.h, shape.c); diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h index c8b181223ae..d9a4dd34ee8 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h @@ -32,11 +32,11 @@ class ConverterPhwc4ToBhwc { // Creates invalid object. ConverterPhwc4ToBhwc() : program_(), workgroup_size_() {} - static Status Create(ConverterPhwc4ToBhwc* converter); + static absl::Status Create(ConverterPhwc4ToBhwc* converter); - Status Convert(const BHWC& shape, const GlBuffer& source, - CommandQueue* command_queue /* optional */, - GlBuffer* destination); + absl::Status Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue /* optional */, + GlBuffer* destination); private: explicit ConverterPhwc4ToBhwc(GlProgram program, const uint3& workgroup_size) diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc index 6f969bb7801..34346e3ce9d 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc @@ -41,7 +41,7 @@ inline std::vector GenerateFloats(float multiplier, int size) { return v; } -Status RunTest(const BHWC& shape) { +absl::Status RunTest(const BHWC& shape) { // Create random input and calculate expected output for it. std::vector input = GenerateFloats(0.01, GetElementsSizeForPHWC4(shape)); @@ -72,9 +72,9 @@ Status RunTest(const BHWC& shape) { RETURN_IF_ERROR(output_buffer.Read( absl::MakeSpan(converted_output.data(), converted_output.size()))); if (output != converted_output) { - return InternalError("Outputs don't match"); + return absl::InternalError("Outputs don't match"); } - return OkStatus(); + return absl::OkStatus(); } TEST(Phwc4ToHwc, Smoke) { diff --git a/tensorflow/lite/delegates/gpu/gl/egl_context.cc b/tensorflow/lite/delegates/gpu/gl/egl_context.cc index 46fbed24291..f01bafcacff 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_context.cc +++ b/tensorflow/lite/delegates/gpu/gl/egl_context.cc @@ -26,19 +26,19 @@ namespace gpu { namespace gl { namespace { -Status GetConfig(EGLDisplay display, const EGLint* attributes, - EGLConfig* config) { +absl::Status GetConfig(EGLDisplay display, const EGLint* attributes, + EGLConfig* config) { EGLint config_count; bool chosen = eglChooseConfig(display, attributes, config, 1, &config_count); RETURN_IF_ERROR(GetOpenGlErrors()); if (!chosen || config_count == 0) { - return InternalError("No EGL error, but eglChooseConfig failed."); + return absl::InternalError("No EGL error, but eglChooseConfig failed."); } - return OkStatus(); + return absl::OkStatus(); } -Status CreateContext(EGLDisplay display, EGLContext shared_context, - EGLConfig config, EglContext* egl_context) { +absl::Status CreateContext(EGLDisplay display, EGLContext shared_context, + EGLConfig config, EglContext* egl_context) { static const EGLint attributes[] = {EGL_CONTEXT_CLIENT_VERSION, 3, #ifdef _DEBUG // Add debugging bit EGL_CONTEXT_FLAGS_KHR, @@ -49,10 +49,10 @@ Status CreateContext(EGLDisplay display, EGLContext shared_context, eglCreateContext(display, config, shared_context, attributes); RETURN_IF_ERROR(GetOpenGlErrors()); if (context == EGL_NO_CONTEXT) { - return InternalError("No EGL error, but eglCreateContext failed."); + return absl::InternalError("No EGL error, but eglCreateContext failed."); } *egl_context = EglContext(context, display, config, true); - return OkStatus(); + return absl::OkStatus(); } bool HasExtension(EGLDisplay display, const char* name) { @@ -93,34 +93,36 @@ EglContext& EglContext::operator=(EglContext&& other) { return *this; } -Status EglContext::MakeCurrent(EGLSurface read, EGLSurface write) { +absl::Status EglContext::MakeCurrent(EGLSurface read, EGLSurface write) { bool is_made_current = eglMakeCurrent(display_, write, read, context_); RETURN_IF_ERROR(GetOpenGlErrors()); if (!is_made_current) { - return InternalError("No EGL error, but eglMakeCurrent failed."); + return absl::InternalError("No EGL error, but eglMakeCurrent failed."); } - return OkStatus(); + return absl::OkStatus(); } bool EglContext::IsCurrent() const { return context_ == eglGetCurrentContext(); } -Status CreateConfiglessContext(EGLDisplay display, EGLContext shared_context, - EglContext* egl_context) { +absl::Status CreateConfiglessContext(EGLDisplay display, + EGLContext shared_context, + EglContext* egl_context) { if (!HasExtension(display, "EGL_KHR_no_config_context")) { - return UnavailableError("EGL_KHR_no_config_context not supported"); + return absl::UnavailableError("EGL_KHR_no_config_context not supported"); } return CreateContext(display, shared_context, EGL_NO_CONFIG_KHR, egl_context); } -Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context, - EglContext* egl_context) { +absl::Status CreateSurfacelessContext(EGLDisplay display, + EGLContext shared_context, + EglContext* egl_context) { if (!HasExtension(display, "EGL_KHR_create_context")) { - return UnavailableError("EGL_KHR_create_context not supported"); + return absl::UnavailableError("EGL_KHR_create_context not supported"); } if (!HasExtension(display, "EGL_KHR_surfaceless_context")) { - return UnavailableError("EGL_KHR_surfaceless_context not supported"); + return absl::UnavailableError("EGL_KHR_surfaceless_context not supported"); } const EGLint attributes[] = {EGL_RENDERABLE_TYPE, EGL_OPENGL_ES3_BIT_KHR, EGL_NONE}; @@ -129,8 +131,8 @@ Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context, return CreateContext(display, shared_context, config, egl_context); } -Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, - EglContext* egl_context) { +absl::Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context) { const EGLint attributes[] = { EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_BIND_TO_TEXTURE_RGB, EGL_TRUE, EGL_RENDERABLE_TYPE, EGL_OPENGL_ES3_BIT_KHR, diff --git a/tensorflow/lite/delegates/gpu/gl/egl_context.h b/tensorflow/lite/delegates/gpu/gl/egl_context.h index 72c53d2dd2e..a93f1fdc4c4 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_context.h +++ b/tensorflow/lite/delegates/gpu/gl/egl_context.h @@ -61,9 +61,9 @@ class EglContext { // Make this EglContext the current EGL context on this thread, replacing // the existing current. - Status MakeCurrent(EGLSurface read, EGLSurface write); + absl::Status MakeCurrent(EGLSurface read, EGLSurface write); - Status MakeCurrentSurfaceless() { + absl::Status MakeCurrentSurfaceless() { return MakeCurrent(EGL_NO_SURFACE, EGL_NO_SURFACE); } @@ -86,14 +86,16 @@ class EglContext { // It uses the EGL_KHR_no_config_context extension to create a no config context // since most modern hardware supports the extension. -Status CreateConfiglessContext(EGLDisplay display, EGLContext shared_context, - EglContext* egl_context); +absl::Status CreateConfiglessContext(EGLDisplay display, + EGLContext shared_context, + EglContext* egl_context); -Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context, - EglContext* egl_context); +absl::Status CreateSurfacelessContext(EGLDisplay display, + EGLContext shared_context, + EglContext* egl_context); -Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, - EglContext* egl_context); +absl::Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/egl_environment.cc b/tensorflow/lite/delegates/gpu/gl/egl_environment.cc index baf6002e6c1..8ae75acd933 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_environment.cc +++ b/tensorflow/lite/delegates/gpu/gl/egl_environment.cc @@ -28,28 +28,28 @@ namespace { // TODO(akulik): detect power management event when all contexts are destroyed // and OpenGL ES is reinitialized. See eglMakeCurrent -Status InitDisplay(EGLDisplay* egl_display) { +absl::Status InitDisplay(EGLDisplay* egl_display) { RETURN_IF_ERROR( TFLITE_GPU_CALL_EGL(eglGetDisplay, egl_display, EGL_DEFAULT_DISPLAY)); if (*egl_display == EGL_NO_DISPLAY) { - return UnavailableError("eglGetDisplay returned nullptr"); + return absl::UnavailableError("eglGetDisplay returned nullptr"); } bool is_initialized; RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglInitialize, &is_initialized, *egl_display, nullptr, nullptr)); if (!is_initialized) { - return InternalError("No EGL error, but eglInitialize failed"); + return absl::InternalError("No EGL error, but eglInitialize failed"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status EglEnvironment::NewEglEnvironment( +absl::Status EglEnvironment::NewEglEnvironment( std::unique_ptr* egl_environment) { *egl_environment = absl::make_unique(); RETURN_IF_ERROR((*egl_environment)->Init()); - return OkStatus(); + return absl::OkStatus(); } EglEnvironment::~EglEnvironment() { @@ -61,12 +61,12 @@ EglEnvironment::~EglEnvironment() { } } -Status EglEnvironment::Init() { +absl::Status EglEnvironment::Init() { bool is_bound; RETURN_IF_ERROR( TFLITE_GPU_CALL_EGL(eglBindAPI, &is_bound, EGL_OPENGL_ES_API)); if (!is_bound) { - return InternalError("No EGL error, but eglBindAPI failed"); + return absl::InternalError("No EGL error, but eglBindAPI failed"); } // Re-use context and display if it was created on this thread. @@ -77,7 +77,7 @@ Status EglEnvironment::Init() { } else { RETURN_IF_ERROR(InitDisplay(&display_)); - Status status = InitConfiglessContext(); + absl::Status status = InitConfiglessContext(); if (!status.ok()) { status = InitSurfacelessContext(); } @@ -94,33 +94,30 @@ Status EglEnvironment::Init() { } // TODO(akulik): when do we need ForceSyncTurning? ForceSyncTurning(); - return OkStatus(); + return absl::OkStatus(); } -Status EglEnvironment::InitConfiglessContext() { +absl::Status EglEnvironment::InitConfiglessContext() { RETURN_IF_ERROR(CreateConfiglessContext(display_, EGL_NO_CONTEXT, &context_)); return context_.MakeCurrentSurfaceless(); } -Status EglEnvironment::InitSurfacelessContext() { +absl::Status EglEnvironment::InitSurfacelessContext() { RETURN_IF_ERROR( CreateSurfacelessContext(display_, EGL_NO_CONTEXT, &context_)); - Status status = context_.MakeCurrentSurfaceless(); - if (!status.ok()) { - return status; - } + RETURN_IF_ERROR(context_.MakeCurrentSurfaceless()); // PowerVR support EGL_KHR_surfaceless_context, but glFenceSync crashes on // PowerVR when it is surface-less. RETURN_IF_ERROR(RequestGpuInfo(&gpu_info_)); if (gpu_info_.type == GpuType::POWERVR) { - return UnavailableError( + return absl::UnavailableError( "Surface-less context is not properly supported on powervr."); } - return OkStatus(); + return absl::OkStatus(); } -Status EglEnvironment::InitPBufferContext() { +absl::Status EglEnvironment::InitPBufferContext() { RETURN_IF_ERROR(CreatePBufferContext(display_, EGL_NO_CONTEXT, &context_)); RETURN_IF_ERROR(CreatePbufferRGBSurface(context_.config(), display_, 1, 1, &surface_read_)); diff --git a/tensorflow/lite/delegates/gpu/gl/egl_environment.h b/tensorflow/lite/delegates/gpu/gl/egl_environment.h index fa7ca047b6e..cb6616496dd 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_environment.h +++ b/tensorflow/lite/delegates/gpu/gl/egl_environment.h @@ -36,7 +36,7 @@ namespace gl { // EGL environment needs to be created once per thread. class EglEnvironment { public: - static Status NewEglEnvironment( + static absl::Status NewEglEnvironment( std::unique_ptr* egl_environment); EglEnvironment() = default; @@ -47,10 +47,10 @@ class EglEnvironment { const GpuInfo& gpu_info() const { return gpu_info_; } private: - Status Init(); - Status InitConfiglessContext(); - Status InitSurfacelessContext(); - Status InitPBufferContext(); + absl::Status Init(); + absl::Status InitConfiglessContext(); + absl::Status InitSurfacelessContext(); + absl::Status InitPBufferContext(); EGLDisplay display_ = EGL_NO_DISPLAY; EglSurface surface_draw_; diff --git a/tensorflow/lite/delegates/gpu/gl/egl_surface.cc b/tensorflow/lite/delegates/gpu/gl/egl_surface.cc index eaccea6411e..d0f062af392 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_surface.cc +++ b/tensorflow/lite/delegates/gpu/gl/egl_surface.cc @@ -44,9 +44,9 @@ void EglSurface::Invalidate() { } } -Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, - uint32_t height, uint32_t width, - EglSurface* egl_surface) { +absl::Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, + uint32_t height, uint32_t width, + EglSurface* egl_surface) { const EGLint pbuffer_attributes[] = {EGL_WIDTH, static_cast(width), EGL_HEIGHT, @@ -60,10 +60,11 @@ Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, eglCreatePbufferSurface(display, config, pbuffer_attributes); RETURN_IF_ERROR(GetOpenGlErrors()); if (surface == EGL_NO_SURFACE) { - return InternalError("No EGL error, but eglCreatePbufferSurface failed"); + return absl::InternalError( + "No EGL error, but eglCreatePbufferSurface failed"); } *egl_surface = EglSurface(surface, display); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/egl_surface.h b/tensorflow/lite/delegates/gpu/gl/egl_surface.h index 793dc7a9dc6..5d39aed33fb 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_surface.h +++ b/tensorflow/lite/delegates/gpu/gl/egl_surface.h @@ -56,9 +56,9 @@ class EglSurface { }; // Creates off-screen pbuffer-based surface of the given height and width. -Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, - uint32_t height, uint32_t width, - EglSurface* egl_surface); +absl::Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, + uint32_t height, uint32_t width, + EglSurface* egl_surface); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc b/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc index 509cadca60d..1de49676219 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc @@ -21,9 +21,10 @@ namespace tflite { namespace gpu { namespace gl { -Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer) { +absl::Status CopyBuffer(const GlBuffer& read_buffer, + const GlBuffer& write_buffer) { if (read_buffer.bytes_size() != write_buffer.bytes_size()) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Read buffer does not match write buffer size."); } gl_buffer_internal::BufferBinder read_buffer_binder(GL_COPY_READ_BUFFER, @@ -35,7 +36,7 @@ Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer) { write_buffer.offset(), read_buffer.bytes_size()); } -Status GetSSBOSize(GLuint id, int64_t* size_bytes) { +absl::Status GetSSBOSize(GLuint id, int64_t* size_bytes) { GLuint prev_id; RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetIntegerv, GL_SHADER_STORAGE_BUFFER_BINDING, @@ -75,19 +76,19 @@ void GlBuffer::Invalidate() { } } -Status GlBuffer::BindToIndex(uint32_t index) const { +absl::Status GlBuffer::BindToIndex(uint32_t index) const { return TFLITE_GPU_CALL_GL(glBindBufferRange, target_, index, id_, offset_, bytes_size_); } -Status GlBuffer::MakeView(size_t offset, size_t bytes_size, - GlBuffer* gl_buffer) { +absl::Status GlBuffer::MakeView(size_t offset, size_t bytes_size, + GlBuffer* gl_buffer) { if (offset + bytes_size > bytes_size_) { - return OutOfRangeError("GlBuffer view is out of range."); + return absl::OutOfRangeError("GlBuffer view is out of range."); } *gl_buffer = GlBuffer(target_, id_, bytes_size, offset_ + offset, /*has_ownership=*/false); - return OkStatus(); + return absl::OkStatus(); } GlBuffer GlBuffer::MakeRef() { @@ -121,12 +122,13 @@ GlPersistentBuffer::~GlPersistentBuffer() { glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); } -Status CreatePersistentBuffer(size_t size, GlPersistentBuffer* gl_buffer) { +absl::Status CreatePersistentBuffer(size_t size, + GlPersistentBuffer* gl_buffer) { PFNGLBUFFERSTORAGEEXTPROC glBufferStorageEXT = nullptr; glBufferStorageEXT = reinterpret_cast( eglGetProcAddress("glBufferStorageEXT")); if (!glBufferStorageEXT) { - return UnavailableError("glBufferStorageEXT is not supported"); + return absl::UnavailableError("glBufferStorageEXT is not supported"); } gl_buffer_internal::BufferId id; gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id()); @@ -140,7 +142,7 @@ Status CreatePersistentBuffer(size_t size, GlPersistentBuffer* gl_buffer) { GL_MAP_READ_BIT | GL_MAP_WRITE_BIT | GL_MAP_PERSISTENT_BIT_EXT)); *gl_buffer = GlPersistentBuffer{ GL_SHADER_STORAGE_BUFFER, id.Release(), size, 0, true, data}; - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer.h b/tensorflow/lite/delegates/gpu/gl/gl_buffer.h index a7e19abde70..3225679ec5a 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_buffer.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer.h @@ -60,30 +60,31 @@ class GlBuffer { // Reads data from buffer into CPU memory. Data should point to a region that // has at least bytes_size available. template - Status Read(absl::Span data) const; + absl::Status Read(absl::Span data) const; // Writes data to a buffer. template - Status Write(absl::Span data); + absl::Status Write(absl::Span data); // Maps GPU memory to CPU address space and calls reader that may read from // that memory. template - Status MappedRead( - const std::function)>& reader) const; + absl::Status MappedRead( + const std::function)>& reader) const; // Maps GPU memory to CPU address space and calls writer that may write into // that memory. template - Status MappedWrite(const std::function)>& writer); + absl::Status MappedWrite( + const std::function)>& writer); - Status MakeView(size_t offset, size_t bytes_size, GlBuffer* gl_buffer); + absl::Status MakeView(size_t offset, size_t bytes_size, GlBuffer* gl_buffer); // Makes a copy without ownership of the buffer. GlBuffer MakeRef(); // Binds a buffer to an index. - Status BindToIndex(uint32_t index) const; + absl::Status BindToIndex(uint32_t index) const; // Releases the ownership of the buffer object. void Release() { has_ownership_ = false; } @@ -112,9 +113,10 @@ class GlBuffer { bool has_ownership_; }; -Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer); +absl::Status CopyBuffer(const GlBuffer& read_buffer, + const GlBuffer& write_buffer); -Status GetSSBOSize(GLuint id, int64_t* size_bytes); +absl::Status GetSSBOSize(GLuint id, int64_t* size_bytes); // Creates new shader storage buffer that will be modified and used many // times. @@ -122,20 +124,20 @@ Status GetSSBOSize(GLuint id, int64_t* size_bytes); // See https://www.khronos.org/opengl/wiki/Shader_Storage_Buffer_Object for // details. template -Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, - GlBuffer* gl_buffer); +absl::Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, + GlBuffer* gl_buffer); // Creates new shader storage buffer that will be filled with data once which // will be used many times. template -Status CreateReadOnlyShaderStorageBuffer(absl::Span data, - GlBuffer* gl_buffer); +absl::Status CreateReadOnlyShaderStorageBuffer(absl::Span data, + GlBuffer* gl_buffer); // Adapts raw Buffer::Read method to read data into a vector. template -Status AppendFromBuffer(const GlBuffer& buffer, std::vector* data) { +absl::Status AppendFromBuffer(const GlBuffer& buffer, std::vector* data) { if (buffer.bytes_size() % sizeof(T) != 0) { - return InvalidArgumentError("Buffer is not aligned"); + return absl::InvalidArgumentError("Buffer is not aligned"); } size_t num_elements = buffer.bytes_size() / sizeof(T); data->resize(data->size() + num_elements); @@ -167,7 +169,7 @@ class GlPersistentBuffer : public GlBuffer { }; // Creates read-write persistent buffer with valid CPU pointer -Status CreatePersistentBuffer(size_t size, GlPersistentBuffer* gl_buffer); +absl::Status CreatePersistentBuffer(size_t size, GlPersistentBuffer* gl_buffer); //////////////////////////////////////////////////////////////////////////////// // Implementation details are below. @@ -243,8 +245,8 @@ class BufferMapper { } // namespace gl_buffer_internal template -Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, - GlBuffer* gl_buffer) { +absl::Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, + GlBuffer* gl_buffer) { gl_buffer_internal::BufferId id; gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id()); // TODO(akulik): benchmark DYNAMIC vs STREAM buffer @@ -253,12 +255,12 @@ Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, GL_STREAM_COPY)); *gl_buffer = GlBuffer{GL_SHADER_STORAGE_BUFFER, id.Release(), num_elements * sizeof(T), 0, true}; - return OkStatus(); + return absl::OkStatus(); } template -Status CreateReadOnlyShaderStorageBuffer(absl::Span data, - GlBuffer* gl_buffer) { +absl::Status CreateReadOnlyShaderStorageBuffer(absl::Span data, + GlBuffer* gl_buffer) { gl_buffer_internal::BufferId id; gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id()); RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glBufferData, GL_SHADER_STORAGE_BUFFER, @@ -266,26 +268,26 @@ Status CreateReadOnlyShaderStorageBuffer(absl::Span data, GL_STATIC_READ)); *gl_buffer = GlBuffer{GL_SHADER_STORAGE_BUFFER, id.Release(), data.size() * sizeof(T), 0, true}; - return OkStatus(); + return absl::OkStatus(); } template -Status GlBuffer::Read(absl::Span data) const { +absl::Status GlBuffer::Read(absl::Span data) const { if (data.size() * sizeof(T) < bytes_size()) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Read from buffer failed. Destination data is shorter than buffer."); } // TODO(akulik): glCopyBufferSubData is actually available in ES 3.1, try it. return MappedRead([this, data](absl::Span src) { std::memcpy(data.data(), src.data(), bytes_size()); - return OkStatus(); + return absl::OkStatus(); }); } template -Status GlBuffer::Write(absl::Span data) { +absl::Status GlBuffer::Write(absl::Span data) { if (data.size() * sizeof(T) > bytes_size_) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Write to buffer failed. Source data is larger than buffer."); } gl_buffer_internal::BufferBinder binder(target_, id_); @@ -294,10 +296,10 @@ Status GlBuffer::Write(absl::Span data) { } template -Status GlBuffer::MappedRead( - const std::function d)>& reader) const { +absl::Status GlBuffer::MappedRead( + const std::function d)>& reader) const { if (bytes_size_ % sizeof(T) != 0) { - return InvalidArgumentError("Buffer is not aligned"); + return absl::InvalidArgumentError("Buffer is not aligned"); } gl_buffer_internal::BufferBinder binder(target_, id_); gl_buffer_internal::BufferMapper mapper(target_, offset_, bytes_size_, @@ -310,10 +312,10 @@ Status GlBuffer::MappedRead( } template -Status GlBuffer::MappedWrite( - const std::function d)>& writer) { +absl::Status GlBuffer::MappedWrite( + const std::function d)>& writer) { if (bytes_size_ % sizeof(T) != 0) { - return InvalidArgumentError("Buffer is not aligned"); + return absl::InvalidArgumentError("Buffer is not aligned"); } gl_buffer_internal::BufferBinder binder(target_, id_); gl_buffer_internal::BufferMapper mapper(target_, offset_, bytes_size_, diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc b/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc index 1d8031fcf39..863f5ec6020 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc @@ -89,7 +89,7 @@ TEST(Buffer, SubView) { GlBuffer view1; ASSERT_TRUE(buffer.MakeView(4, 16, &view1).ok()); GlBuffer view2; - EXPECT_NE(view1.MakeView(1, 16, &view2), OkStatus()); + EXPECT_FALSE(view1.MakeView(1, 16, &view2).ok()); ASSERT_TRUE(view1.MakeView(2, 2, &view2).ok()); EXPECT_FALSE(view2.has_ownership()); diff --git a/tensorflow/lite/delegates/gpu/gl/gl_call.h b/tensorflow/lite/delegates/gpu/gl/gl_call.h index a8a81bae608..1a392d6aca3 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_call.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_call.h @@ -53,12 +53,13 @@ namespace gl_call_internal { template struct Caller { template - Status operator()(const std::string& context, F func, ErrorF error_func, - T* result, Params&&... params) { + absl::Status operator()(const std::string& context, F func, ErrorF error_func, + T* result, Params&&... params) { *result = func(std::forward(params)...); const auto status = error_func(); - if (status.ok()) return OkStatus(); - return Status(status.code(), status.error_message() + ": " + context); + if (status.ok()) return absl::OkStatus(); + return absl::Status(status.code(), + std::string(status.message()) + ": " + context); } }; @@ -66,25 +67,27 @@ struct Caller { template<> struct Caller { template - Status operator()(const std::string& context, F func, ErrorF error_func, - Params&&... params) { + absl::Status operator()(const std::string& context, F func, ErrorF error_func, + Params&&... params) { func(std::forward(params)...); const auto status = error_func(); - if (status.ok()) return OkStatus(); - return Status(status.code(), status.error_message() + ": " + context); + if (status.ok()) return absl::OkStatus(); + return absl::Status(status.code(), + std::string(status.message()) + ": " + context); } }; template -Status CallAndCheckError(const std::string& context, F func, ErrorF error_func, - ResultT* result, ParamsT&&... params) { +absl::Status CallAndCheckError(const std::string& context, F func, + ErrorF error_func, ResultT* result, + ParamsT&&... params) { return Caller()(context, func, error_func, result, std::forward(params)...); } template -Status CallAndCheckError(const std::string& context, F func, ErrorF error_func, - Params&&... params) { +absl::Status CallAndCheckError(const std::string& context, F func, + ErrorF error_func, Params&&... params) { return Caller()(context, func, error_func, std::forward(params)...); } diff --git a/tensorflow/lite/delegates/gpu/gl/gl_errors.cc b/tensorflow/lite/delegates/gpu/gl/gl_errors.cc index 1a40e38ea9c..3ad6be8a25e 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_errors.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_errors.cc @@ -58,83 +58,83 @@ struct ErrorFormatter { // TODO(akulik): create new error space for GL error. -Status GetOpenGlErrors() { +absl::Status GetOpenGlErrors() { auto error = glGetError(); if (error == GL_NO_ERROR) { - return OkStatus(); + return absl::OkStatus(); } auto error2 = glGetError(); if (error2 == GL_NO_ERROR) { - return InternalError(ErrorToString(error)); + return absl::InternalError(ErrorToString(error)); } std::vector errors = {error, error2}; for (error = glGetError(); error != GL_NO_ERROR; error = glGetError()) { errors.push_back(error); } - return InternalError(absl::StrJoin(errors, ",", ErrorFormatter())); + return absl::InternalError(absl::StrJoin(errors, ",", ErrorFormatter())); } -Status GetEglError() { +absl::Status GetEglError() { EGLint error = eglGetError(); switch (error) { case EGL_SUCCESS: - return OkStatus(); + return absl::OkStatus(); case EGL_NOT_INITIALIZED: - return InternalError( + return absl::InternalError( "EGL is not initialized, or could not be initialized, for the " "specified EGL display connection."); case EGL_BAD_ACCESS: - return InternalError( + return absl::InternalError( "EGL cannot access a requested resource (for example a context is " "bound in another thread)."); case EGL_BAD_ALLOC: - return InternalError( + return absl::InternalError( "EGL failed to allocate resources for the requested operation."); case EGL_BAD_ATTRIBUTE: - return InternalError( + return absl::InternalError( "An unrecognized attribute or attribute value was passed in the " "attribute list."); case EGL_BAD_CONTEXT: - return InternalError( + return absl::InternalError( "An EGLContext argument does not name a valid EGL rendering " "context."); case EGL_BAD_CONFIG: - return InternalError( + return absl::InternalError( "An EGLConfig argument does not name a valid EGL frame buffer " "configuration."); case EGL_BAD_CURRENT_SURFACE: - return InternalError( + return absl::InternalError( "The current surface of the calling thread is a window, pixel buffer " "or pixmap that is no longer valid."); case EGL_BAD_DISPLAY: - return InternalError( + return absl::InternalError( "An EGLDisplay argument does not name a valid EGL display " "connection."); case EGL_BAD_SURFACE: - return InternalError( + return absl::InternalError( "An EGLSurface argument does not name a valid surface (window, pixel " "buffer or pixmap) configured for GL rendering."); case EGL_BAD_MATCH: - return InternalError( + return absl::InternalError( "Arguments are inconsistent (for example, a valid context requires " "buffers not supplied by a valid surface)."); case EGL_BAD_PARAMETER: - return InternalError("One or more argument values are invalid."); + return absl::InternalError("One or more argument values are invalid."); case EGL_BAD_NATIVE_PIXMAP: - return InternalError( + return absl::InternalError( "A NativePixmapType argument does not refer to a valid native " "pixmap."); case EGL_BAD_NATIVE_WINDOW: - return InternalError( + return absl::InternalError( "A NativeWindowType argument does not refer to a valid native " "window."); case EGL_CONTEXT_LOST: - return InternalError( + return absl::InternalError( "A power management event has occurred. The application must destroy " "all contexts and reinitialize OpenGL ES state and objects to " "continue rendering."); } - return UnknownError("EGL error: " + std::to_string(error)); + return absl::UnknownError("EGL error: " + std::to_string(error)); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_errors.h b/tensorflow/lite/delegates/gpu/gl/gl_errors.h index 978e642abaa..761eddd8901 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_errors.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_errors.h @@ -23,10 +23,10 @@ namespace gpu { namespace gl { // @return recent opengl errors and packs them into Status. -Status GetOpenGlErrors(); +absl::Status GetOpenGlErrors(); // @return the error of the last called EGL function in the current thread. -Status GetEglError(); +absl::Status GetEglError(); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/gl_program.cc b/tensorflow/lite/delegates/gpu/gl/gl_program.cc index def82357a6a..d6e56ca64c4 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_program.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_program.cc @@ -29,19 +29,19 @@ namespace gpu { namespace gl { namespace { -Status CreateNewProgramId(GLuint* program_id) { +absl::Status CreateNewProgramId(GLuint* program_id) { RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glCreateProgram, program_id)); if (!*program_id) { - return UnknownError("Can't create opengl program: 0 program_id"); + return absl::UnknownError("Can't create opengl program: 0 program_id"); } - return OkStatus(); + return absl::OkStatus(); } -Status CheckProgramLinked(GLuint program_id) { +absl::Status CheckProgramLinked(GLuint program_id) { GLint linked; glGetProgramiv(program_id, GL_LINK_STATUS, &linked); if (linked == GL_TRUE) { - return OkStatus(); + return absl::OkStatus(); } GLint info_size; glGetProgramiv(program_id, GL_INFO_LOG_LENGTH, &info_size); @@ -49,26 +49,26 @@ Status CheckProgramLinked(GLuint program_id) { errors.resize(info_size + 1 /* plus \0 */); glGetProgramInfoLog(program_id, info_size + 1, nullptr, &errors[0]); // TODO(akulik): use glValidateProgram to gather more info. - return UnavailableError("Program is not properly linked: " + errors); + return absl::UnavailableError("Program is not properly linked: " + errors); } struct ParameterSetter { - Status operator()(int value) { + absl::Status operator()(int value) { return TFLITE_GPU_CALL_GL(glProgramUniform1i, program_id, uniform_id, value); } - Status operator()(const int2& value) { + absl::Status operator()(const int2& value) { return TFLITE_GPU_CALL_GL(glProgramUniform2i, program_id, uniform_id, value.x, value.y); } - Status operator()(const int4& value) { + absl::Status operator()(const int4& value) { return TFLITE_GPU_CALL_GL(glProgramUniform4i, program_id, uniform_id, value.x, value.y, value.z, value.w); } - Status operator()(const std::vector& value) { + absl::Status operator()(const std::vector& value) { std::vector ints(value.size() * 2, 0); for (int i = 0; i < value.size(); ++i) { ints[i * 2] = value[i].x; @@ -78,32 +78,32 @@ struct ParameterSetter { ints.size(), ints.data()); } - Status operator()(unsigned int value) { + absl::Status operator()(unsigned int value) { return TFLITE_GPU_CALL_GL(glProgramUniform1ui, program_id, uniform_id, value); } - Status operator()(const uint4& value) { + absl::Status operator()(const uint4& value) { return TFLITE_GPU_CALL_GL(glProgramUniform4ui, program_id, uniform_id, value.x, value.y, value.z, value.w); } - Status operator()(float value) { + absl::Status operator()(float value) { return TFLITE_GPU_CALL_GL(glProgramUniform1f, program_id, uniform_id, value); } - Status operator()(const float2& value) { + absl::Status operator()(const float2& value) { return TFLITE_GPU_CALL_GL(glProgramUniform2f, program_id, uniform_id, value.x, value.y); } - Status operator()(const float4& value) { + absl::Status operator()(const float4& value) { return TFLITE_GPU_CALL_GL(glProgramUniform4f, program_id, uniform_id, value.x, value.y, value.z, value.w); } - Status operator()(const std::vector& value) { + absl::Status operator()(const std::vector& value) { std::vector floats(value.size() * 4, 0); for (int i = 0; i < value.size(); ++i) { floats[i * 4] = value[i].x; @@ -121,8 +121,8 @@ struct ParameterSetter { } // namespace -Status GlProgram::CreateWithShader(const GlShader& shader, - GlProgram* gl_program) { +absl::Status GlProgram::CreateWithShader(const GlShader& shader, + GlProgram* gl_program) { GLuint program_id; RETURN_IF_ERROR(CreateNewProgramId(&program_id)); @@ -136,11 +136,11 @@ Status GlProgram::CreateWithShader(const GlShader& shader, RETURN_IF_ERROR(CheckProgramLinked(program.id())); *gl_program = std::move(program); - return OkStatus(); + return absl::OkStatus(); } -Status GlProgram::CreateWithBinaryShader(const BinaryShader& shader, - GlProgram* gl_program) { +absl::Status GlProgram::CreateWithBinaryShader(const BinaryShader& shader, + GlProgram* gl_program) { GLuint program_id; RETURN_IF_ERROR(CreateNewProgramId(&program_id)); @@ -154,15 +154,15 @@ Status GlProgram::CreateWithBinaryShader(const BinaryShader& shader, RETURN_IF_ERROR(CheckProgramLinked(program.id())); *gl_program = std::move(program); - return OkStatus(); + return absl::OkStatus(); } -Status GlProgram::GetBinary(BinaryShader* binary_shader) { +absl::Status GlProgram::GetBinary(BinaryShader* binary_shader) { GLint size = 0; RETURN_IF_ERROR( TFLITE_GPU_CALL_GL(glGetProgramiv, id_, GL_PROGRAM_BINARY_LENGTH, &size)); if (!size) { - return InternalError("Getting binary size failed."); + return absl::InternalError("Getting binary size failed."); } // TODO(akulik): call // glProgramParameteri(id_, GL_PROGRAM_BINARY_RETRIEVABLE_HINT, GL_TRUE) @@ -174,10 +174,10 @@ Status GlProgram::GetBinary(BinaryShader* binary_shader) { &returned_size, &format, reinterpret_cast(&binary[0]))); if (size != returned_size) { - return InternalError("Getting binary is failed."); + return absl::InternalError("Getting binary is failed."); } *binary_shader = BinaryShader(format, std::move(binary)); - return OkStatus(); + return absl::OkStatus(); } GlProgram::GlProgram(GlProgram&& program) : id_(program.id_) { @@ -201,16 +201,16 @@ GlProgram& GlProgram::operator=(GlProgram&& program) { GlProgram::~GlProgram() { Invalidate(); } -Status GlProgram::SetParameter(const Variable& param) { +absl::Status GlProgram::SetParameter(const Variable& param) { GLint uniform_location; RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetUniformLocation, &uniform_location, id_, param.name.c_str())); return absl::visit(ParameterSetter{id_, uniform_location}, param.value); } -Status GlProgram::Dispatch(const uint3& workgroups) const { +absl::Status GlProgram::Dispatch(const uint3& workgroups) const { if (workgroups.x == 0 || workgroups.y == 0 || workgroups.z == 0) { - return InvalidArgumentError("Invalid workgroups"); + return absl::InvalidArgumentError("Invalid workgroups"); } RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glUseProgram, id_)); return TFLITE_GPU_CALL_GL(glDispatchCompute, workgroups.x, workgroups.y, diff --git a/tensorflow/lite/delegates/gpu/gl/gl_program.h b/tensorflow/lite/delegates/gpu/gl/gl_program.h index dfd6bde4c59..892cb8e0850 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_program.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_program.h @@ -40,12 +40,13 @@ class GlProgram { // a program. Thus, if this call returns a program, one may set parameters and // finally execute a program. // therefore it needs to be handled elsewhere. - static Status CreateWithShader(const GlShader& shader, GlProgram* gl_program); + static absl::Status CreateWithShader(const GlShader& shader, + GlProgram* gl_program); // Same as CreateWithShader but takes compiled shader in a binary form, // therefore compilation step is avoided. - static Status CreateWithBinaryShader(const BinaryShader& shader, - GlProgram* gl_program); + static absl::Status CreateWithBinaryShader(const BinaryShader& shader, + GlProgram* gl_program); // move-only GlProgram(GlProgram&& program); @@ -59,12 +60,12 @@ class GlProgram { // Returns a binary representation for a shader currently attached and linked // into this program. - Status GetBinary(BinaryShader* binary_shader); + absl::Status GetBinary(BinaryShader* binary_shader); - Status SetParameter(const Variable& param); + absl::Status SetParameter(const Variable& param); // Executes program - Status Dispatch(const uint3& workgroups) const; + absl::Status Dispatch(const uint3& workgroups) const; bool is_valid() const { return id_ != 0; } diff --git a/tensorflow/lite/delegates/gpu/gl/gl_shader.cc b/tensorflow/lite/delegates/gpu/gl/gl_shader.cc index 32391749985..e3823a24d93 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_shader.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_shader.cc @@ -42,9 +42,9 @@ GlShader& GlShader::operator=(GlShader&& shader) { GlShader::~GlShader() { Invalidate(); } -Status GlShader::CompileShader(GLenum shader_type, - const std::string& shader_source, - GlShader* gl_shader) { +absl::Status GlShader::CompileShader(GLenum shader_type, + const std::string& shader_source, + GlShader* gl_shader) { // NOTE: code compilation can fail due to gl errors happened before GLuint shader_id; RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glCreateShader, &shader_id, shader_type)); @@ -64,12 +64,12 @@ Status GlShader::CompileShader(GLenum shader_type, glGetShaderiv(shader.id(), GL_INFO_LOG_LENGTH, &info_log_len); std::string errors(info_log_len, 0); glGetShaderInfoLog(shader.id(), info_log_len, nullptr, &errors[0]); - return InternalError("Shader compilation failed: " + errors + - "\nProblem shader is:\n" + shader_source); + return absl::InternalError("Shader compilation failed: " + errors + + "\nProblem shader is:\n" + shader_source); } *gl_shader = std::move(shader); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_shader.h b/tensorflow/lite/delegates/gpu/gl/gl_shader.h index d0ec421bb16..45adc59207b 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_shader.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_shader.h @@ -33,9 +33,9 @@ class GlShader { // // @param shader_type is one of GL_VERTEX_SHADER, GL_FRAGMENT_SHADER, or // GL_COMPUTE_SHADER. - static Status CompileShader(GLenum shader_type, - const std::string& shader_source, - GlShader* gl_shader); + static absl::Status CompileShader(GLenum shader_type, + const std::string& shader_source, + GlShader* gl_shader); GlShader() : id_(0) {} diff --git a/tensorflow/lite/delegates/gpu/gl/gl_sync.cc b/tensorflow/lite/delegates/gpu/gl/gl_sync.cc index 92caaa5c78a..89d3a88d16f 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_sync.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_sync.cc @@ -25,7 +25,7 @@ namespace tflite { namespace gpu { namespace gl { -Status GlSyncWait() { +absl::Status GlSyncWait() { GlSync sync; RETURN_IF_ERROR(GlSync::NewSync(&sync)); // Flush sync and loop afterwards without it. @@ -37,16 +37,16 @@ Status GlSyncWait() { break; case GL_CONDITION_SATISFIED: case GL_ALREADY_SIGNALED: - return OkStatus(); + return absl::OkStatus(); case GL_WAIT_FAILED: return GetOpenGlErrors(); } status = glClientWaitSync(sync.sync(), 0, /* timeout ns = */ 10000000); } - return OkStatus(); + return absl::OkStatus(); } -Status GlActiveSyncWait() { +absl::Status GlActiveSyncWait() { GlSync sync; RETURN_IF_ERROR(GlSync::NewSync(&sync)); // Since creating a Sync object is itself a GL command it *must* be flushed. @@ -59,7 +59,7 @@ Status GlActiveSyncWait() { break; case GL_CONDITION_SATISFIED: case GL_ALREADY_SIGNALED: - return OkStatus(); + return absl::OkStatus(); case GL_WAIT_FAILED: return GetOpenGlErrors(); } @@ -69,7 +69,7 @@ Status GlActiveSyncWait() { while (true) { glGetSynciv(sync.sync(), GL_SYNC_STATUS, sizeof(GLint), nullptr, &result); if (result == GL_SIGNALED) { - return OkStatus(); + return absl::OkStatus(); } #ifdef __ARM_ACLE // Try to save CPU power by yielding CPU to another thread. @@ -78,7 +78,7 @@ Status GlActiveSyncWait() { } } -Status GlShaderSync::NewSync(GlShaderSync* gl_sync) { +absl::Status GlShaderSync::NewSync(GlShaderSync* gl_sync) { GlShaderSync sync; RETURN_IF_ERROR(CreatePersistentBuffer(sizeof(int), &sync.flag_buffer_)); static const std::string* kCode = new std::string(R"(#version 310 es @@ -94,16 +94,16 @@ Status GlShaderSync::NewSync(GlShaderSync* gl_sync) { RETURN_IF_ERROR(GlShader::CompileShader(GL_COMPUTE_SHADER, *kCode, &shader)); RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &sync.flag_program_)); *gl_sync = std::move(sync); - return OkStatus(); + return absl::OkStatus(); } // How it works: GPU writes a buffer and CPU checks the buffer value to be // changed. The buffer is accessible for writing by GPU and reading by CPU // simultaneously - persistent buffer or buffer across shild context can be used // for that. -Status GlShaderSync::Wait() { +absl::Status GlShaderSync::Wait() { if (!flag_buffer_.is_valid()) { - return UnavailableError("GlShaderSync is not initialized."); + return absl::UnavailableError("GlShaderSync is not initialized."); } RETURN_IF_ERROR(flag_buffer_.BindToIndex(0)); volatile int* flag_ptr_ = reinterpret_cast(flag_buffer_.data()); @@ -115,7 +115,7 @@ Status GlShaderSync::Wait() { // Wait for the value is being updated by the shader. while (*flag_ptr_ != 1) { } - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_sync.h b/tensorflow/lite/delegates/gpu/gl/gl_sync.h index dadb4b1192f..8b5d910910d 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_sync.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_sync.h @@ -32,12 +32,12 @@ namespace gl { // GlSync is moveable but not copyable. class GlSync { public: - static Status NewSync(GlSync* gl_sync) { + static absl::Status NewSync(GlSync* gl_sync) { GLsync sync; RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glFenceSync, &sync, GL_SYNC_GPU_COMMANDS_COMPLETE, 0)); *gl_sync = GlSync(sync); - return OkStatus(); + return absl::OkStatus(); } // Creates invalid object. @@ -75,12 +75,12 @@ class GlSync { }; // Waits until GPU is done with processing. -Status GlSyncWait(); +absl::Status GlSyncWait(); // Waits until all commands are flushed and then performs active waiting by // spinning a thread and checking sync status. It leads to shorter wait time // (up to tens of ms) but consumes more CPU. -Status GlActiveSyncWait(); +absl::Status GlActiveSyncWait(); // CPU checks the value in the buffer that is going to be written by GPU. The // persistent buffer is used for the simultaneous access to the buffer by GPU @@ -88,9 +88,9 @@ Status GlActiveSyncWait(); // is not supported by the device. class GlShaderSync { public: - static Status NewSync(GlShaderSync* gl_sync); + static absl::Status NewSync(GlShaderSync* gl_sync); GlShaderSync() {} - Status Wait(); + absl::Status Wait(); private: GlProgram flag_program_; diff --git a/tensorflow/lite/delegates/gpu/gl/gl_texture.cc b/tensorflow/lite/delegates/gpu/gl/gl_texture.cc index eb20deca758..0267a52e44f 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_texture.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_texture.cc @@ -120,31 +120,31 @@ void GlTexture::Invalidate() { } } -Status GlTexture::BindImage(uint32_t index, GLenum access) const { +absl::Status GlTexture::BindImage(uint32_t index, GLenum access) const { return TFLITE_GPU_CALL_GL(glBindImageTexture, index, id_, /* level = */ 0, /* layered = */ GL_TRUE, layer_, access, format_); } -Status GlTexture::BindAsReadonlyImage(uint32_t index) const { +absl::Status GlTexture::BindAsReadonlyImage(uint32_t index) const { return BindImage(index, GL_READ_ONLY); } -Status GlTexture::BindAsWriteonlyImage(uint32_t index) const { +absl::Status GlTexture::BindAsWriteonlyImage(uint32_t index) const { return BindImage(index, GL_WRITE_ONLY); } -Status GlTexture::BindAsReadWriteImage(uint32_t index) const { +absl::Status GlTexture::BindAsReadWriteImage(uint32_t index) const { return BindImage(index, GL_READ_WRITE); } -Status GlTexture::BindAsSampler2D(uint32_t index) const { +absl::Status GlTexture::BindAsSampler2D(uint32_t index) const { RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glActiveTexture, GL_TEXTURE0 + index)); return TFLITE_GPU_CALL_GL(glBindTexture, GL_TEXTURE_2D, id_); } namespace { -Status SetTextureWrapAndFilter(GLenum target, GLenum texture_format) { +absl::Status SetTextureWrapAndFilter(GLenum target, GLenum texture_format) { if (texture_format == GL_RGBA32F) { RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, GL_TEXTURE_WRAP_S, GL_REPEAT)); @@ -177,14 +177,16 @@ Status SetTextureWrapAndFilter(GLenum target, GLenum texture_format) { RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, GL_TEXTURE_MIN_FILTER, GL_LINEAR)); } - return OkStatus(); + return absl::OkStatus(); } -Status CreateReadOnlyRgba2dImageTexture(DataType data_type, const uint2& size, - const void* data, size_t byte_size, - GlTexture* gl_texture) { +absl::Status CreateReadOnlyRgba2dImageTexture(DataType data_type, + const uint2& size, + const void* data, + size_t byte_size, + GlTexture* gl_texture) { if (byte_size != /* RGBA=*/4 * SizeOf(data_type) * size.x * size.y) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Creating image texture failed. Source data size is not matching " "expected dimensions."); } @@ -202,14 +204,16 @@ Status CreateReadOnlyRgba2dImageTexture(DataType data_type, const uint2& size, 0, 0, size.x, size.y, format, type, data)); *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, 0, /*owned=*/true); - return OkStatus(); + return absl::OkStatus(); } -Status CreateReadOnlyRgba3dImageTexture(DataType data_type, const uint3& size, - const void* data, size_t byte_size, - GlTexture* gl_texture) { +absl::Status CreateReadOnlyRgba3dImageTexture(DataType data_type, + const uint3& size, + const void* data, + size_t byte_size, + GlTexture* gl_texture) { if (byte_size != /* RGBA=*/4 * SizeOf(data_type) * size.x * size.y * size.z) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Creating image texture failed. Source data is larger than dimensions " "product."); } @@ -228,53 +232,54 @@ Status CreateReadOnlyRgba3dImageTexture(DataType data_type, const uint3& size, type, data)); *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, 0, /*owned=*/true); - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status CreateReadOnlyImageTexture(const uint2& size, - absl::Span data, - GlTexture* gl_texture) { +absl::Status CreateReadOnlyImageTexture(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba2dImageTexture(DataType::FLOAT32, size, data.data(), data.size() * sizeof(float), gl_texture); } -Status CreateReadOnlyImageTexture(const uint3& size, - absl::Span data, - GlTexture* gl_texture) { +absl::Status CreateReadOnlyImageTexture(const uint3& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba3dImageTexture(DataType::FLOAT32, size, data.data(), data.size() * sizeof(float), gl_texture); } -Status CreateReadOnlyImageTextureU8(const uint2& size, - absl::Span data, - GlTexture* gl_texture) { +absl::Status CreateReadOnlyImageTextureU8(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba2dImageTexture(DataType::UINT8, size, data.data(), data.size() * sizeof(uint8_t), gl_texture); } -Status CreateReadOnlyImageTextureF16(const uint2& size, - absl::Span data, - GlTexture* gl_texture) { +absl::Status CreateReadOnlyImageTextureF16(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba2dImageTexture(DataType::FLOAT16, size, data.data(), data.size() * sizeof(uint16_t), gl_texture); } -Status CreateReadOnlyImageTextureF16(const uint3& size, - absl::Span data, - GlTexture* gl_texture) { +absl::Status CreateReadOnlyImageTextureF16(const uint3& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba3dImageTexture(DataType::FLOAT16, size, data.data(), data.size() * sizeof(uint16_t), gl_texture); } -Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint2& size, - GlTexture* gl_texture) { +absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, + const uint2& size, + GlTexture* gl_texture) { const GLenum kTarget = GL_TEXTURE_2D; const GLenum internal_format = ToTextureInternalFormat(data_type); gl_texture_internal::TextureId id; @@ -287,11 +292,12 @@ Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint2& size, *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, /* layer = */ 0, /* owned = */ true); - return OkStatus(); + return absl::OkStatus(); } -Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint3& size, - GlTexture* gl_texture) { +absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, + const uint3& size, + GlTexture* gl_texture) { const GLenum kTarget = GL_TEXTURE_2D_ARRAY; GLenum internal_format = ToTextureInternalFormat(data_type); gl_texture_internal::TextureId id; @@ -305,7 +311,7 @@ Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint3& size, *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, /* layer = */ 0, /* owned = */ true); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_texture.h b/tensorflow/lite/delegates/gpu/gl/gl_texture.h index 951b22f23f1..60e22b47229 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_texture.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_texture.h @@ -57,16 +57,16 @@ class GlTexture { ~GlTexture(); // Binds a texture as an image to the given index. - Status BindAsReadonlyImage(uint32_t index) const; + absl::Status BindAsReadonlyImage(uint32_t index) const; // Bind texture as an image for write access at given index. - Status BindAsWriteonlyImage(uint32_t index) const; + absl::Status BindAsWriteonlyImage(uint32_t index) const; // Bind texture as an image for read-write access at given index. - Status BindAsReadWriteImage(uint32_t index) const; + absl::Status BindAsReadWriteImage(uint32_t index) const; // Binds a texture as a sampler to the given index. - Status BindAsSampler2D(uint32_t index) const; + absl::Status BindAsSampler2D(uint32_t index) const; GLenum target() const { return target_; } @@ -87,7 +87,7 @@ class GlTexture { private: void Invalidate(); - Status BindImage(uint32_t index, GLenum access) const; + absl::Status BindImage(uint32_t index, GLenum access) const; GLuint id_; GLenum target_; @@ -101,53 +101,55 @@ class GlTexture { // will be used for reading. // // @param size defines 2D image texture size where each pixel is RGBA. -Status CreateReadOnlyImageTexture(const uint2& size, - absl::Span data, - GlTexture* gl_texture); +absl::Status CreateReadOnlyImageTexture(const uint2& size, + absl::Span data, + GlTexture* gl_texture); // Creates new 2D image texture that will be filled with float16 data once which // will be used for reading. // // @param size defines 2D image texture size where each pixel is RGBA. -Status CreateReadOnlyImageTextureF16(const uint2& size, - absl::Span data, - GlTexture* gl_texture); +absl::Status CreateReadOnlyImageTextureF16(const uint2& size, + absl::Span data, + GlTexture* gl_texture); // Creates new 2D image texture that will be filled with uint8 data once which // will be used for reading. // // @param size defines 2D image texture size where each pixel is RGBA. -Status CreateReadOnlyImageTextureU8(const uint2& size, - absl::Span data, - GlTexture* gl_texture); +absl::Status CreateReadOnlyImageTextureU8(const uint2& size, + absl::Span data, + GlTexture* gl_texture); // Creates new 3D RGBA image texture that will be filled with float32 data once // which will be used for reading. // // @param size defines 3D image texture size where each pixel is RGBA. -Status CreateReadOnlyImageTexture(const uint3& size, - absl::Span data, - GlTexture* gl_texture); +absl::Status CreateReadOnlyImageTexture(const uint3& size, + absl::Span data, + GlTexture* gl_texture); // Creates new 3D RGBA image texture that will be filled with float16 data once // which will be used for reading. // // @param size defines 3D image texture size where each pixel is RGBA. -Status CreateReadOnlyImageTextureF16(const uint3& size, - absl::Span data, - GlTexture* gl_texture); +absl::Status CreateReadOnlyImageTextureF16(const uint3& size, + absl::Span data, + GlTexture* gl_texture); // Creates new RGBA 2D image texture // // @param size defines 2D image texture size where each pixel is RGBA. -Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint2& size, - GlTexture* gl_texture); +absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, + const uint2& size, + GlTexture* gl_texture); // Creates new RGBA 3D image texture // // @param size defines 3D image texture size where each pixel is RGBA. -Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint3& size, - GlTexture* gl_texture); +absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, + const uint3& size, + GlTexture* gl_texture); GLenum ToTextureFormat(DataType type); diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/add.cc b/tensorflow/lite/delegates/gpu/gl/kernels/add.cc index 12124a8cc57..135253112ba 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/add.cc @@ -34,8 +34,8 @@ namespace { class Add : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast(ctx.node->operation.attributes); auto adds = absl::get_if>(&attr.param); auto scalar = absl::get_if(&attr.param); @@ -60,13 +60,13 @@ class Add : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } std::string code = "value_0 = value_0"; for (int index = 1; index < inputs.size(); ++index) { if (inputs[index]->tensor.shape != inputs[0]->tensor.shape) { - return InvalidArgumentError("Shapes are not equal"); + return absl::InvalidArgumentError("Shapes are not equal"); } absl::StrAppend(&code, " + value_", index); } @@ -81,7 +81,7 @@ class Add : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } if (scalar) { @@ -111,7 +111,7 @@ class Add : public NodeShader { }; } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc b/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc index a97d618e0b6..43afab2922e 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc @@ -67,10 +67,10 @@ class AlignedConcatByChannels : public NodeShader { return true; } - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { if (!IsSupported(ctx)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "This case is not supported by aligned concat"); } auto inputs = ctx.graph->FindInputs(ctx.node->id); @@ -94,7 +94,7 @@ class AlignedConcatByChannels : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; @@ -127,10 +127,10 @@ class ConcatByAnyChannel : public NodeShader { return true; } - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { if (!IsSupported(ctx)) { - return UnimplementedError("This case is not supported by concat"); + return absl::UnimplementedError("This case is not supported by concat"); } auto inputs = ctx.graph->FindInputs(ctx.node->id); @@ -182,7 +182,7 @@ class ConcatByAnyChannel : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::ONLY_DEFINITIONS, }; - return OkStatus(); + return absl::OkStatus(); } private: @@ -348,8 +348,8 @@ class FlatConcatByHeight : public NodeShader { return true; } - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto inputs = ctx.graph->FindInputs(ctx.node->id); std::string code; std::vector params; @@ -382,7 +382,7 @@ class FlatConcatByHeight : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; @@ -415,8 +415,8 @@ class FlatConcatByWidth : public NodeShader { return true; } - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto inputs = ctx.graph->FindInputs(ctx.node->id); std::string code; std::vector params; @@ -449,21 +449,22 @@ class FlatConcatByWidth : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; class FlatConcat : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { if (FlatConcatByHeight::IsSupported(ctx)) { return flat_concat_by_height_.GenerateCode(ctx, generated_code); } if (FlatConcatByWidth::IsSupported(ctx)) { return flat_concat_by_width_.GenerateCode(ctx, generated_code); } - return InvalidArgumentError("This case is not supported by flat concat"); + return absl::InvalidArgumentError( + "This case is not supported by flat concat"); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc index 0b18a4c4246..5c88402c1d1 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc @@ -37,8 +37,8 @@ namespace { class Convolution : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto attr = absl::any_cast( ctx.node->operation.attributes); @@ -139,7 +139,7 @@ class Convolution : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; @@ -160,24 +160,24 @@ int SelectMultiplier(int32_t input_width, class Convolution1x1 : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = absl::any_cast( ctx.node->operation.attributes); if (attr.weights.shape.h != 1 || attr.weights.shape.w != 1) { - return UnimplementedError("Height and width should be 1."); + return absl::UnimplementedError("Height and width should be 1."); } if (attr.dilations.h != 1 || attr.dilations.w != 1) { - return UnimplementedError("Dilations are not supported."); + return absl::UnimplementedError("Dilations are not supported."); } if (attr.strides.h != 1 || attr.strides.w != 1) { - return UnimplementedError("Strides are not supported."); + return absl::UnimplementedError("Strides are not supported."); } if (attr.padding.appended.h != 0 || attr.padding.appended.w != 0 || attr.padding.prepended.h != 0 || attr.padding.prepended.w != 0) { - return UnimplementedError("Padding is not supported."); + return absl::UnimplementedError("Padding is not supported."); } int multiplier = SelectMultiplier(input->tensor.shape.w, ctx); @@ -280,7 +280,7 @@ class Convolution1x1 : public NodeShader { /*output=*/multiplier == 1 ? IOStructure::AUTO : IOStructure::ONLY_DEFINITIONS, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc index 189beedf815..bc4c61075a3 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc @@ -31,11 +31,11 @@ namespace gl { namespace { // Wraps given SSBO into GlBuffer object that does not have ownership. -Status WrapSSBO(OpenGlBuffer ssbo, GlBuffer* buffer) { +absl::Status WrapSSBO(OpenGlBuffer ssbo, GlBuffer* buffer) { int64_t size_bytes; RETURN_IF_ERROR(GetSSBOSize(ssbo.id, &size_bytes)); *buffer = GlBuffer(GL_SHADER_STORAGE_BUFFER, ssbo.id, size_bytes, 0, false); - return OkStatus(); + return absl::OkStatus(); } std::string GetShaderHeader(const uint3& localsize) { @@ -49,12 +49,12 @@ class OpenGlConverterImpl : public TensorObjectConverter { explicit OpenGlConverterImpl(CommandQueue* command_queue) : command_queue_(command_queue) {} - virtual Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def) = 0; + virtual absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def) = 0; protected: - Status InitializeProgram(const uint3& workgroup_size, - const std::string& shader_source) { + absl::Status InitializeProgram(const uint3& workgroup_size, + const std::string& shader_source) { workgroup_size_ = workgroup_size; GlShader shader; RETURN_IF_ERROR(GlShader::CompileShader( @@ -63,7 +63,7 @@ class OpenGlConverterImpl : public TensorObjectConverter { return GlProgram::CreateWithShader(shader, &program_); } - Status Dispatch(const uint3& workload) { + absl::Status Dispatch(const uint3& workload) { uint3 num_workgroups = IntegralDivideRoundUp(workload, workgroup_size_); if (command_queue_) { return command_queue_->Dispatch(program_, num_workgroups); @@ -103,12 +103,12 @@ class FromTensorConverter : public OpenGlConverterImpl { input.data_layout == DataLayout::DHWC4; } - Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def) final { + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def) final { shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h, output_def.dimensions.w, output_def.dimensions.c); if (shape_.b != 1) { - return UnimplementedError( + return absl::UnimplementedError( "FromTensorConverter: Batch size != 1 is not supported."); } @@ -135,18 +135,18 @@ class FromTensorConverter : public OpenGlConverterImpl { })"); } - Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto output = absl::get_if(&output_obj); if (!output || !output->id) { - return InvalidArgumentError("Missing output in converter"); + return absl::InvalidArgumentError("Missing output in converter"); } auto input = absl::get_if(&input_obj); if (!input || !input->id) { - return InvalidArgumentError("Missing input in converter"); + return absl::InvalidArgumentError("Missing input in converter"); } if (input->id == output->id) { - return InvalidArgumentError("Can not execute inplace conversion"); + return absl::InvalidArgumentError("Can not execute inplace conversion"); } GlBuffer input_ssbo; RETURN_IF_ERROR(WrapSSBO(*input, &input_ssbo)); @@ -154,11 +154,11 @@ class FromTensorConverter : public OpenGlConverterImpl { RETURN_IF_ERROR(WrapSSBO(*output, &output_ssbo)); if (input_ssbo.bytes_size() != SizeInBytesDHWC4(shape_)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "FromTensorConverter: input data size does not match expected size."); } if (output_ssbo.bytes_size() != SizeInBytesBHWC(shape_)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "FromTensorConverter: output data size does not match expected " "size."); } @@ -191,12 +191,12 @@ class ToTensorConverter : public OpenGlConverterImpl { output.data_layout == DataLayout::DHWC4; } - Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def) final { + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def) final { shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h, output_def.dimensions.w, output_def.dimensions.c); if (shape_.b != 1) { - return UnimplementedError( + return absl::UnimplementedError( "FromTensorConverter: Batch size != 1 is not supported."); } @@ -230,18 +230,18 @@ class ToTensorConverter : public OpenGlConverterImpl { })"); } - Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto output = absl::get_if(&output_obj); if (!output || !output->id) { - return InvalidArgumentError("Missing output in converter"); + return absl::InvalidArgumentError("Missing output in converter"); } auto input = absl::get_if(&input_obj); if (!input || !input->id) { - return InvalidArgumentError("Missing input in converter"); + return absl::InvalidArgumentError("Missing input in converter"); } if (input->id == output->id) { - return InvalidArgumentError("Can not execute inplace conversion"); + return absl::InvalidArgumentError("Can not execute inplace conversion"); } GlBuffer input_ssbo; RETURN_IF_ERROR(WrapSSBO(*input, &input_ssbo)); @@ -249,11 +249,11 @@ class ToTensorConverter : public OpenGlConverterImpl { RETURN_IF_ERROR(WrapSSBO(*output, &output_ssbo)); if (input_ssbo.bytes_size() != SizeInBytesBHWC(shape_)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "ToTensorConverter: input data size does not match expected size."); } if (output_ssbo.bytes_size() != SizeInBytesDHWC4(shape_)) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "ToTensorConverter: output data size does not match expected size."); } auto d = IntegralDivideRoundUp(shape_.c, 4); @@ -279,19 +279,19 @@ class TrivialCopier : public TensorObjectConverter { input.data_layout == output.data_layout; } - Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto ssbo_input = absl::get_if(&input_obj); auto ssbo_output = absl::get_if(&output_obj); if (ssbo_input && ssbo_output) { return Copy(*ssbo_input, *ssbo_output); } - return InternalError("Unexpected object"); + return absl::InternalError("Unexpected object"); } - Status Copy(OpenGlBuffer input, OpenGlBuffer output) { + absl::Status Copy(OpenGlBuffer input, OpenGlBuffer output) { if (input.id == output.id) { - return OkStatus(); + return absl::OkStatus(); } GlBuffer input_obj; RETURN_IF_ERROR(WrapSSBO(input, &input_obj)); @@ -313,8 +313,8 @@ class CpuCopier : public TensorObjectConverter { input.object_type == ObjectType::OPENGL_SSBO)); } - Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto cpu_input = absl::get_if(&input_obj); auto cpu_output = absl::get_if(&output_obj); if (cpu_input) { @@ -335,7 +335,7 @@ class CpuCopier : public TensorObjectConverter { static_cast(cpu_output->data), cpu_output->size_bytes)); } } - return InternalError("Unexpected object"); + return absl::InternalError("Unexpected object"); } }; @@ -355,7 +355,7 @@ class TensorConverterBuilderImpl : public TensorObjectConverterBuilder { ToTensorConverter::IsSupported(input_def, output_def)); } - Status MakeConverter( + absl::Status MakeConverter( const TensorObjectDef& input, const TensorObjectDef& output, std::unique_ptr* converter) final { std::unique_ptr impl; @@ -363,20 +363,22 @@ class TensorConverterBuilderImpl : public TensorObjectConverterBuilder { const auto& output_def = output.object_def; if (TrivialCopier::IsSupported(input_def, output_def)) { *converter = absl::make_unique(); - return OkStatus(); - } else if (CpuCopier::IsSupported(input_def, output_def)) { + return absl::OkStatus(); + } + if (CpuCopier::IsSupported(input_def, output_def)) { *converter = absl::make_unique(); - return OkStatus(); - } else if (FromTensorConverter::IsSupported(input_def, output_def)) { + return absl::OkStatus(); + } + if (FromTensorConverter::IsSupported(input_def, output_def)) { impl = absl::make_unique(command_queue_); } else if (ToTensorConverter::IsSupported(input_def, output_def)) { impl = absl::make_unique(command_queue_); } else { - return UnimplementedError("Unsupported conversion"); + return absl::UnimplementedError("Unsupported conversion"); } RETURN_IF_ERROR(impl->Init(input, output)); *converter = std::move(impl); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc index daba2f6d9ef..5f14f093c55 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc @@ -45,7 +45,7 @@ Dimensions ToDimensions(const BHWC& shape) { return Dimensions(shape.b, shape.h, shape.w, shape.c); } -Status RunFromTensorTest(const BHWC& shape) { +absl::Status RunFromTensorTest(const BHWC& shape) { // Create random input and calculate expected output for it. std::vector input = GenerateFloats(0.01, GetElementsSizeForPHWC4(shape)); @@ -85,9 +85,9 @@ Status RunFromTensorTest(const BHWC& shape) { RETURN_IF_ERROR(output_buffer.Read( absl::MakeSpan(converted_output.data(), converted_output.size()))); if (output != converted_output) { - return InternalError("Outputs don't match"); + return absl::InternalError("Outputs don't match"); } - return OkStatus(); + return absl::OkStatus(); } TEST(FromTensor, Smoke) { @@ -103,7 +103,7 @@ TEST(FromTensor, Smoke) { } } -Status RunToTensorTest(const BHWC& shape) { +absl::Status RunToTensorTest(const BHWC& shape) { // Create random input and calculate expected output for it. std::vector input = GenerateFloats(0.01, shape.DimensionsProduct()); std::vector output(GetElementsSizeForPHWC4(shape), 0); @@ -142,9 +142,9 @@ Status RunToTensorTest(const BHWC& shape) { RETURN_IF_ERROR(output_buffer.Read( absl::MakeSpan(converted_output.data(), converted_output.size()))); if (output != converted_output) { - return InternalError("Outputs don't match"); + return absl::InternalError("Outputs don't match"); } - return OkStatus(); + return absl::OkStatus(); } TEST(ToTensor, Smoke) { diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc index a8d71a943b7..38ddbf361b4 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc @@ -36,8 +36,8 @@ namespace { class DepthwiseConvolution : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto attr = absl::any_cast( ctx.node->operation.attributes); @@ -146,7 +146,7 @@ class DepthwiseConvolution : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc index 35b233cbdcc..aa254770535 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc @@ -31,8 +31,8 @@ class ElementwiseOneArgument : public NodeShader { public: explicit ElementwiseOneArgument(OperationType operation_type) : operation_type_(operation_type) {} - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::string source; switch (operation_type_) { case OperationType::ABS: @@ -89,7 +89,8 @@ class ElementwiseOneArgument : public NodeShader { source = "value_0 = tanh(value_0);"; break; default: - return InvalidArgumentError("Incorrect elementwise operation type."); + return absl::InvalidArgumentError( + "Incorrect elementwise operation type."); } *generated_code = { /*parameters=*/{}, @@ -101,7 +102,7 @@ class ElementwiseOneArgument : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } private: @@ -144,8 +145,8 @@ class ElementwiseTwoArguments : public NodeShader { return true; } - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::vector parameters; std::vector> objects; std::string argument0, argument1; @@ -159,7 +160,7 @@ class ElementwiseTwoArguments : public NodeShader { const ElementwiseAttributes* attr = absl::any_cast( &ctx.node->operation.attributes); if (!attr) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Couldn't read attributes for the scalar of const vector case."); } auto* tensor = @@ -167,7 +168,7 @@ class ElementwiseTwoArguments : public NodeShader { &attr->param); auto* scalar = absl::get_if(&attr->param); if (!tensor && !scalar) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Couldn't read scalar of const vector data from the attributes."); } @@ -208,7 +209,7 @@ class ElementwiseTwoArguments : public NodeShader { break; } default: - return InvalidArgumentError( + return absl::InvalidArgumentError( "Incorrect elementwise with scalar operation type."); } source = absl::Substitute(source, argument0, argument1); @@ -222,7 +223,7 @@ class ElementwiseTwoArguments : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc index f4ad5b8cc0a..a8246515247 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc @@ -34,8 +34,8 @@ namespace { class FullyConnectedBuffers : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast( ctx.node->operation.attributes); @@ -106,7 +106,7 @@ class FullyConnectedBuffers : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::ONLY_DEFINITIONS, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc index e248cdfb31a..7179ba00581 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc @@ -43,8 +43,8 @@ namespace { // class LstmNodeShader : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::string code = R"( vec4 prev_state = $input_data_1[gid.x, gid.y, gid.z]$; @@ -80,7 +80,7 @@ class LstmNodeShader : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc index 2e977625489..c8961eee087 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc @@ -33,8 +33,8 @@ namespace { class MaxUnpooling : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast( ctx.node->operation.attributes); std::vector parameters = { @@ -66,7 +66,7 @@ class MaxUnpooling : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc index 9328351f169..e94c952ffaa 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc @@ -32,11 +32,11 @@ namespace { class Mean : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast(ctx.node->operation.attributes); if (attr.dims != std::set({Axis::HEIGHT, Axis::WIDTH})) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Mean calculation is supported only for height and width."); } @@ -72,7 +72,7 @@ class Mean : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc index 7de4caea81d..6e825dc862d 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc @@ -52,8 +52,8 @@ bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) { return shape1.h == 1 && shape1.w == 1 && shape0.c == shape1.c; } -Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { +absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { const auto inputs = ctx.graph->FindInputs(ctx.node->id); const auto& shape0 = inputs[0]->tensor.shape; const auto& shape1 = inputs[1]->tensor.shape; @@ -80,11 +80,11 @@ Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } -Status GenerateMultiplyScalarCode(const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { +absl::Status GenerateMultiplyScalarCode( + const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) { auto attr = absl::any_cast(ctx.node->operation.attributes); auto muls = absl::get_if>(&attr.param); @@ -103,7 +103,7 @@ Status GenerateMultiplyScalarCode(const NodeShader::GenerationContext& ctx, }; } else { if (!muls) { - return InvalidArgumentError("Empty parameters for Multiplication."); + return absl::InvalidArgumentError("Empty parameters for Multiplication."); } auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape; *generated_code = { @@ -120,13 +120,13 @@ Status GenerateMultiplyScalarCode(const NodeShader::GenerationContext& ctx, }; } - return OkStatus(); + return absl::OkStatus(); } class Multiply : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { if (IsApplyMaskSupported(ctx)) { return GenerateApplyMaskCode(ctx, generated_code); } else { diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc b/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc index 14fe55d943a..3fc84aa675e 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc @@ -34,22 +34,22 @@ namespace { class Pad : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto attr = absl::any_cast(ctx.node->operation.attributes); if (attr.type != PaddingContentType::ZEROS && attr.type != PaddingContentType::REFLECT) { - return UnimplementedError( + return absl::UnimplementedError( "Only ZERO and REFLECT padding types are supported."); } if (attr.appended.h < 0 || attr.appended.w < 0 || attr.appended.c < 0 || attr.prepended.h < 0 || attr.prepended.w < 0 || attr.prepended.c < 0) { - return UnimplementedError("Negative padding is not supported."); + return absl::UnimplementedError("Negative padding is not supported."); } if (attr.appended.b != 0 || attr.prepended.b != 0) { - return UnimplementedError("Padding for BATCH is not supported."); + return absl::UnimplementedError("Padding for BATCH is not supported."); } std::vector parameters = { {"input_data_0_h", input->tensor.shape.h}, @@ -130,7 +130,7 @@ class Pad : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc index 8f140c33fca..5c6aefcde1c 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc @@ -31,14 +31,14 @@ namespace gpu { namespace gl { namespace { -Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr, - const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { +absl::Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr, + const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; if (attr.padding.prepended.h > attr.kernel.h || attr.padding.prepended.w > attr.kernel.w) { - return InvalidArgumentError("Padding is bigger than kernel."); + return absl::InvalidArgumentError("Padding is bigger than kernel."); } std::vector parameters = { @@ -94,12 +94,12 @@ Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr, /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } -Status GenerateAveragePoolingCode(const Pooling2DAttributes& attr, - const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { +absl::Status GenerateAveragePoolingCode( + const Pooling2DAttributes& attr, const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; std::vector parameters = { @@ -136,13 +136,13 @@ Status GenerateAveragePoolingCode(const Pooling2DAttributes& attr, /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } class Pooling : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { const auto& attr = absl::any_cast(ctx.node->operation.attributes); switch (attr.type) { @@ -151,7 +151,7 @@ class Pooling : public NodeShader { case PoolingType::MAX: return GenerateMaxPoolingCode(attr, ctx, generated_code); default: - return InvalidArgumentError("Incorrect attributes' type."); + return absl::InvalidArgumentError("Incorrect attributes' type."); } } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc index 88078935ee2..28f8551f530 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc @@ -35,17 +35,17 @@ namespace { class PReLULinearAlpha : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = absl::any_cast(ctx.node->operation.attributes); auto alpha = absl::get_if>(&attr.alpha); if (!alpha) { - return InvalidArgumentError("Alpha is missing"); + return absl::InvalidArgumentError("Alpha is missing"); } if (alpha->shape.v != output->tensor.shape.c) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Alpha shape does not match the number of channels."); } @@ -79,25 +79,26 @@ class PReLULinearAlpha : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; class PReLUFull : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = absl::any_cast(ctx.node->operation.attributes); auto alpha = absl::get_if>(&attr.alpha); if (!alpha) { - return InvalidArgumentError("Alpha is missing"); + return absl::InvalidArgumentError("Alpha is missing"); } if (alpha->shape.h != output->tensor.shape.h || alpha->shape.w != output->tensor.shape.w || alpha->shape.c != output->tensor.shape.c) { - return InvalidArgumentError("Alpha shape does not match input shape."); + return absl::InvalidArgumentError( + "Alpha shape does not match input shape."); } auto shape = output->tensor.shape; @@ -141,14 +142,14 @@ class PReLUFull : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; class PReLU : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast(ctx.node->operation.attributes); auto alpha = absl::get_if>(&attr.alpha); diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc index 3f21124aee9..1d45e07aeee 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc @@ -31,8 +31,8 @@ namespace { class QuantizeAndDequantize : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::string code; // Constants code += "vec4 scale = vec4($quant_scale$);"; @@ -59,7 +59,7 @@ class QuantizeAndDequantize : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc index 6903abc0b26..8f6de92acd8 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc @@ -120,19 +120,19 @@ class Registry : public NodeShader { ~Registry() final = default; - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::vector errors; auto it = shaders_.find(ctx.node->operation.type); if (it != shaders_.end()) { for (auto& shader : it->second) { const auto status = shader->GenerateCode(ctx, generated_code); if (status.ok()) return status; - errors.push_back(status.error_message()); + errors.push_back(std::string(status.message())); } } - return NotFoundError(absl::StrCat("Suitable node shader is not found: ", - absl::StrJoin(errors, ", "))); + return absl::NotFoundError(absl::StrCat( + "Suitable node shader is not found: ", absl::StrJoin(errors, ", "))); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc b/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc index a8e006ed151..a9357968a90 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc @@ -33,8 +33,8 @@ namespace { class ReLU : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast(ctx.node->operation.attributes); // clamp(value, min(0, alpha * value), clip) std::vector params; @@ -62,7 +62,7 @@ class ReLU : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc b/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc index cd01417cff5..9734ff14a1e 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc @@ -32,19 +32,19 @@ namespace { class Reshape : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; if (input->tensor.shape.DimensionsProduct() != output->tensor.shape.DimensionsProduct()) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Number of elements in input & output tensors don't match."); } auto attr = absl::any_cast(ctx.node->operation.attributes); if (attr.new_shape != output->tensor.shape) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Dimensions for output does not match new_shape attribute"); } @@ -80,7 +80,7 @@ class Reshape : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc index 33d59518987..004ae14fe8b 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc @@ -33,10 +33,8 @@ namespace { class Resize : public NodeShader { public: - Resize() {} - - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = @@ -44,15 +42,15 @@ class Resize : public NodeShader { if (input->tensor.shape.w > output->tensor.shape.w || input->tensor.shape.h > output->tensor.shape.h) { - return InvalidArgumentError("Output size is less than input size."); + return absl::InvalidArgumentError("Output size is less than input size."); } if (output->tensor.shape.w != attr.new_shape.w || output->tensor.shape.h != attr.new_shape.h) { - return InvalidArgumentError( + return absl::InvalidArgumentError( "Output size does not match new_size in attributes."); } if (input->tensor.shape.c != output->tensor.shape.c) { - return InvalidArgumentError("Input/output channels mismatch."); + return absl::InvalidArgumentError("Input/output channels mismatch."); } if (input->tensor.shape.h == 1 && input->tensor.shape.w == 1) { // Copy a single element from input. @@ -66,7 +64,7 @@ class Resize : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } std::vector parameters = { {"input_data_0_h", input->tensor.shape.h}, @@ -107,7 +105,7 @@ class Resize : public NodeShader { value_0 = $input_data_0[coord.x, coord.y, gid.z]$; )"; } else { - return InvalidArgumentError("Unknown sampling type"); + return absl::InvalidArgumentError("Unknown sampling type"); } *generated_code = { /*parameters=*/std::move(parameters), @@ -119,7 +117,7 @@ class Resize : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc b/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc index d0fe1923d4e..ab4497c4b62 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc @@ -33,8 +33,8 @@ namespace { class Slice : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = @@ -107,7 +107,7 @@ class Slice : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc index e59343df7b6..b6c8e144a09 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc @@ -41,17 +41,19 @@ float4 GetMask(int num_channels) { class Softmax : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { const auto* input = ctx.graph->FindInputs(ctx.node->id)[0]; const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0]; const auto& attr = absl::any_cast( ctx.node->operation.attributes); if (input->tensor.shape != output->tensor.shape) { - return InvalidArgumentError("Input and output shapes do not match."); + return absl::InvalidArgumentError( + "Input and output shapes do not match."); } if (attr.axis != Axis::CHANNELS) { - return UnimplementedError("Softmax is only supported for channels axis."); + return absl::UnimplementedError( + "Softmax is only supported for channels axis."); } return input->tensor.shape.h == 1 && input->tensor.shape.w == 1 ? GenerateCodeFor1x1(ctx, generated_code) @@ -59,8 +61,8 @@ class Softmax : public NodeShader { } private: - Status GenerateCodeFor1x1(const GenerationContext& ctx, - GeneratedCode* generated_code) const { + absl::Status GenerateCodeFor1x1(const GenerationContext& ctx, + GeneratedCode* generated_code) const { const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0]; const int depth = IntegralDivideRoundUp(output->tensor.shape.c, 4); std::vector shared_variables = { @@ -133,11 +135,11 @@ class Softmax : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::ONLY_DEFINITIONS, }; - return OkStatus(); + return absl::OkStatus(); } - Status GenerateCodeGeneral(const GenerationContext& ctx, - GeneratedCode* generated_code) const { + absl::Status GenerateCodeGeneral(const GenerationContext& ctx, + GeneratedCode* generated_code) const { const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0]; std::vector parameters = { {"src_depth", IntegralDivideRoundUp(output->tensor.shape.c, 4)}, @@ -172,7 +174,7 @@ class Softmax : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::ONLY_DEFINITIONS, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.cc b/tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.cc index 1d49da0e3fa..b1e650a1ffc 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.cc @@ -31,8 +31,8 @@ namespace { class SpaceToDepth : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { const auto attr = absl::any_cast(ctx.node->operation.attributes); const auto& input_data_0 = ctx.graph->FindInputs(ctx.node->id)[0]->tensor; @@ -60,7 +60,7 @@ class SpaceToDepth : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; } // namespace diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc index de6e324017d..e9abec7eec6 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc @@ -68,9 +68,9 @@ bool SingleOpModel::PopulateTensor(int index, std::vector&& data) { return true; } -Status SingleOpModel::Invoke(const CompilationOptions& compile_options, - const RuntimeOptions& runtime_options, - const NodeShader& shader) { +absl::Status SingleOpModel::Invoke(const CompilationOptions& compile_options, + const RuntimeOptions& runtime_options, + const NodeShader& shader) { std::unique_ptr env; RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env)); @@ -125,10 +125,10 @@ Status SingleOpModel::Invoke(const CompilationOptions& compile_options, CopyFromPHWC4Buffer(*objects.FindBuffer(output->id), &tensor)); outputs_.push_back(std::move(tensor)); } - return OkStatus(); + return absl::OkStatus(); } -Status SingleOpModel::Invoke(const NodeShader& shader) { +absl::Status SingleOpModel::Invoke(const NodeShader& shader) { return Invoke(CompilationOptions(), RuntimeOptions(), shader); } diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h index c917220d075..42a789020df 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h +++ b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h @@ -48,10 +48,10 @@ class SingleOpModel { bool PopulateTensor(int index, std::vector&& data); - Status Invoke(const NodeShader& shader); - Status Invoke(const CompilationOptions& compile_options, - const RuntimeOptions& runtime_options, - const NodeShader& shader); + absl::Status Invoke(const NodeShader& shader); + absl::Status Invoke(const CompilationOptions& compile_options, + const RuntimeOptions& runtime_options, + const NodeShader& shader); const std::vector& GetOutput(int index) const { return outputs_[index].data; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc index 7fcfde4f92a..eb28672d49f 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc @@ -35,8 +35,8 @@ namespace { class ConvolutionTransposedBuffers : public NodeShader { public: - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto attr = absl::any_cast( ctx.node->operation.attributes); @@ -63,10 +63,10 @@ class ConvolutionTransposedBuffers : public NodeShader { ivec2 p0 = ($padding$ + $stride$ - gid.xy % $stride$) % $stride$; for (int y = p0.y; y < $kernel_size.y$; y += $stride.y$) { for (int x = p0.x; x < $kernel_size.x$; x += $stride.x$) { - - int i = int(float(y * $kernel_size.x$) + float(x)); + + int i = int(float(y * $kernel_size.x$) + float(x)); ivec2 idx = ivec2(vec2(gid.xy + ivec2(x, y)) - vec2($padding$)); - + if (IN_BOUNDS(idx, ivec2(0), ivec2($input_data_0_w$, $input_data_0_h$) * $stride$)) { ivec2 coord = idx / $stride$; for (int l = 0; l < $src_depth$; ++l) { @@ -94,7 +94,7 @@ class ConvolutionTransposedBuffers : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/node_shader.h b/tensorflow/lite/delegates/gpu/gl/node_shader.h index 38364656b7a..d98bdbf8914 100644 --- a/tensorflow/lite/delegates/gpu/gl/node_shader.h +++ b/tensorflow/lite/delegates/gpu/gl/node_shader.h @@ -101,8 +101,8 @@ class NodeShader { }; // Generates shader code for a node. The code should be just a function body. - virtual Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const = 0; + virtual absl::Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const = 0; // Limit the size of the const offsets array static constexpr int kMaxConstArraySize = 9; diff --git a/tensorflow/lite/delegates/gpu/gl/object_manager.cc b/tensorflow/lite/delegates/gpu/gl/object_manager.cc index 4eca794a20a..c37be507b2b 100644 --- a/tensorflow/lite/delegates/gpu/gl/object_manager.cc +++ b/tensorflow/lite/delegates/gpu/gl/object_manager.cc @@ -24,21 +24,22 @@ namespace tflite { namespace gpu { namespace gl { -Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, - GlBuffer* gl_buffer) { +absl::Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, + GlBuffer* gl_buffer) { std::vector transposed(GetElementsSizeForPHWC4(tensor.shape)); RETURN_IF_ERROR( ConvertToPHWC4(tensor.data, tensor.shape, absl::MakeSpan(transposed))); return CreateReadOnlyShaderStorageBuffer(transposed, gl_buffer); } -Status CreatePHWC4BufferFromTensorRef(const TensorRef& tensor_ref, - GlBuffer* gl_buffer) { +absl::Status CreatePHWC4BufferFromTensorRef(const TensorRef& tensor_ref, + GlBuffer* gl_buffer) { return CreateReadWriteShaderStorageBuffer( GetElementsSizeForPHWC4(tensor_ref.shape), gl_buffer); } -Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor) { +absl::Status CopyFromPHWC4Buffer(const GlBuffer& buffer, + TensorFloat32* tensor) { return buffer.MappedRead( [tensor, &buffer](absl::Span data) { tensor->data.resize(tensor->shape.DimensionsProduct()); @@ -47,12 +48,12 @@ Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor) { }); } -Status ObjectManager::RegisterBuffer(uint32_t id, GlBuffer buffer) { +absl::Status ObjectManager::RegisterBuffer(uint32_t id, GlBuffer buffer) { if (id >= buffers_.size()) { buffers_.resize(id + 1); } buffers_[id] = absl::make_unique(std::move(buffer)); - return OkStatus(); + return absl::OkStatus(); } void ObjectManager::RemoveBuffer(uint32_t id) { @@ -65,12 +66,12 @@ GlBuffer* ObjectManager::FindBuffer(uint32_t id) const { return id >= buffers_.size() ? nullptr : buffers_[id].get(); } -Status ObjectManager::RegisterTexture(uint32_t id, GlTexture texture) { +absl::Status ObjectManager::RegisterTexture(uint32_t id, GlTexture texture) { if (id >= textures_.size()) { textures_.resize(id + 1); } textures_[id] = absl::make_unique(std::move(texture)); - return OkStatus(); + return absl::OkStatus(); } void ObjectManager::RemoveTexture(uint32_t id) { diff --git a/tensorflow/lite/delegates/gpu/gl/object_manager.h b/tensorflow/lite/delegates/gpu/gl/object_manager.h index 8fa82871b50..0a7de28e1dc 100644 --- a/tensorflow/lite/delegates/gpu/gl/object_manager.h +++ b/tensorflow/lite/delegates/gpu/gl/object_manager.h @@ -41,7 +41,7 @@ namespace gl { class ObjectManager { public: // Moves ownership over the given buffer to the manager. - Status RegisterBuffer(uint32_t id, GlBuffer buffer); + absl::Status RegisterBuffer(uint32_t id, GlBuffer buffer); void RemoveBuffer(uint32_t id); @@ -49,7 +49,7 @@ class ObjectManager { GlBuffer* FindBuffer(uint32_t id) const; // Moves ownership over the given texture to the manager. - Status RegisterTexture(uint32_t id, GlTexture texture); + absl::Status RegisterTexture(uint32_t id, GlTexture texture); void RemoveTexture(uint32_t id); @@ -67,17 +67,17 @@ class ObjectManager { // Creates read-only buffer from the given tensor. Tensor data is converted to // PHWC4 layout. -Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, - GlBuffer* gl_buffer); +absl::Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, + GlBuffer* gl_buffer); // Creates read-write buffer for the given tensor shape, where data layout is // supposed to be PHWC4. -Status CreatePHWC4BufferFromTensorRef(const TensorRef& tensor_ref, - GlBuffer* gl_buffer); +absl::Status CreatePHWC4BufferFromTensorRef(const TensorRef& tensor_ref, + GlBuffer* gl_buffer); // Copies data from a buffer that holds data in PHWC4 layout to the given // tensor. -Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor); +absl::Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc index 7134fc010d0..0769a5014b4 100644 --- a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc @@ -28,7 +28,7 @@ namespace tflite { namespace gpu { namespace gl { -Status RequestGpuInfo(GpuInfo* gpu_info) { +absl::Status RequestGpuInfo(GpuInfo* gpu_info) { GpuInfo info; const GLubyte* renderer_name = glGetString(GL_RENDERER); @@ -73,7 +73,7 @@ Status RequestGpuInfo(GpuInfo* gpu_info) { glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS, &info.max_array_texture_layers); RETURN_IF_ERROR(GetOpenGlErrors()); *gpu_info = info; - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h index 4eba7a55c2a..f9d203e2325 100644 --- a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h +++ b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h @@ -28,7 +28,7 @@ namespace gl { // This method performs multiple GL calls, therefore, egl context needs to be // created upfront. -Status RequestGpuInfo(GpuInfo* gpu_info); +absl::Status RequestGpuInfo(GpuInfo* gpu_info); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.cc b/tensorflow/lite/delegates/gpu/gl/runtime.cc index 14e30389cf0..2a48b59c8d9 100644 --- a/tensorflow/lite/delegates/gpu/gl/runtime.cc +++ b/tensorflow/lite/delegates/gpu/gl/runtime.cc @@ -41,13 +41,13 @@ namespace gl { namespace { struct TextureF16Maker { - Status operator()(const uint3& size) const { + absl::Status operator()(const uint3& size) const { return CreateReadOnlyImageTextureF16(size, data, gl_texture); } - Status operator()(const uint2& size) const { + absl::Status operator()(const uint2& size) const { return CreateReadOnlyImageTextureF16(size, data, gl_texture); } - Status operator()(const size_t& size) const { + absl::Status operator()(const size_t& size) const { return CreateReadOnlyImageTextureF16(uint2(static_cast(size), 1U), data, gl_texture); } @@ -56,13 +56,13 @@ struct TextureF16Maker { }; struct TextureF32Maker { - Status operator()(const uint3& size) const { + absl::Status operator()(const uint3& size) const { return CreateReadOnlyImageTexture(size, data, gl_texture); } - Status operator()(const uint2& size) const { + absl::Status operator()(const uint2& size) const { return CreateReadOnlyImageTexture(size, data, gl_texture); } - Status operator()(const size_t& size) const { + absl::Status operator()(const size_t& size) const { return CreateReadOnlyImageTexture(uint2(static_cast(size), 1U), data, gl_texture); } @@ -70,20 +70,21 @@ struct TextureF32Maker { GlTexture* gl_texture; }; -Status MakeGlTexture(const Object& object, const ObjectData& data, - GlTexture* gl_texture) { +absl::Status MakeGlTexture(const Object& object, const ObjectData& data, + GlTexture* gl_texture) { if (object.access == AccessType::READ_WRITE || object.access == AccessType::WRITE) { - return InvalidArgumentError("Read-write textures are not supported"); + return absl::InvalidArgumentError("Read-write textures are not supported"); } if (object.data_type != DataType::FLOAT16 && object.data_type != DataType::FLOAT32) { - return InvalidArgumentError("Textures support float16 or float32 only."); + return absl::InvalidArgumentError( + "Textures support float16 or float32 only."); } switch (object.data_type) { case DataType::FLOAT16: { if (data.size() % 2 != 0) { - return InvalidArgumentError("Texture size is not aligned"); + return absl::InvalidArgumentError("Texture size is not aligned"); } return absl::visit( TextureF16Maker{ @@ -96,7 +97,7 @@ Status MakeGlTexture(const Object& object, const ObjectData& data, } case DataType::FLOAT32: { if (data.size() % sizeof(float) != 0) { - return InvalidArgumentError("Texture size is not aligned"); + return absl::InvalidArgumentError("Texture size is not aligned"); } return absl::visit( TextureF32Maker{ @@ -108,18 +109,18 @@ Status MakeGlTexture(const Object& object, const ObjectData& data, object.size); } default: - return InvalidArgumentError("Unsupported textures data type."); + return absl::InvalidArgumentError("Unsupported textures data type."); } } struct TextureRefMaker { - Status operator()(const uint3& size) const { + absl::Status operator()(const uint3& size) const { return CreateReadWriteRgbaImageTexture(type, size, gl_texture); } - Status operator()(const uint2& size) const { + absl::Status operator()(const uint2& size) const { return CreateReadWriteRgbaImageTexture(type, size, gl_texture); } - Status operator()(const size_t& size) const { + absl::Status operator()(const size_t& size) const { return CreateReadWriteRgbaImageTexture( type, uint2(static_cast(size), 1U), gl_texture); } @@ -128,37 +129,38 @@ struct TextureRefMaker { }; // Makes read-write gl texture -Status MakeGlTextureRef(const Object& object, GlTexture* gl_texture) { +absl::Status MakeGlTextureRef(const Object& object, GlTexture* gl_texture) { return absl::visit(TextureRefMaker{object.data_type, gl_texture}, object.size); } -Status MakeGlBuffer(const Object& object, const ObjectData& data, - GlBuffer* gl_buffer) { +absl::Status MakeGlBuffer(const Object& object, const ObjectData& data, + GlBuffer* gl_buffer) { if (data.size() % SizeOf(object.data_type) != 0) { - return InvalidArgumentError("Buffer size is not aligned"); + return absl::InvalidArgumentError("Buffer size is not aligned"); } return CreateReadOnlyShaderStorageBuffer(absl::MakeConstSpan(data), gl_buffer); } // Looks up an object with the given id. If found, makes a binding function. -Status MakeBindingFunc(const Object& object, uint32_t id, - const ObjectManager& objects, - std::function* binding_func) { +absl::Status MakeBindingFunc(const Object& object, uint32_t id, + const ObjectManager& objects, + std::function* binding_func) { const uint32_t binding = object.binding; switch (object.object_type) { case ObjectType::BUFFER: { auto ptr = objects.FindBuffer(id); if (!ptr) { - return NotFoundError(absl::StrCat("Buffer ", id, " is not found")); + return absl::NotFoundError( + absl::StrCat("Buffer ", id, " is not found")); } // Validate buffer. size_t size_in_bytes = ByteSizeOf(object); // TODO(akulik): make comparison != instead of < if (ptr->bytes_size() < size_in_bytes) { - return FailedPreconditionError( + return absl::FailedPreconditionError( absl::StrCat("Buffer ", id, " size in bytes ", ptr->bytes_size(), " < requested size_in_bytes ", size_in_bytes)); } @@ -168,15 +170,16 @@ Status MakeBindingFunc(const Object& object, uint32_t id, case ObjectType::TEXTURE: { auto ptr = objects.FindTexture(id); if (!ptr) { - return NotFoundError(absl::StrCat("Texture ", id, " is not found")); + return absl::NotFoundError( + absl::StrCat("Texture ", id, " is not found")); } *binding_func = [=]() { return ptr->BindAsReadWriteImage(binding); }; break; } case ObjectType::UNKNOWN: - return InvalidArgumentError("Unknown object type"); + return absl::InvalidArgumentError("Unknown object type"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -194,10 +197,10 @@ Runtime::Runtime(const RuntimeOptions& options, const GpuInfo& gpu_info, } } -Status Runtime::AddProgram(const GlShader& shader, - const std::vector& parameters, - const std::vector& objects, - const uint3& num_workgroups) { +absl::Status Runtime::AddProgram(const GlShader& shader, + const std::vector& parameters, + const std::vector& objects, + const uint3& num_workgroups) { GlProgram program; RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); @@ -217,10 +220,10 @@ Status Runtime::AddProgram(const GlShader& shader, // Reference object could be provided externally as a model input/output // but also for debugging purposes. Otherwise all references are collected // and allocated later. - Status status = MakeBindingFunc(object, GetRef(object), - *external_objects_, &binding_func); + absl::Status status = MakeBindingFunc(object, GetRef(object), + *external_objects_, &binding_func); if (!status.ok()) { - if (status.code() == StatusCode::kNotFound) { + if (absl::IsNotFound(status)) { program.refs.push_back(object); continue; // don't add to binding. } @@ -238,10 +241,10 @@ Status Runtime::AddProgram(const GlShader& shader, // All parameters once set stay with program, therefore, we only need to keep // program and bindings for execution. - return OkStatus(); + return absl::OkStatus(); } -Status Runtime::AllocateInternalObject(const Object& object) { +absl::Status Runtime::AllocateInternalObject(const Object& object) { const ObjectRef ref = GetRef(object); switch (object.object_type) { case ObjectType::BUFFER: { @@ -260,15 +263,16 @@ Status Runtime::AllocateInternalObject(const Object& object) { break; } default: - return InternalError("Unexpected internal object type"); + return absl::InternalError("Unexpected internal object type"); } - return OkStatus(); + return absl::OkStatus(); } -Status Runtime::AllocateConstObject(const Object& object, uint32_t* id) { +absl::Status Runtime::AllocateConstObject(const Object& object, uint32_t* id) { const ObjectData* data = GetData(object); if (data == nullptr) { - return InternalError("Unable to allocate reference as a const object"); + return absl::InternalError( + "Unable to allocate reference as a const object"); } *id = next_const_id_++; switch (object.object_type) { @@ -289,12 +293,12 @@ Status Runtime::AllocateConstObject(const Object& object, uint32_t* id) { break; } case ObjectType::UNKNOWN: - return InternalError("Unknown object type"); + return absl::InternalError("Unknown object type"); } - return OkStatus(); + return absl::OkStatus(); } -Status Runtime::PrepareForExecution() { +absl::Status Runtime::PrepareForExecution() { if (shared_readonly_buffer_ && !shared_readonly_buffer_->empty()) { GlBuffer shared_buffer; RETURN_IF_ERROR( @@ -320,11 +324,10 @@ Status Runtime::PrepareForExecution() { // Check whether it is created already. BindFunc binding; ObjectRef ref = GetRef(object); - Status status = MakeBindingFunc(object, ref, internal_objects_, &binding); + absl::Status status = + MakeBindingFunc(object, ref, internal_objects_, &binding); if (!status.ok()) { - if (status.code() != StatusCode::kNotFound) { - return status; - } + if (absl::IsNotFound(status)) return status; RETURN_IF_ERROR(AllocateInternalObject(object)); RETURN_IF_ERROR( MakeBindingFunc(object, ref, internal_objects_, &binding)); @@ -333,7 +336,7 @@ Status Runtime::PrepareForExecution() { } program.refs.clear(); } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -399,8 +402,8 @@ struct AddUsageRecordForTextureFunc { // We assume that AddUsageRecord for different objects is called in order of // program_id. -Status AddUsageRecord(CombinedUsageRecords* usage_records, const Object& object, - const size_t program_id) { +absl::Status AddUsageRecord(CombinedUsageRecords* usage_records, + const Object& object, const size_t program_id) { auto ref = GetRef(object); if (ref >= usage_records->usage_refs.size()) { usage_records->usage_refs.resize(ref + 1, kNotAssigned); @@ -416,17 +419,17 @@ Status AddUsageRecord(CombinedUsageRecords* usage_records, const Object& object, } else { UpdateUsageRecord(&usage_records->buffers[usage_ref], program_id); } - return OkStatus(); + return absl::OkStatus(); } if (object.object_type == ObjectType::TEXTURE) { absl::visit(AddUsageRecordForTextureFunc{usage_records, ref, program_id}, object.size); - return OkStatus(); + return absl::OkStatus(); } - return InternalError("Unexpected object type"); + return absl::InternalError("Unexpected object type"); } -Status ApplyBuffersAssignment( +absl::Status ApplyBuffersAssignment( const ObjectsAssignment& assignment, const std::vector& global_ref_to_usage_rec, const std::vector& global_ref_to_object_ptr, @@ -462,11 +465,11 @@ Status ApplyBuffersAssignment( } (*global_ref_to_shared_ref)[global_ref] = shared_ref; } - return OkStatus(); + return absl::OkStatus(); } template -Status ApplyTexturesAssignment( +absl::Status ApplyTexturesAssignment( const ObjectsAssignment& assignment, const std::vector& global_ref_to_usage_rec, const std::vector& global_ref_to_object_ptr, @@ -504,7 +507,7 @@ Status ApplyTexturesAssignment( } (*global_ref_to_shared_ref)[global_ref] = shared_ref; } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -512,7 +515,8 @@ Status ApplyTexturesAssignment( // Assign shared objects to internal objects, using memory allocation // algorithms. Usage records for the algorithms are calculated separately for // each data type and object type. -Status Runtime::AssignInternalObjects(std::vector* shared_objects) { +absl::Status Runtime::AssignInternalObjects( + std::vector* shared_objects) { // Build tensor usage records, clusterized by object type and data type. std::map usage_records_by_data_type; std::vector global_ref_to_object_ptr; @@ -579,10 +583,10 @@ Status Runtime::AssignInternalObjects(std::vector* shared_objects) { object.object = global_ref_to_shared_ref[GetRef(object)]; } } - return OkStatus(); + return absl::OkStatus(); } -Status Runtime::Execute() { +absl::Status Runtime::Execute() { for (const auto& descriptor : programs_) { for (auto& b : descriptor.bindings) { RETURN_IF_ERROR(b()); @@ -590,7 +594,7 @@ Status Runtime::Execute() { RETURN_IF_ERROR(command_queue_->Dispatch(descriptor.program, descriptor.num_workgroups)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.h b/tensorflow/lite/delegates/gpu/gl/runtime.h index b66a7fdfaa4..97f0f732834 100644 --- a/tensorflow/lite/delegates/gpu/gl/runtime.h +++ b/tensorflow/lite/delegates/gpu/gl/runtime.h @@ -44,17 +44,17 @@ class Runtime { CommandQueue* command_queue, const ObjectManager* external_objects); // Takes parameters and objects and prepares GL program. - Status AddProgram(const GlShader& shader, - const std::vector& parameters, - const std::vector& objects, - const uint3& num_workgroups); + absl::Status AddProgram(const GlShader& shader, + const std::vector& parameters, + const std::vector& objects, + const uint3& num_workgroups); // Needs to be called once all programs and shaders has been added to runtime. - Status PrepareForExecution(); + absl::Status PrepareForExecution(); // Executes all compiled programs. // TODO(akulik): add more controls over execution. Execution policy? - Status Execute(); + absl::Status Execute(); // Gets access to objects created while executing generated code. const ObjectManager* internal_objects() const { return &internal_objects_; } @@ -72,14 +72,14 @@ class Runtime { } private: - Status AllocateInternalObject(const Object& object); + absl::Status AllocateInternalObject(const Object& object); - Status AllocateConstObject(const Object& object, uint32_t* id); + absl::Status AllocateConstObject(const Object& object, uint32_t* id); // Goes over objects in programs and decides how to allocate them to // minimize total allocated memory. Returns a collection of objects to be // allocated and shared by internal objects. - Status AssignInternalObjects(std::vector* objects); + absl::Status AssignInternalObjects(std::vector* objects); const RuntimeOptions options_; const GpuInfo gpu_info_; @@ -92,7 +92,7 @@ class Runtime { std::unique_ptr shared_readonly_buffer_; - using BindFunc = std::function; + using BindFunc = std::function; // Encapsulates a program and all object to bind before dispatch. struct CompiledProgramDescriptor { diff --git a/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h b/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h index d4f49d1952c..11b094637f2 100644 --- a/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h +++ b/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h @@ -55,7 +55,7 @@ class SharedBufferData { bool empty() const { return shared_data_.empty(); } // Returns a single GlBuffer that owns entire shared data. - Status CreateSharedGlBuffer(GlBuffer* gl_buffer) { + absl::Status CreateSharedGlBuffer(GlBuffer* gl_buffer) { // Upload data to a buffer gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, buffer_id_.id()); @@ -64,7 +64,7 @@ class SharedBufferData { GL_STATIC_READ)); *gl_buffer = GlBuffer(GL_SHADER_STORAGE_BUFFER, buffer_id_.Release(), shared_data_.size(), 0, /*has_ownership=*/true); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/serialization.cc b/tensorflow/lite/delegates/gpu/gl/serialization.cc index 17db339fa98..7e15cf2d271 100644 --- a/tensorflow/lite/delegates/gpu/gl/serialization.cc +++ b/tensorflow/lite/delegates/gpu/gl/serialization.cc @@ -390,15 +390,15 @@ absl::Span SerializedCompiledModelBuilder::Finalize( namespace { -Status ParseParameter(const data::UniformParameter& fb_parameter, - Variable* parameter) { +absl::Status ParseParameter(const data::UniformParameter& fb_parameter, + Variable* parameter) { parameter->name = fb_parameter.name()->str(); switch (fb_parameter.type()) { case data::ParameterType::INT32: { auto* ptr = fb_parameter.data_as_DataInt32(); if (ptr == nullptr) { - return InvalidArgumentError("Unexpected data type '" + parameter->name + - "'"); + return absl::InvalidArgumentError("Unexpected data type '" + + parameter->name + "'"); } switch (ptr->data()->size()) { case 1: @@ -412,16 +412,16 @@ Status ParseParameter(const data::UniformParameter& fb_parameter, (*ptr->data())[2], (*ptr->data())[3]); break; default: - return InvalidArgumentError("Unexpected size for parameter '" + - parameter->name + "'"); + return absl::InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); } break; } case data::ParameterType::UINT32: { auto* ptr = fb_parameter.data_as_DataUint32(); if (ptr == nullptr) { - return InvalidArgumentError("Unexpected data type '" + parameter->name + - "'"); + return absl::InvalidArgumentError("Unexpected data type '" + + parameter->name + "'"); } switch (ptr->data()->size()) { case 1: @@ -432,16 +432,16 @@ Status ParseParameter(const data::UniformParameter& fb_parameter, (*ptr->data())[2], (*ptr->data())[3]); break; default: - return InvalidArgumentError("Unexpected size for parameter '" + - parameter->name + "'"); + return absl::InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); } break; } case data::ParameterType::FLOAT32: { auto* ptr = fb_parameter.data_as_DataFloat(); if (ptr == nullptr) { - return InvalidArgumentError("Unexpected data type '" + parameter->name + - "'"); + return absl::InvalidArgumentError("Unexpected data type '" + + parameter->name + "'"); } switch (ptr->data()->size()) { case 1: @@ -455,21 +455,21 @@ Status ParseParameter(const data::UniformParameter& fb_parameter, (*ptr->data())[2], (*ptr->data())[3]); break; default: - return InvalidArgumentError("Unexpected size for parameter '" + - parameter->name + "'"); + return absl::InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); } break; } case data::ParameterType::INT32_2: { auto* ptr = fb_parameter.data_as_DataInt32(); if (ptr == nullptr) { - return InvalidArgumentError("Unexpected data type '" + parameter->name + - "'"); + return absl::InvalidArgumentError("Unexpected data type '" + + parameter->name + "'"); } if (ptr->data()->size() % 2 != 0) { - return InvalidArgumentError("Unexpected size for parameter '" + - parameter->name + "'"); + return absl::InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); } std::vector values(ptr->data()->size() / 2); @@ -480,7 +480,7 @@ Status ParseParameter(const data::UniformParameter& fb_parameter, break; } } - return OkStatus(); + return absl::OkStatus(); } DataType ToEnum(data::DataType type) { @@ -520,7 +520,7 @@ AccessType ToEnum(data::AccessType type) { } } -Status ParseObject(const data::Object& fb_object, Object* object) { +absl::Status ParseObject(const data::Object& fb_object, Object* object) { object->access = ToEnum(fb_object.access()); object->binding = fb_object.binding(); object->object_type = ToEnum(fb_object.type()); @@ -543,7 +543,7 @@ Status ParseObject(const data::Object& fb_object, Object* object) { break; } case data::ObjectSize::NONE: - return InvalidArgumentError("Texture size is not set"); + return absl::InvalidArgumentError("Texture size is not set"); } switch (fb_object.object_type()) { @@ -560,10 +560,10 @@ Status ParseObject(const data::Object& fb_object, Object* object) { break; } case data::ObjectVariant::NONE: { - return InvalidArgumentError("Object is not set"); + return absl::InvalidArgumentError("Object is not set"); } } - return OkStatus(); + return absl::OkStatus(); } CompiledModelOptions ParseParameters(const data::Parameters& fb_parameters) { @@ -574,11 +574,11 @@ CompiledModelOptions ParseParameters(const data::Parameters& fb_parameters) { } // namespace -Status DeserializeCompiledModel(absl::Span serialized, - DeserializationHandler* handler) { +absl::Status DeserializeCompiledModel(absl::Span serialized, + DeserializationHandler* handler) { flatbuffers::Verifier verifier(serialized.data(), serialized.size()); if (!data::VerifyCompiledModelBuffer(verifier)) { - return InvalidArgumentError("Serialized model is corrupted."); + return absl::InvalidArgumentError("Serialized model is corrupted."); } auto model = data::GetCompiledModel(serialized.data()); @@ -612,7 +612,7 @@ Status DeserializeCompiledModel(absl::Span serialized, program->shader_index())); } handler->OnOptions(ParseParameters(*model->parameters())); - return OkStatus(); + return absl::OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/serialization.h b/tensorflow/lite/delegates/gpu/gl/serialization.h index c3c88b4c462..82b76a475f5 100644 --- a/tensorflow/lite/delegates/gpu/gl/serialization.h +++ b/tensorflow/lite/delegates/gpu/gl/serialization.h @@ -67,19 +67,19 @@ class DeserializationHandler { public: virtual ~DeserializationHandler() = default; - virtual Status OnShader(absl::Span shader_src) = 0; + virtual absl::Status OnShader(absl::Span shader_src) = 0; - virtual Status OnProgram(const std::vector& parameters, - const std::vector& objects, - const uint3& workgroup_size, - const uint3& num_workgroups, - size_t shader_index) = 0; + virtual absl::Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, + const uint3& num_workgroups, + size_t shader_index) = 0; virtual void OnOptions(const CompiledModelOptions& options) = 0; }; -Status DeserializeCompiledModel(absl::Span serialized, - DeserializationHandler* handler); +absl::Status DeserializeCompiledModel(absl::Span serialized, + DeserializationHandler* handler); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/serialization_test.cc b/tensorflow/lite/delegates/gpu/gl/serialization_test.cc index 25aa9be73b2..37c08129139 100644 --- a/tensorflow/lite/delegates/gpu/gl/serialization_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/serialization_test.cc @@ -45,18 +45,19 @@ struct ProgramDesc { }; struct Handler : public DeserializationHandler { - Status OnShader(absl::Span shader_src) final { + absl::Status OnShader(absl::Span shader_src) final { shaders.push_back(std::string(shader_src.data(), shader_src.size())); - return OkStatus(); + return absl::OkStatus(); } - Status OnProgram(const std::vector& parameters, - const std::vector& objects, - const uint3& workgroup_size, const uint3& num_workgroups, - size_t shader_index) final { + absl::Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, + const uint3& num_workgroups, + size_t shader_index) final { programs.push_back( {parameters, objects, workgroup_size, num_workgroups, shader_index}); - return OkStatus(); + return absl::OkStatus(); } void OnOptions(const CompiledModelOptions& o) final { options = o; } diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc index 16aaafa5c94..5ebefb4a6eb 100644 --- a/tensorflow/lite/delegates/gpu/gl_delegate.cc +++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc @@ -93,7 +93,8 @@ class Delegate { } } - Status CopyFromBufferHandle(TfLiteBufferHandle handle, TfLiteTensor* tensor) { + absl::Status CopyFromBufferHandle(TfLiteBufferHandle handle, + TfLiteTensor* tensor) { ValueRef ref; RETURN_IF_ERROR(FindObject(handle, &ref)); auto buffer = phwc4_objects_.FindBuffer(handle); @@ -105,8 +106,8 @@ class Delegate { }); } - Status CopyToBufferHandle(TfLiteBufferHandle handle, - TfLiteTensor* tensor) const { + absl::Status CopyToBufferHandle(TfLiteBufferHandle handle, + TfLiteTensor* tensor) const { ValueRef ref; RETURN_IF_ERROR(FindObject(handle, &ref)); auto buffer = phwc4_objects_.FindBuffer(handle); @@ -117,7 +118,7 @@ class Delegate { }); } - Status BindBufferToTensor(GLuint ssbo, int tensor_index) { + absl::Status BindBufferToTensor(GLuint ssbo, int tensor_index) { int64_t bytes_size; RETURN_IF_ERROR(GetSSBOSize(ssbo, &bytes_size)); return bhwc_objects_.RegisterBuffer( @@ -126,8 +127,8 @@ class Delegate { /* has_ownership = */ false)); } - Status Prepare(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params) { + absl::Status Prepare(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { // Extract TFLite delegate execution plan from the context and convert it // into FlowGraph32. GraphFloat32 graph; @@ -137,7 +138,7 @@ class Delegate { NullTransformationReporter reporter; ModelTransformer transformer(&graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return InternalError("Graph general transformations failed"); + return absl::InternalError("Graph general transformations failed"); } if (!env_) RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env_)); @@ -176,7 +177,7 @@ class Delegate { tflite_graph_io.insert(tensor_index); const auto* input = find_value(tensor_index); if (!input || tensor->type != TfLiteType::kTfLiteFloat32) { - return NotFoundError("Input tensor is not found in the graph."); + return absl::NotFoundError("Input tensor is not found in the graph."); } inputs_.push_back(input->id); @@ -215,7 +216,8 @@ class Delegate { tflite_graph_io.insert(tensor_index); const auto* output = find_value(tensor_index); if (!output || tensor->type != TfLiteType::kTfLiteFloat32) { - return NotFoundError("Output tensor is not found in the graph."); + return absl::NotFoundError( + "Output tensor is not found in the graph."); } outputs_.push_back(output->id); @@ -270,14 +272,14 @@ class Delegate { RETURN_IF_ERROR(compiled_model->NewRun(runtime_options, &phwc4_objects_, command_queue_.get(), &inference_context_)); - return OkStatus(); + return absl::OkStatus(); } - Status Invoke(TfLiteContext* context) { + absl::Status Invoke(TfLiteContext* context) { const EGLContext egl_context_at_delegate_init = env_->context().context(); const EGLContext egl_context_at_delegate_invoke = eglGetCurrentContext(); if (egl_context_at_delegate_init != egl_context_at_delegate_invoke) { - return FailedPreconditionError( + return absl::FailedPreconditionError( "Delegate should run on the same thread where it was initialized."); } @@ -330,18 +332,18 @@ class Delegate { RETURN_IF_ERROR(CopyFromBufferHandle(id, &tensor)); } } - return OkStatus(); + return absl::OkStatus(); } TfLiteDelegate* tflite_delegate() { return &delegate_; } private: - Status FindObject(ValueId id, ValueRef* ref) const { + absl::Status FindObject(ValueId id, ValueRef* ref) const { if (id >= tensors_.size()) { - return InvalidArgumentError("Invalid buffer id"); + return absl::InvalidArgumentError("Invalid buffer id"); } *ref = tensors_[id]; - return OkStatus(); + return absl::OkStatus(); } TfLiteDelegate delegate_ = { @@ -387,7 +389,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = gpu_delegate->Prepare(context, params); if (status.ok()) return gpu_delegate; context->ReportError(context, "TfLiteGpuDelegate Prepare: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return nullptr; }, // .free @@ -401,7 +403,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = GetGpuDelegate(node)->Invoke(context); if (status.ok()) return kTfLiteOk; context->ReportError(context, "TfLiteGpuDelegate Invoke: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return kTfLiteError; }, nullptr, // .profiling_string @@ -425,7 +427,7 @@ TfLiteStatus DelegateCopyFromBufferHandle(TfLiteContext* context, const auto status = gpu_delegate->CopyFromBufferHandle(buffer_handle, tensor); if (status.ok()) return kTfLiteOk; context->ReportError(context, "TfLiteGpuDelegate CopyFromBufferHandle: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return kTfLiteError; } @@ -438,7 +440,7 @@ TfLiteStatus DelegateCopyToBufferHandle(TfLiteContext* context, const auto status = gpu_delegate->CopyToBufferHandle(buffer_handle, tensor); if (status.ok()) return kTfLiteOk; context->ReportError(context, "TfLiteGpuDelegate CopyToBufferHandle: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return kTfLiteError; } diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 4c0af17090e..b2887e523a5 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -122,13 +122,14 @@ std::vector SelectSpaceToDepth( return SpaceToDepth(id, input_id, output_id, attr); } -Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, - const std::vector& inputs, - const std::vector& outputs, - const RuntimeOptions& options, - std::vector* tasks) { +absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, + const std::vector& inputs, + const std::vector& outputs, + const RuntimeOptions& options, + std::vector* tasks) { if (!IsBatchMatchesForAllValues(graph)) { - return InvalidArgumentError("Only identical batch dimension is supported"); + return absl::InvalidArgumentError( + "Only identical batch dimension is supported"); } int node_id = static_cast(node->id); auto op_type = OperationTypeFromString(node->operation.type); @@ -199,7 +200,7 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, case OperationType::PAD: { auto attr = absl::any_cast(node->operation.attributes); if (attr.appended.b != 0 || attr.prepended.b != 0) { - return UnimplementedError("Padding for BATCH is not supported."); + return absl::UnimplementedError("Padding for BATCH is not supported."); } *tasks = Padding(node_id, inputs[0], outputs[0], attr); break; @@ -236,7 +237,8 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, case OperationType::SOFTMAX: { auto attr = absl::any_cast(node->operation.attributes); if (attr.axis != Axis::CHANNELS) { - return UnimplementedError("Softmax supports only CHANNELS dimension"); + return absl::UnimplementedError( + "Softmax supports only CHANNELS dimension"); } *tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]); break; @@ -278,15 +280,16 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, case OperationType::SPACE_TO_BATCH: case OperationType::TRANSPOSE: case OperationType::UNKNOWN: - return UnimplementedError("Unsupported op: " + node->operation.type); + return absl::UnimplementedError("Unsupported op: " + + node->operation.type); } - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, - CompiledModel* compiled_model) { +absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, + CompiledModel* compiled_model) { for (const auto& node : graph.nodes()) { std::vector inputs; for (auto& input : graph.FindInputs(node->id)) { @@ -303,11 +306,11 @@ Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, auto primary_status = RegisterPrimaryOps(graph, node, inputs, outputs, options, &tasks); if (!primary_status.ok()) { - return UnimplementedError(absl::Substitute( - "Unsupported op type: $0; custom registry error: " - "$1; primary registry error: $2;", - node->operation.type, custom_status.error_message(), - primary_status.error_message())); + return absl::UnimplementedError( + absl::Substitute("Unsupported op type: $0; custom registry error: " + "$1; primary registry error: $2;", + node->operation.type, custom_status.message(), + primary_status.message())); } } for (auto task : tasks) { @@ -315,7 +318,7 @@ Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, } compiled_model->insert(compiled_model->end(), tasks.begin(), tasks.end()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/api.h b/tensorflow/lite/delegates/gpu/metal/api.h index dd3c423a612..c1c7648638c 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.h +++ b/tensorflow/lite/delegates/gpu/metal/api.h @@ -26,8 +26,8 @@ namespace gpu { namespace metal { // Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions. -Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, - CompiledModel* compiled_model); +absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, + CompiledModel* compiled_model); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/common.h b/tensorflow/lite/delegates/gpu/metal/common.h index 9d7d66176f6..6f4e94ed2e7 100644 --- a/tensorflow/lite/delegates/gpu/metal/common.h +++ b/tensorflow/lite/delegates/gpu/metal/common.h @@ -39,10 +39,9 @@ id GetBestSupportedMetalDevice(); /// both. /// @discussion The function autoselects the maximum shader language version supported by the target /// OS. FastMath is enabled. -::tflite::gpu::Status CreateComputeProgram(id device, NSString* code, - NSString* functionName, - NSDictionary* macros, - id* program); +absl::Status CreateComputeProgram(id device, NSString* code, NSString* functionName, + NSDictionary* macros, + id* program); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/common.mm b/tensorflow/lite/delegates/gpu/metal/common.mm index 7167430a343..cc5a98dfffc 100644 --- a/tensorflow/lite/delegates/gpu/metal/common.mm +++ b/tensorflow/lite/delegates/gpu/metal/common.mm @@ -34,9 +34,9 @@ namespace metal { id GetBestSupportedMetalDevice() { return MTLCreateSystemDefaultDevice(); } -Status CreateComputeProgram(id device, NSString* code, NSString* functionName, - NSDictionary* macros, - id* program) { +absl::Status CreateComputeProgram(id device, NSString* code, NSString* functionName, + NSDictionary* macros, + id* program) { MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; // Runtime checks for the iOS version independently of minimum target iOS. @@ -70,14 +70,14 @@ Status CreateComputeProgram(id device, NSString* code, NSString* func if (!library) { NSString* errorString = [NSString stringWithFormat:@"newLibraryWithSource: %@", [error localizedDescription]]; - return InternalError([errorString UTF8String]); + return absl::InternalError([errorString UTF8String]); } id function = [library newFunctionWithName:functionName]; if (!function) { NSString* errorString = [NSString stringWithFormat:@"newFunctionWithName: %@", [error localizedDescription]]; - return InternalError([errorString UTF8String]); + return absl::InternalError([errorString UTF8String]); } *program = [device newComputePipelineStateWithFunction:function error:&error]; @@ -85,9 +85,9 @@ Status CreateComputeProgram(id device, NSString* code, NSString* func NSString* errorString = [NSString stringWithFormat:@"newComputePipelineStateWithFunction error: %@", [error localizedDescription]]; - return InternalError([errorString UTF8String]); + return absl::InternalError([errorString UTF8String]); } - return OkStatus(); + return absl::OkStatus(); } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/common_test.mm b/tensorflow/lite/delegates/gpu/metal/common_test.mm index 18a495ebd18..7cedac0f799 100644 --- a/tensorflow/lite/delegates/gpu/metal/common_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/common_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include @@ -25,7 +26,6 @@ limitations under the License. using ::tflite::gpu::metal::GetBestSupportedMetalDevice; using ::tflite::gpu::metal::CreateComputeProgram; -using ::tflite::gpu::Status; @interface CommonTest : XCTestCase @@ -53,16 +53,16 @@ kernel void FunctionName(device TYPE* const src_buffer[[buffer(0)]], XCTAssertNotNil(device, @"The Metal device must exists on real device"); NSString* functionName = @"FunctionName"; id program; - Status status; + absl::Status status; NSDictionary* macrosFloat4 = @{@"TYPE" : @"float4"}; status = CreateComputeProgram(device, code, functionName, macrosFloat4, &program); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.messasge()).c_str()); XCTAssertNotNil(program); NSDictionary* macrosHalf4 = @{@"TYPE" : @"half4"}; status = CreateComputeProgram(device, code, functionName, macrosHalf4, &program); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.messasge()).c_str()); XCTAssertNotNil(program); // This compilation is intended to be incorrect diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc index 9608aaddeb4..711ed9fed88 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc @@ -558,10 +558,10 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { } // namespace -Status ValidateOptimizeModel(const std::vector& input_buffers, - const std::vector& output_buffers, - const CompiledModel& input_vector, - CompiledModel* output) { +absl::Status ValidateOptimizeModel(const std::vector& input_buffers, + const std::vector& output_buffers, + const CompiledModel& input_vector, + CompiledModel* output) { std::list input; input.insert(input.end(), input_vector.begin(), input_vector.end()); OptimizationInfo info; @@ -600,10 +600,10 @@ Status ValidateOptimizeModel(const std::vector& input_buffers, std::to_string(info.unused_input_buffer_ids.size()) + "\nMissing output buffers " + std::to_string(info.missing_output_buffer_ids.size()); - return InternalError(message); + return absl::InternalError(message); } for (const auto& chain : sorted_chains) output->push_back(FuseChain(chain)); - return OkStatus(); + return absl::OkStatus(); } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.h b/tensorflow/lite/delegates/gpu/metal/compiled_model.h index 5f9982d0a66..222534402d9 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model.h +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.h @@ -31,9 +31,10 @@ using CompiledModel = std::vector; // Receives input CompiledModel, validates, optimizes it and returns output // CompiledModel. No shader compilation or memory allocation happen here, this // function just does high-level operations fusion. -Status ValidateOptimizeModel(const std::vector& input_buffers, - const std::vector& output_buffers, - const CompiledModel& input, CompiledModel* output); +absl::Status ValidateOptimizeModel(const std::vector& input_buffers, + const std::vector& output_buffers, + const CompiledModel& input, + CompiledModel* output); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm b/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm index 59827ce2c08..83870123321 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm @@ -183,7 +183,7 @@ static std::vector Add2Linkable(int id, ValueId input_ auto nodes = MulLinkable(1, 1, 2); std::vector model; auto status = ValidateOptimizeModel({1}, {2}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } // Outputs: one missing, one unused. @@ -195,8 +195,8 @@ static std::vector Add2Linkable(int id, ValueId input_ std::vector errorMessages = {"Input operations count 1", "Unused operations 1", "Unused inputs 1", "Missing output buffers 1"}; for (const std::string& message : errorMessages) { - bool doesContainMessage = status.error_message().find(message) != std::string::npos; - XCTAssertTrue(doesContainMessage, @"%s", status.error_message().c_str()); + bool doesContainMessage = std::string(status.message()).find(message) != std::string::npos; + XCTAssertTrue(doesContainMessage, @"%s", std::string(status.message()).c_str()); } } @@ -205,7 +205,7 @@ static std::vector Add2Linkable(int id, ValueId input_ auto nodes = MulLinkable(1, 1, 2); std::vector model; auto status = ValidateOptimizeModel({1}, {2, 3}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } // Unused input => empty graph, missing output. @@ -216,8 +216,8 @@ static std::vector Add2Linkable(int id, ValueId input_ std::vector errorMessages = {"Input operations count 1", "Unused operations 0", "Unused inputs 1", "Missing output buffers 1"}; for (const std::string& message : errorMessages) { - bool doesContainMessage = status.error_message().find(message) != std::string::npos; - XCTAssertTrue(doesContainMessage, @"%s", status.error_message().c_str()); + bool doesContainMessage = std::string(status.message()).find(message) != std::string::npos; + XCTAssertTrue(doesContainMessage, @"%s", std::string(status.message()).c_str()); } } @@ -228,7 +228,7 @@ static std::vector Add2Linkable(int id, ValueId input_ nodes.insert(nodes.end(), nodes2.begin(), nodes2.end()); std::vector model; auto status = ValidateOptimizeModel({1}, {3}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } // Two sequential operations. Not fused. @@ -238,14 +238,14 @@ static std::vector Add2Linkable(int id, ValueId input_ nodes.insert(nodes.end(), nodes2.begin(), nodes2.end()); std::vector model; auto status = ValidateOptimizeModel({1}, {3}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testAddOperationSuccess { auto nodes = Add2(1, 1, 2, 3); std::vector model; auto status = ValidateOptimizeModel({1, 2}, {3}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testAddOperationFused { @@ -254,7 +254,7 @@ static std::vector Add2Linkable(int id, ValueId input_ graph.insert(graph.end(), graph2.begin(), graph2.end()); std::vector model; auto status = ValidateOptimizeModel({1, 2}, {4}, graph, &model); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); XCTAssertTrue(model.size() == 1, @"Not fused, more than one task descriptor."); } @@ -266,7 +266,7 @@ static std::vector Add2Linkable(int id, ValueId input_ graph.insert(graph.end(), graph3.begin(), graph3.end()); std::vector model; auto status = ValidateOptimizeModel({1, 2}, {5}, graph, &model); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.h b/tensorflow/lite/delegates/gpu/metal/compute_task.h index 611185b8fc1..b03a8436077 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.h +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.h @@ -31,12 +31,12 @@ limitations under the License. @interface TFLComputeTask : NSObject /// Returns empty string or error if shader can't be compiled. -- (::tflite::gpu::Status)compileWithDevice:(id)device - taskDescriptor:(::tflite::gpu::metal::ComputeTaskDescriptorPtr)desc - runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; +- (absl::Status)compileWithDevice:(id)device + taskDescriptor:(::tflite::gpu::metal::ComputeTaskDescriptorPtr)desc + runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; /// Updates dimensions for inputs/outputs/intermediate tensors -- (::tflite::gpu::Status) +- (absl::Status) setInputDimensionsWithDevice:(id)device dimensions:(std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions; @@ -50,12 +50,11 @@ limitations under the License. /// @param sharedBufferIds contain shared buffer id for each tensor usage record id. /// @param sharedBuffers contain metal handles to the allocated buffers for each shared buffer id. /// TODO(ypisarchyk): probably we can decrease the number of parameters here -- (::tflite::gpu::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers - outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds - usageRecordIds: - (const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds - sharedBufferIds:(const std::vector&)sharedBufferIds - sharedBuffers:(const std::vector>&)sharedBuffers; +- (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers + outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds + usageRecordIds:(const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds + sharedBufferIds:(const std::vector&)sharedBufferIds + sharedBuffers:(const std::vector>&)sharedBuffers; - (void)encodeWithEncoder:(id)encoder inputOutputBuffers: diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.mm b/tensorflow/lite/delegates/gpu/metal/compute_task.mm index 24b89c1b11c..d3e3466ca6f 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.mm +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.mm @@ -29,8 +29,6 @@ limitations under the License. using ::tflite::gpu::AlignByN; using ::tflite::gpu::BHWC; -using ::tflite::gpu::InternalError; -using ::tflite::gpu::InvalidArgumentError; using ::tflite::gpu::HalfBits; using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; using ::tflite::gpu::metal::CreateComputeProgram; @@ -38,8 +36,6 @@ using ::tflite::gpu::metal::DispatchParamsFunction; using ::tflite::gpu::metal::OutputDimensions; using ::tflite::gpu::metal::RuntimeOptions; using ::tflite::gpu::metal::UniformsFunction; -using ::tflite::gpu::OkStatus; -using ::tflite::gpu::Status; using ::tflite::gpu::uint3; using ::tflite::gpu::ValueId; @@ -70,9 +66,9 @@ using ::tflite::gpu::ValueId; std::string _description; } -- (Status)compileWithDevice:(id)device - taskDescriptor:(ComputeTaskDescriptorPtr)desc - runtimeOptions:(const RuntimeOptions&)options { +- (absl::Status)compileWithDevice:(id)device + taskDescriptor:(ComputeTaskDescriptorPtr)desc + runtimeOptions:(const RuntimeOptions&)options { NSString* barrier; // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0 if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) { @@ -123,7 +119,7 @@ using ::tflite::gpu::ValueId; id program; RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program)); if (!program) { - return InternalError("Unknown shader compilation error"); + return absl::InternalError("Unknown shader compilation error"); } for (auto& buffer : desc->input_buffers) { _inputBuffers.emplace_back(InputBuffer{buffer.id, nil}); @@ -148,12 +144,13 @@ using ::tflite::gpu::ValueId; _resizeFunction = desc->resize_function; _program = program; _description = desc->description; - return OkStatus(); + return absl::OkStatus(); } -- (Status)setInputDimensionsWithDevice:(id)device - dimensions: - (std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions { +- (absl::Status)setInputDimensionsWithDevice:(id)device + dimensions: + (std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*) + dimensions { // Re-calculate output buffers dimensions for (auto& buffer : _outputBuffers) { auto outputDimensions = buffer.dimensionsFunction(*dimensions); @@ -180,23 +177,23 @@ using ::tflite::gpu::ValueId; error += "is larger than the MTLDevice can support: "; error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) + ", " + std::to_string(threadsPerGroup.depth); - return InvalidArgumentError(error); + return absl::InvalidArgumentError(error); } _groupsCount = workGroups.second; - return OkStatus(); + return absl::OkStatus(); } -- (Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers - outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds - usageRecordIds:(const std::map&)usageRecordIds - sharedBufferIds:(const std::vector&)sharedBufferIds - sharedBuffers:(const std::vector>&)sharedBuffers { +- (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers + outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds + usageRecordIds:(const std::map&)usageRecordIds + sharedBufferIds:(const std::vector&)sharedBufferIds + sharedBuffers:(const std::vector>&)sharedBuffers { for (auto& buffer : _outputBuffers) { // If the buffer is intermediate: set its metalHandle from sharedBuffers if (std::find(outputIds.begin(), outputIds.end(), buffer.uid) == outputIds.end()) { auto usageRecordIt = usageRecordIds.find(buffer.uid); if (usageRecordIt == usageRecordIds.end()) { - return InternalError("TensorUsageRecord for intermediate tensor is not found."); + return absl::InternalError("TensorUsageRecord for intermediate tensor is not found."); } buffer.metalHandle = sharedBuffers.at(sharedBufferIds.at(usageRecordIt->second)); (*buffers)[buffer.uid] = buffer.metalHandle; @@ -207,7 +204,7 @@ using ::tflite::gpu::ValueId; for (auto& buffer : _inputBuffers) { buffer.metalHandle = (*buffers)[buffer.uid]; } - return OkStatus(); + return absl::OkStatus(); } - (void)encodeWithEncoder:(id)encoder diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h index 8569a4ed009..97a6f3b3b18 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.h +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h @@ -50,12 +50,12 @@ limitations under the License. /// @return Status signals whether model is compiled successfully or not. /// @discussion Previously added operations are distilled into sorted list of sets of /// ComputeTaskDescriptors, which can be fused into a single GPU task. -- (::tflite::gpu::Status) - compileModelWithDevice:(id)device - taskDescriptors: - (const std::vector<::tflite::gpu::metal::ComputeTaskDescriptorPtr>&)taskDescriptors - outputBufferIDs:(const std::vector<::tflite::gpu::ValueId>&)outputBufferIDs - runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; +- (absl::Status)compileModelWithDevice:(id)device + taskDescriptors: + (const std::vector<::tflite::gpu::metal::ComputeTaskDescriptorPtr>&) + taskDescriptors + outputBufferIDs:(const std::vector<::tflite::gpu::ValueId>&)outputBufferIDs + runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; /// Creates intermediate buffers. The model is ready to be used after this call. /// @param inputDimensions Used to create resources: shaders, buffers. @@ -63,7 +63,7 @@ limitations under the License. /// @return Status signals whether intermediate buffers are successfully created or not. /// @discussion The operation is intended to be lightweight with minimum overhead. A preceding call /// compileModelWithDevice() must be made with the proper device parameter set. -- (::tflite::gpu::Status) +- (absl::Status) setInputDimensions:(const std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>&)inputDimensions outputDimensions:(std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)outputDimensions taskDescriptors: diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.mm b/tensorflow/lite/delegates/gpu/metal/inference_context.mm index fb3a51f4694..d5589ae8ab4 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.mm +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.mm @@ -32,9 +32,6 @@ limitations under the License. using ::tflite::gpu::BHWC; using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; using ::tflite::gpu::metal::RuntimeOptions; -using ::tflite::gpu::InternalError; -using ::tflite::gpu::OkStatus; -using ::tflite::gpu::Status; using ::tflite::gpu::ValueId; using ::tflite::gpu::AlignByN; using ::tflite::gpu::HalfBits; @@ -48,10 +45,10 @@ using ::tflite::gpu::TensorUsageRecord; RuntimeOptions _options; } -- (Status)compileModelWithDevice:(id)device - taskDescriptors:(const std::vector&)taskDescriptors - outputBufferIDs:(const std::vector&)requestedOutputBufferIDs - runtimeOptions:(const RuntimeOptions&)options { +- (absl::Status)compileModelWithDevice:(id)device + taskDescriptors:(const std::vector&)taskDescriptors + outputBufferIDs:(const std::vector&)requestedOutputBufferIDs + runtimeOptions:(const RuntimeOptions&)options { _device = device; _outputIds = requestedOutputBufferIDs; _options = options; @@ -61,12 +58,12 @@ using ::tflite::gpu::TensorUsageRecord; RETURN_IF_ERROR([task compileWithDevice:_device taskDescriptor:node runtimeOptions:_options]); _computeTasks.emplace_back(task); } - return OkStatus(); + return absl::OkStatus(); } -- (Status)setInputDimensions:(const std::map&)inputDimensions - outputDimensions:(std::map*)outputDimensions - taskDescriptors:(const std::vector&)taskDescriptors { +- (absl::Status)setInputDimensions:(const std::map&)inputDimensions + outputDimensions:(std::map*)outputDimensions + taskDescriptors:(const std::vector&)taskDescriptors { // These maps contain all input/output/intermediate buffers shared across model. std::map dimensions = inputDimensions; std::map> buffers; @@ -97,7 +94,7 @@ using ::tflite::gpu::TensorUsageRecord; if (!usageRecordIds.count(outputId)) { const auto it = dimensions.find(outputId); if (it == dimensions.end()) { - return InternalError("Dimensions for intermediate tensor not found."); + return absl::InternalError("Dimensions for intermediate tensor not found."); } usageRecordIds[outputId] = usageRecords.size(); usageRecords.emplace_back(it->second.w * it->second.h * AlignByN(it->second.c, 4), i, i); @@ -133,14 +130,14 @@ using ::tflite::gpu::TensorUsageRecord; error += std::to_string(assignment.object_ids[i]) + " with size: " + std::to_string(bufferSize) + " exceeds MTLDevice maxBufferLength: " + std::to_string([_device maxBufferLength]); - return ::tflite::gpu::ResourceExhaustedError(error); + return absl::ResourceExhaustedError(error); } #endif #if defined(__MAC_10_12) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_12 if ([_device currentAllocatedSize] + bufferSize > [_device recommendedMaxWorkingSetSize]) { std::string error("Out of memory in MTLBuffer allocation. Currently allocated: "); error += std::to_string([_device currentAllocatedSize]); - return ::tflite::gpu::ResourceExhaustedError(error); + return absl::ResourceExhaustedError(error); } #endif @@ -154,7 +151,7 @@ using ::tflite::gpu::TensorUsageRecord; sharedBufferIds:assignment.object_ids sharedBuffers:sharedBuffers]); } - return OkStatus(); + return absl::OkStatus(); } - (void)encodeWithEncoder:(id)commandEncoder diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context_test.mm b/tensorflow/lite/delegates/gpu/metal/inference_context_test.mm index 14ea40c68b4..4d9e54a0ca0 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/inference_context_test.mm @@ -17,6 +17,8 @@ limitations under the License. #import +#include + #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" @@ -170,9 +172,9 @@ static std::vector MulArrayLinkable( std::map inputs{{inputBufferID, input}}; std::map outputs{{outputBufferID, {}}}; auto status = RunGraph(graph, _device, inputs, &outputs); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2.2f, 3.3f, 4.4f}, outputs[outputBufferID].data, 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testImmutableShaderOutput { @@ -187,9 +189,9 @@ static std::vector MulArrayLinkable( std::map inputs{{inputBufferID, input}}; std::map outputs{{outputBufferID, {}}}; auto status = RunGraph(graph, _device, inputs, &outputs); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 4, 9, 16, 25, 36, 49}, outputs[outputBufferID].data, 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testUniformShaderOutput { @@ -203,9 +205,9 @@ static std::vector MulArrayLinkable( std::map inputs{{inputBufferID, input}}; std::map outputs{{outputBufferID, {}}}; auto status = RunGraph(graph, _device, inputs, &outputs); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 4, 6}, outputs[outputBufferID].data, 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testUniformAndImmutableShaderOutput { @@ -222,9 +224,9 @@ static std::vector MulArrayLinkable( std::map inputs{{inputBufferID, input}}; std::map outputs{{outputBufferID, {}}}; auto status = RunGraph(graph, _device, inputs, &outputs); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 6, 12, 20, 26, 38, 52}, outputs[outputBufferID].data, 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm index 10481b2a867..540308f23b4 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -65,9 +66,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8})); XCTAssertTrue(model.PopulateTensor(1, {0.1, 0.2, 0.3, 0.5})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({-1.9, 0.4, 1.0, 1.3}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testInputTensorAndScalar { @@ -85,9 +86,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testInputTensorWithConstantBroadcast { @@ -112,10 +113,10 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({11.0, 22.0, 13.0, 24.0, 15.0, 26.0, 17.0, 28.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm index b67c1ca839c..195a2986628 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -66,9 +67,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 3, 5, 7})); XCTAssertTrue(model.PopulateTensor(1, {2, 4, 6, 8})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6, 7, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTwoInputTensorsByAlignedChannel { @@ -92,9 +93,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(1, {5, 6, 7, 8})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6, 7, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTwoInputTensorsByHeight { @@ -118,9 +119,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 2})); XCTAssertTrue(model.PopulateTensor(1, {3, 4, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTwoInputTensorsByWidth { @@ -144,8 +145,8 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 4})); XCTAssertTrue(model.PopulateTensor(1, {2, 3, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm index 8f1b24a4735..a74b22cf13e 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -82,9 +83,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({4, 8, 4, 8, 2, 4, 2, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testO1H2W2I1Stride1x1Dilation2x2 { @@ -120,9 +121,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1, 1, 1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({10}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testO1H3W3I1Stride1x1Dilation1x1 { @@ -158,9 +159,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({11}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testO2H1W1I2Stride1x1Dilation1x1 { @@ -196,9 +197,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({4, 8, 4, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testO1H1W1I1Stride2x2Dilation1x1 { @@ -235,9 +236,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 0, 2, 0, 0, 0, 4, 0, 8})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 4, 8, 16}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc index 228583c6e30..620a4581c52 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc @@ -26,12 +26,12 @@ namespace tflite { namespace gpu { namespace metal { -Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, - const std::vector& inputs, - const std::vector& outputs, - const RuntimeOptions& options, - std::vector* tasks) { - return UnimplementedError("Unsupported op: " + node->operation.type); +absl::Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, + const std::vector& inputs, + const std::vector& outputs, + const RuntimeOptions& options, + std::vector* tasks) { + return absl::UnimplementedError("Unsupported op: " + node->operation.type); } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h index bef2ba20def..eee1632a644 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h @@ -28,11 +28,11 @@ namespace gpu { namespace metal { // Registers custom operations. -Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, - const std::vector& inputs, - const std::vector& outputs, - const RuntimeOptions& options, - std::vector* tasks); +absl::Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, + const std::vector& inputs, + const std::vector& outputs, + const RuntimeOptions& options, + std::vector* tasks); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm index 5f262238464..d76507253a9 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -83,9 +84,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 3})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 4, 12, 16}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testO2H1W1I1Strides2x2Dilation1x1 { @@ -122,9 +123,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 0, 1, 1, 0, 1, 1, 0, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 3, 1, 3, 1, 3, 1, 3}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testO2H2W2I1Strides1x1Dilation2x2 { @@ -161,9 +162,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 0, 1, 1, 0, 1, 1, 0, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({10, 26}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm index 4baa4573909..d8521ba76b1 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -59,9 +60,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, 6.2, 2.0, 4.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testCos { @@ -72,9 +73,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, -1.0, -1.0, 0.540302}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testDiv { @@ -86,9 +87,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, -3.1, -4.0, 1.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testExp { @@ -99,11 +100,11 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0f, 1.0f, -1.0f, 100.0f, -100.0f, 0.01f, -0.01f})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({std::exp(0.0f), std::exp(1.0f), std::exp(-1.0f), std::exp(100.0f), std::exp(-100.0f), std::exp(0.01f), std::exp(-0.01f)}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testHardSwish { @@ -114,10 +115,10 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {-4.5f, -3.0f, -1.5f, 0.0f, 1.5f, 3.0f, 4.5f})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testLog { @@ -128,9 +129,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 3.1415926, 1.0, 1.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, 1.14473, 0.0, 0.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testMaximum { @@ -142,9 +143,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 2.0, 3.0, -2.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testMaximumWithScalar { @@ -157,9 +158,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, -1.0, 2.0, -1.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testMinimum { @@ -171,9 +172,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, -6.2, 2.0, -3.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testMinimumWithScalar { @@ -186,9 +187,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({-1.0, -6.2, -1.0, -3.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testPow { @@ -200,9 +201,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, 1.0, 8.0, 256.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testRsqrt { @@ -213,9 +214,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 0.707106, 0.5, 0.333333}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSigmoid { @@ -226,9 +227,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.5, 0.002473, 0.880797, 0.982014}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSin { @@ -239,9 +240,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, 0.0, 0.0, 0.841471}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSqrt { @@ -252,9 +253,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, 1.0, 1.414213, 2.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSquare { @@ -265,9 +266,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 4.0, 0.25, 9.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSquaredDiff { @@ -279,9 +280,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 1.0, 9.0, 0.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSub { @@ -293,9 +294,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({-1.0, -8.2, -1.0, 0.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTanh { @@ -306,9 +307,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, -0.999987, 0.964027, 0.999329}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm index 6d3a3e697b8..e57f9aa84e2 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -75,9 +76,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::FULLY_CONNECTED), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({6, 13, 20, 27}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm index cacd501f0bd..cf4aacf724f 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -72,10 +73,10 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(1, {0, 0, 0, 0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 0, 2, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm index 69eed7d86b0..67325c1adb7 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -62,9 +63,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::MEAN), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2.5}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm index 279fd1e4fea..f69598bad5b 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -63,9 +64,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 4, 6, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testMulLinear { @@ -89,9 +90,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 6, 6, 12}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @@ -115,9 +116,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(1, {2, 3})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 4, 9, 12}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testApplyMaskEqualsToInputChannel { @@ -140,9 +141,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm index 22fa11a89fb..9c55cfc45b0 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -78,9 +79,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PAD), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors(expected, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)runPrepending:(const HWC&)prepend @@ -164,9 +165,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PAD), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testMirrorPadChannelsOperation { @@ -188,9 +189,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PAD), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm index f79d53c7bd3..d2d95b30af2 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -73,11 +74,11 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output, indices}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 1, 2, 3, 4, 3, 4, 7, 8, 7, 8, 5, 6, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({4, 4, 8, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({3, 3, 1, 1}, model.GetOutput(1), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testPoolingMaxKernel2x2Stride2x2WithoutIndices { @@ -101,9 +102,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 1, 2, 3, 4, 3, 4, 7, 8, 7, 8, 5, 6, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({4, 4, 8, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testPoolingAverageKernel2x2Stride2x2 { @@ -127,9 +128,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm index b805ed81c76..1df08be61db 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -69,9 +70,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PRELU), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {-1.0, -2.0, 1.0, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({-2, -4, 1, 2}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testPReluLinearAlphaWithClip { @@ -96,9 +97,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PRELU), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {-1.0, -2.0, 1.0, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({-2, -4, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testPRelu3DAlphaNoClip { @@ -124,9 +125,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(op_type), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -1.0, 2.0, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0, -2, 2, -6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testPRelu3DAlphaWithClip { @@ -152,9 +153,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(op_type), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -1.0, 2.0, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0, -2, 1, -6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm index 3687c0ecd65..52de77e0ee4 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -60,9 +61,9 @@ TensorRef GetTensorRef(int ref) { SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, {GetTensorRef(1)}); XCTAssertTrue(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, 0.0, 2.0, 8.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testReluClipOnly { @@ -73,9 +74,9 @@ TensorRef GetTensorRef(int ref) { SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, {GetTensorRef(1)}); XCTAssertTrue(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0, 0.0, 2.0, 6.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testReluAlphaOnly { @@ -86,9 +87,9 @@ TensorRef GetTensorRef(int ref) { SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, {GetTensorRef(1)}); XCTAssertTrue(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({-3.0, 0.0, 2.0, 8.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testReluClipAndAlpha { @@ -99,9 +100,9 @@ TensorRef GetTensorRef(int ref) { SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, {GetTensorRef(1)}); XCTAssertTrue(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({-3.0, 0.0, 2.0, 6.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm index 48d292e2a1b..684e83b2db1 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm @@ -62,9 +62,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESHAPE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testReshape3x1x2To2x1x3 { @@ -84,9 +84,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESHAPE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testReshape1x1x4To2x2x1 { @@ -106,9 +106,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESHAPE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testReshapeBatchIsUnsupported { @@ -128,9 +128,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESHAPE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.error_message().find("Only identical batch dimension is supported") != + XCTAssertTrue(std::string(status.message()).find("Only identical batch dimension is supported") != std::string::npos, - @"%s", status.error_message().c_str()); + @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm index 49febc1d4c6..f00b2766bdc 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -65,9 +66,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testResizeBilinear1x2x1To1x4x1 { @@ -89,9 +90,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 2.5, 4.0, 4.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testResizeBilinear2x2x1To4x4x1 { @@ -113,11 +114,11 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 4.0, 6.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors( {1.0, 2.5, 4.0, 4.0, 3.5, 4.75, 6.0, 6.0, 6.0, 7.0, 8.0, 8.0, 6.0, 7.0, 8.0, 8.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testResizeBilinear2x2x1To3x3x1WithoutHalfPixel { @@ -140,10 +141,10 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 1.666666, 2.0, 2.333333, 3.0, 3.333333, 3.0, 3.666666, 4.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testResizeBilinear2x2x1To3x3x1WithHalfPixel { @@ -166,9 +167,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 1.5, 2.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testResizeNearest1x2x1To2x4x1 { @@ -190,9 +191,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm index 827f85fe00a..e0c29561f9b 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -64,9 +65,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 2, 3, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSliceNoStrides { @@ -88,9 +89,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 3}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSliceNoStridesStartOffset { @@ -112,9 +113,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({3, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSliceStridesByHeight { @@ -136,9 +137,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 3}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSliceStridesByWidth { @@ -160,9 +161,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSliceStridesByChannels { @@ -184,9 +185,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm index f5c4770bd8b..9196e9fe094 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm @@ -17,6 +17,7 @@ limitations under the License. #import +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -62,9 +63,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.1, 0.2, 0.1, 0.2})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 1, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSoftmaxDoesNotWorkForHeightAxis { @@ -84,7 +85,7 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4})); auto status = model.Invoke(); - XCTAssertFalse(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertFalse(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSoftmaxDoesNotWorkForWidthAxis { @@ -104,7 +105,7 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4})); auto status = model.Invoke(); - XCTAssertFalse(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertFalse(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testSoftmax1x1 { @@ -126,11 +127,11 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.1f, 0.2f, 0.3f, 0.4f})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors( {std::exp(0.1f) / sum, std::exp(0.2f) / sum, std::exp(0.3f) / sum, std::exp(0.4f) / sum}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm index 6e82ebe0361..17e398817b2 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm @@ -51,7 +51,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTFail(@"PopulateTensor()"); } const auto status = model.Invoke(); - if (!status.ok()) XCTFail(@"%s", status.error_message().c_str()); + if (!status.ok()) XCTFail(@"%s", std::string(status.message()).c_str()); const std::vector& actual = model.GetOutput(0); const std::vector expected = {1.0f, 2.0f, 3.0f, 4.0f}; XCTAssertEqual(actual[0], expected[0]); @@ -69,7 +69,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTFail(@"PopulateTensor()"); } const auto status = model.Invoke(); - if (!status.ok()) XCTFail(@"%s", status.error_message().c_str()); + if (!status.ok()) XCTFail(@"%s", std::string(status.message()).c_str()); const std::vector& actual = model.GetOutput(0); const std::vector expected = {1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f}; XCTAssertEqual(actual[0], expected[0]); @@ -94,7 +94,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTFail(@"PopulateTensor()"); } const auto status = model.Invoke(); - if (!status.ok()) XCTFail(@"%s", status.error_message().c_str()); + if (!status.ok()) XCTFail(@"%s", std::string(status.message()).c_str()); const std::vector& actual = model.GetOutput(0); const std::vector expected = {1.0f, 2.0f, 3.0f, // 4.0f, 5.0f, 6.0f, // @@ -126,7 +126,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTFail(@"PopulateTensor()"); } const auto status = model.Invoke(); - if (!status.ok()) XCTFail(@"%s", status.error_message().c_str()); + if (!status.ok()) XCTFail(@"%s", std::string(status.message()).c_str()); const std::vector& actual = model.GetOutput(0); const std::vector expected = {1.0f, 2.0f, 3.0f, 4.0f, // 5.0f, 6.0f, 7.0f, 8.0f, // diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h index 7a4066fea0a..ffa567a5a9d 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h @@ -45,7 +45,7 @@ class SingleOpModel { return true; } - Status Invoke(); + absl::Status Invoke(); const std::vector& GetOutput(int index) const { return outputs_[index].data; @@ -57,16 +57,16 @@ class SingleOpModel { std::vector outputs_; }; -Status CompareVectors(const std::vector& reference, - const std::vector& output, float max_error); +absl::Status CompareVectors(const std::vector& reference, + const std::vector& output, float max_error); /// Helper function that compiles previously configured graph (with added /// tasks), initializes graph with specified inputs, invokes and fills specified /// outputs -Status RunGraph(const std::vector& graph, - id device, - const std::map& inputs, - std::map* outputs); +absl::Status RunGraph(const std::vector& graph, + id device, + const std::map& inputs, + std::map* outputs); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm index 3edc8669f2c..80c0e2457af 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm @@ -65,7 +65,7 @@ SingleOpModel::SingleOpModel(Operation&& operation, const std::vector input_ids; input_ids.reserve(inputs_.size()); for (const auto& input : inputs_) { @@ -143,16 +143,16 @@ Status SingleOpModel::Invoke() { RETURN_IF_ERROR(ConvertFromPHWC4(absl::MakeConstSpan(output_pointer, elements_count), output.shape, absl::MakeSpan(output.data))); } - return OkStatus(); + return absl::OkStatus(); } -Status CompareVectors(const std::vector& reference, const std::vector& output, - float max_error) { +absl::Status CompareVectors(const std::vector& reference, const std::vector& output, + float max_error) { if (reference.size() != output.size()) { const std::string message = "CompareVectors: vectors size does not match for reference: " + std::to_string(reference.size()) + " vs. output: " + std::to_string(output.size()); - return tflite::gpu::InternalError(message); + return absl::InternalError(message); } for (int i = 0; i < reference.size(); i++) { float error = std::abs(reference[i] - output[i]); @@ -160,15 +160,15 @@ Status CompareVectors(const std::vector& reference, const std::vector& nodes, id device, - const std::map& inputs, - std::map* outputs) { +absl::Status RunGraph(const std::vector& nodes, id device, + const std::map& inputs, + std::map* outputs) { std::vector inputBufferIDs; inputBufferIDs.reserve(inputs.size()); for (const auto& input : inputs) { @@ -251,7 +251,7 @@ Status RunGraph(const std::vector& nodes, id +#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -81,10 +82,10 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 4, 2, 4, 1, 1, 4, 8, 4, 8, 1, 1, 3, 5, 3, 5, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTransposeConvO1H2W2I1Stride1x1Adjacent2x2 { @@ -120,11 +121,11 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1, 1, 1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({1, 3, 3, 2, 0, 0, 4, 10, 10, 6, 0, 0, 4, 10, 10, 6, 0, 0, 3, 7, 7, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTransposeConvO1H3W3I1Stride1x1Adjacent1x1 { @@ -160,10 +161,10 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({7, 11, 7, 1, 7, 11, 7, 1, 4, 6, 4, 1, 1, 1, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTransposeConvO2H1W1I2Stride1x1Dilation1x1 { @@ -199,9 +200,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({4, 8, 1, 1, 4, 8, 1, 1, 1, 1, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTransposeConvO1H1W1I1Stride2x2Dilation1x1 { @@ -238,11 +239,11 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 0, 2, 0, 0, 0, 4, 0, 8})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - (void)testTransposeConv4x4 { @@ -277,13 +278,13 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {0.0f, 1.0f, 2.0f, 3.0f})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); status = CompareVectors({0.0f, 0.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 2.0f, 4.0f, 6.0f, 12.0f, 6.0f, 12.0f, 4.0f, 8.0f, 2.0f, 4.0f, 6.0f, 12.0f, 6.0f, 12.0f, 4.0f, 8.0f, 2.0f, 4.0f, 5.0f, 10.0f, 5.0f, 10.0f, 3.0f, 6.0f}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm index f7f08b273ae..4c6bb140a96 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.mm +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -198,13 +198,13 @@ class Delegate { } } - Status BindBufferToTensor(id buffer, int tensor_index) { + absl::Status BindBufferToTensor(id buffer, int tensor_index) { for (auto& input : graph_inputs_) { if (input.tensor_id == tensor_index) { input_output_buffers_[input.id] = buffer; bphwc4_buffers_[input.id] = buffer; input.set_externally = true; - return OkStatus(); + return absl::OkStatus(); } } for (auto& output : graph_outputs_) { @@ -212,10 +212,10 @@ class Delegate { input_output_buffers_[output.id] = buffer; bphwc4_buffers_[output.id] = buffer; output.set_externally = true; - return OkStatus(); + return absl::OkStatus(); } } - return NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index)); + return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index)); } void SetCommandEncoder( @@ -225,7 +225,7 @@ class Delegate { external_command_encoder_ = encoder; } - Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { + absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { // Extract TFLite delegate execution plan from the context and convert it into FlowGraph32. GraphFloat32 graph; RETURN_IF_ERROR(BuildModel(context, delegate_params, &graph)); @@ -234,7 +234,7 @@ class Delegate { NullTransformationReporter reporter; ModelTransformer transformer(&graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return InternalError("Graph general transformations failed"); + return absl::InternalError("Graph general transformations failed"); } // TODO(impjdi): Remove code duplication. @@ -265,7 +265,7 @@ class Delegate { if (tensor->allocation_type == TfLiteAllocationType::kTfLiteMmapRo) continue; const auto* input = find_value(tensor_index); if (!input || tensor->type != TfLiteType::kTfLiteFloat32) { - return NotFoundError("Input tensor is not found in the graph."); + return absl::NotFoundError("Input tensor is not found in the graph."); } inputs_.push_back(input->id); @@ -283,7 +283,7 @@ class Delegate { auto* tensor = context->tensors + tensor_index; const auto* output = find_value(tensor_index); if (!output || tensor->type != TfLiteType::kTfLiteFloat32) { - return NotFoundError("Output tensor is not found in the graph."); + return absl::NotFoundError("Output tensor is not found in the graph."); } outputs_.push_back(output->id); @@ -323,7 +323,9 @@ class Delegate { const auto& input_tensor = tensors_[input]; const auto tensor_id = input_tensor.tensor_id; input_ids.push_back(input); - if (input_tensor.shape.b != 1) return UnimplementedError("Batching is not supported yet."); + if (input_tensor.shape.b != 1) { + return absl::UnimplementedError("Batching is not supported yet."); + } input_dimensions[input] = input_tensor.shape; graph_inputs_.push_back({ input, // .id @@ -346,7 +348,7 @@ class Delegate { isFloat16:options_.allow_precision_loss convertToPBHWC4:true]; if (converter_to_BPHWC4_ == nil) { - return InternalError("Error initialization of input buffer converter"); + return absl::InternalError("Error initialization of input buffer converter"); } } } else { @@ -383,7 +385,7 @@ class Delegate { isFloat16:options_.allow_precision_loss convertToPBHWC4:false]; if (converter_from_BPHWC4_ == nil) { - return InternalError("Error initialization of output buffer converter"); + return absl::InternalError("Error initialization of output buffer converter"); } } } else { @@ -406,10 +408,10 @@ class Delegate { RETURN_IF_ERROR([inference_context_ setInputDimensions:input_dimensions outputDimensions:&output_dimensions taskDescriptors:optimized_model]); - return OkStatus(); + return absl::OkStatus(); } - Status Invoke(TfLiteContext* context) { + absl::Status Invoke(TfLiteContext* context) { if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) gpu_alarm_clock_->Stop(); // We need only synchronization so volatile works better than atomic which reads from global @@ -514,11 +516,11 @@ class Delegate { // External command encoder is assigned so all output buffers are controlled by a user. for (const auto& output : graph_outputs_) { if (!output.set_externally) { - return InternalError( + return absl::InternalError( "External command encoder is used, but not all output buffers are bound."); } } - return OkStatus(); + return absl::OkStatus(); } // Retrieve data from GPU and convert from PHWC4 to HWC. @@ -529,7 +531,7 @@ class Delegate { const void* gpu_ptr = [input_output_buffers_[output.id] contents]; std::memcpy(tensor->data.f, gpu_ptr, output.shape.DimensionsProduct() * sizeof(float)); } - return OkStatus(); + return absl::OkStatus(); } TfLiteDelegate* tflite_delegate() { return &delegate_; } @@ -596,7 +598,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = metal_delegate->Prepare(context, params); if (status.ok()) return metal_delegate; context->ReportError(context, "TfLiteGpuDelegate Prepare: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return nullptr; }, // .free @@ -610,7 +612,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = GetMetalDelegate(node)->Invoke(context); if (status.ok()) return kTfLiteOk; context->ReportError(context, "TfLiteMetalDelegate Invoke: %s", - status.error_message().c_str()); + std::string(status.message()).c_str()); return kTfLiteError; }, nullptr, // .profiling_string diff --git a/tensorflow/lite/delegates/gpu/spi.h b/tensorflow/lite/delegates/gpu/spi.h index c7f041f3db1..a70f8dbb326 100644 --- a/tensorflow/lite/delegates/gpu/spi.h +++ b/tensorflow/lite/delegates/gpu/spi.h @@ -33,8 +33,8 @@ class TensorObjectConverter { public: virtual ~TensorObjectConverter() = default; - virtual Status Convert(const TensorObject& input, - const TensorObject& output) = 0; + virtual absl::Status Convert(const TensorObject& input, + const TensorObject& output) = 0; }; class TensorObjectConverterBuilder { @@ -44,7 +44,7 @@ class TensorObjectConverterBuilder { virtual bool IsSupported(const TensorObjectDef& input, const TensorObjectDef& output) const = 0; - virtual Status MakeConverter( + virtual absl::Status MakeConverter( const TensorObjectDef& input, const TensorObjectDef& output, std::unique_ptr* converter) = 0; }; @@ -66,13 +66,13 @@ class TensorTie { virtual ~TensorTie() = default; - virtual Status SetExternalObject(TensorObject obj) = 0; + virtual absl::Status SetExternalObject(TensorObject obj) = 0; virtual TensorObject GetExternalObject() = 0; - virtual Status CopyToExternalObject() = 0; + virtual absl::Status CopyToExternalObject() = 0; - virtual Status CopyFromExternalObject() = 0; + virtual absl::Status CopyFromExternalObject() = 0; const TensorTieDef& def() const { return def_; } From 8c84097861cf3d003aaed8457e296213898aeba6 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Mon, 23 Mar 2020 13:32:25 -0700 Subject: [PATCH 439/492] Add tests for TPUStrategy for compilation corner cases. PiperOrigin-RevId: 302505555 Change-Id: I2b6b5acc4249661caad901dc06fc7c53b98eb22b --- .../python/distribute/tpu_strategy_test.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 4b88ae7134a..c44a621ed77 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -18,13 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import tpu_strategy as tpu_lib from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver from tensorflow.python.eager import def_function +from tensorflow.python.eager import function from tensorflow.python.eager import remote from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -214,6 +219,95 @@ class TPUStrategyTest(test.TestCase): self.assertEndsWith(second_core_strategy.extended.worker_devices[0], "device:TPU:1") + def test_tpu_tf_function_same_device(self): + with ops.device("/device:TPU:0"): + a = variables.Variable(1) + + @function.defun_with_attributes(attributes={"_noinline": True}) + def get_a_plus_one(): + return a + 1 + + @def_function.function( + input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) + def foo(x): + with ops.device("/device:TPU:0"): + b = x + get_a_plus_one() + return b + 1 + + result = foo(a) + self.assertAllEqual(4, result) + + def test_tpu_return_int32(self): + with ops.device("/device:TPU:0"): + a = variables.Variable(0) + + @def_function.function + def foo(): + return a + 1 + + @def_function.function + def bar(): + with ops.device("/device:TPU:1"): + return foo() + + with ops.device("/device:CPU:0"): + result = bar() + 1 + self.assertAllEqual(result, 2) + + def test_control_output_in_while_body_fn(self): + strategy = get_tpu_strategy() + + with strategy.scope(): + v = variables.Variable( + 0.0, aggregation=variables.VariableAggregation.MEAN) + + @def_function.function + def train_step(): + + def step_fn(): + v.assign_add(1) + + for _ in math_ops.range(2): + strategy.run(step_fn) + + train_step() + self.assertEqual(2.0, v.numpy()) + + def test_cluster_in_graph_and_while_body_fn(self): + strategy = get_tpu_strategy() + + @def_function.function + def train_step(): + + def step_fn(prev): + s = prev + 1 + return s + + def init_fn(): + return array_ops.zeros(shape=()) + + prev = strategy.run(init_fn) + for _ in math_ops.range(10): + prev = strategy.run(step_fn, args=(prev,)) + return strategy.reduce(reduce_util.ReduceOp.SUM, prev, axis=None) + + sum_val = train_step().numpy().astype(float) + self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10) + + def test_two_clusters_with_same_fn(self): + strategy = get_tpu_strategy() + + @def_function.function + def foo(x): + return strategy.run(lambda x: x + 1, (x,)) + + @def_function.function + def bar(x): + foo(x) + return foo(x) + + bar(1) + if __name__ == "__main__": test.main() From 8591cedd08f9a57fe5799ab3dc99e3cb55c62cbc Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Mon, 23 Mar 2020 13:57:30 -0700 Subject: [PATCH 440/492] Add additional TPUStrategy tests for external variable and keras metric. PiperOrigin-RevId: 302510831 Change-Id: I82dcc1f955fae533af01c82f7b5ad0f441a14b14 --- .../python/distribute/tpu_strategy_test.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index c44a621ed77..f0429ab07ef 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import tpu_strategy as tpu_lib from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver @@ -308,6 +310,42 @@ class TPUStrategyTest(test.TestCase): bar(1) + def test_using_external_variable_inside_tf_function(self): + strategy = get_tpu_strategy() + dataset = dataset_ops.Dataset.range(10, output_type=dtypes.float32).batch(2) + input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) + + v = variables.Variable(2.0) + + @def_function.function + def train_step(data): + def computation(inputs): + return inputs + v + return strategy.run(computation, args=(data,)) + + expected_result = [[x + 2.] for x in range(0, strategy.num_replicas_in_sync) + ] + self.assertAllEqual( + expected_result, + strategy.experimental_local_results(train_step(next(input_iterator)))) + + def test_keras_metric_outside_strategy_scope_per_replica(self): + strategy = get_tpu_strategy() + metric = keras.metrics.Mean("test_metric", dtype=dtypes.float32) + + dataset = dataset_ops.Dataset.range(10).batch(2) + dataset = strategy.experimental_distribute_dataset(dataset) + + @def_function.function + def step_fn(i): + metric.update_state(i) + + with self.assertRaisesRegex(ValueError, "Trying to run metric.update_state " + "in replica context"): + with strategy.scope(): + for i in dataset: + strategy.run(step_fn, args=(i,)) + if __name__ == "__main__": test.main() From 48598dddf7a35471fb9dff431652df56306e1ad0 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Mon, 23 Mar 2020 14:01:28 -0700 Subject: [PATCH 441/492] Always log the latency of the very first inference. PiperOrigin-RevId: 302511687 Change-Id: Idffdadddcae1b58a624a338d33351bef30fc6b50 --- .../lite/tools/benchmark/benchmark_main.cc | 2 -- .../lite/tools/benchmark/benchmark_model.cc | 35 ++++++++++++------- .../lite/tools/benchmark/benchmark_model.h | 1 + .../benchmark/benchmark_plus_flex_main.cc | 2 -- ...nchmark_tflite_performance_options_main.cc | 3 -- 5 files changed, 23 insertions(+), 20 deletions(-) diff --git a/tensorflow/lite/tools/benchmark/benchmark_main.cc b/tensorflow/lite/tools/benchmark/benchmark_main.cc index ad7ec23a55c..4eebc31cab8 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_main.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_main.cc @@ -24,8 +24,6 @@ namespace benchmark { int Main(int argc, char** argv) { TFLITE_LOG(INFO) << "STARTING!"; BenchmarkTfLiteModel benchmark; - BenchmarkLoggingListener listener; - benchmark.AddListener(&listener); if (benchmark.Run(argc, argv) != kTfLiteOk) { TFLITE_LOG(ERROR) << "Benchmarking failed."; return EXIT_FAILURE; diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc index 8dc3efb4a00..854b777dccc 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc @@ -47,10 +47,21 @@ void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults& results) { auto inference_us = results.inference_time_us(); auto init_us = results.startup_latency_us(); auto warmup_us = results.warmup_time_us(); - TFLITE_LOG(INFO) << "Average inference timings in us: " - << "Warmup: " << warmup_us.avg() << ", " + auto init_mem_usage = results.init_mem_usage(); + auto overall_mem_usage = results.overall_mem_usage(); + TFLITE_LOG(INFO) << "Inference timings in us: " << "Init: " << init_us << ", " - << "Inference: " << inference_us.avg(); + << "First inference: " << warmup_us.first() << ", " + << "Warmup (avg): " << warmup_us.avg() << ", " + << "Inference (avg): " << inference_us.avg(); + + TFLITE_LOG(INFO) + << "Note: as the benchmark tool itself affects memory footprint, the " + "following is only APPROXIMATE to the actual memory footprint of the " + "model at runtime. Take the information at your discretion."; + TFLITE_LOG(INFO) << "Peak memory footprint (MB): init=" + << init_mem_usage.max_rss_kb / 1024.0 + << " overall=" << overall_mem_usage.max_rss_kb / 1024.0; } std::vector BenchmarkModel::GetFlags() { @@ -193,18 +204,16 @@ TfLiteStatus BenchmarkModel::Run() { params_.Get("max_secs"), REGULAR, &status); const auto overall_mem_usage = profiling::memory::GetMemoryUsage() - start_mem_usage; - listeners_.OnBenchmarkEnd({model_size_mb, startup_latency_us, input_bytes, - warmup_time_us, inference_time_us, init_mem_usage, - overall_mem_usage}); - TFLITE_LOG(INFO) - << "Note: as the benchmark tool itself affects memory footprint, the " - "following is only APPROXIMATE to the actual memory footprint of the " - "model at runtime. Take the information at your discretion."; - TFLITE_LOG(INFO) << "Peak memory footprint (MB): init=" - << init_mem_usage.max_rss_kb / 1024.0 - << " overall=" << overall_mem_usage.max_rss_kb / 1024.0; + const BenchmarkResults final_results( + model_size_mb, startup_latency_us, input_bytes, warmup_time_us, + inference_time_us, init_mem_usage, overall_mem_usage); + listeners_.OnBenchmarkEnd(final_results); + // We always TFLITE_LOG the benchmark result regardless whether a + // BenchmarkListener is registered or not. + BenchmarkLoggingListener log_output; + log_output.OnBenchmarkEnd(final_results); return status; } diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.h b/tensorflow/lite/tools/benchmark/benchmark_model.h index 977bda7d010..8a207a6fd45 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_model.h @@ -151,6 +151,7 @@ class BenchmarkListeners : public BenchmarkListener { // Benchmark listener that just logs the results of benchmark run. class BenchmarkLoggingListener : public BenchmarkListener { + public: void OnBenchmarkEnd(const BenchmarkResults& results) override; }; diff --git a/tensorflow/lite/tools/benchmark/benchmark_plus_flex_main.cc b/tensorflow/lite/tools/benchmark/benchmark_plus_flex_main.cc index 8b7564f3c17..8dd2db11dd8 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_plus_flex_main.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_plus_flex_main.cc @@ -24,8 +24,6 @@ int Main(int argc, char** argv) { ::tflite::InitTensorFlow(); TFLITE_LOG(INFO) << "STARTING!"; BenchmarkTfLiteModel benchmark; - BenchmarkLoggingListener listener; - benchmark.AddListener(&listener); benchmark.Run(argc, argv); return 0; } diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc index f4271e35cc4..6bf4cf4e193 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc @@ -23,9 +23,6 @@ namespace benchmark { int Main(int argc, char** argv) { TFLITE_LOG(INFO) << "STARTING!"; BenchmarkTfLiteModel benchmark; - BenchmarkLoggingListener listener; - benchmark.AddListener(&listener); - BenchmarkPerformanceOptions all_options_benchmark(&benchmark); all_options_benchmark.Run(argc, argv); return EXIT_SUCCESS; From 197f27691d0175aaefa799a3133d0e35e05ccc6a Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 23 Mar 2020 14:05:58 -0700 Subject: [PATCH 442/492] [XLA][MLIR] Rename Handle arg names to 'instr' to reduce confusion. PiperOrigin-RevId: 302512882 Change-Id: I9a18fde407705270a074a5ce29b37e0ef4d41f69 --- .../service/mlir_gpu/hlo_dialect_emitter.cc | 68 ++++++------ .../service/mlir_gpu/hlo_dialect_emitter.h | 12 +-- .../service/mlir_gpu/lhlo_dialect_emitter.cc | 100 +++++++++--------- .../service/mlir_gpu/lhlo_dialect_emitter.h | 22 ++-- 4 files changed, 101 insertions(+), 101 deletions(-) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index 1c2fc1962cf..c12418a0c49 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -122,58 +122,58 @@ Status HloDialectEmitter::DefaultAction(HloInstruction* instr) { return Status::OK(); } -Status HloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) { +Status HloDialectEmitter::HandleBroadcast(HloInstruction* instr) { mlir::DenseIntElementsAttr broadcast_dim = - CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_); + CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_); TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType( - broadcast->shape(), builder_)); + instr->shape(), builder_)); - instruction_to_values_[broadcast] = builder_.create( - getLocation(broadcast), llvm::makeArrayRef(res_type), - instruction_to_values_[broadcast->operand(0)], broadcast_dim); + instruction_to_values_[instr] = builder_.create( + getLocation(instr), llvm::makeArrayRef(res_type), + instruction_to_values_[instr->operand(0)], broadcast_dim); return Status::OK(); } -Status HloDialectEmitter::HandleParameter(HloInstruction* param) { - auto argValue = arguments_[param->parameter_number()]; - instruction_to_values_[param] = argValue; +Status HloDialectEmitter::HandleParameter(HloInstruction* instr) { + auto argValue = arguments_[instr->parameter_number()]; + instruction_to_values_[instr] = argValue; return Status::OK(); } -Status HloDialectEmitter::HandleConstant(HloInstruction* constant) { - auto shape = constant->shape(); +Status HloDialectEmitter::HandleConstant(HloInstruction* instr) { + auto shape = instr->shape(); if (!shape.IsArray() || shape.rank() != 0) { return Unimplemented("non-scalar constants are not supported yet"); } TF_ASSIGN_OR_RETURN(auto type, ConvertTensorShapeToType( - constant->shape(), builder_)); + instr->shape(), builder_)); TF_ASSIGN_OR_RETURN(auto value, CreateDenseElementsAttrFromLiteral( - constant->literal(), builder_)); + instr->literal(), builder_)); auto const_value = - builder_.create(getLocation(constant), type, value); - instruction_to_values_[constant] = const_value; + builder_.create(getLocation(instr), type, value); + instruction_to_values_[instr] = const_value; return Status::OK(); } -Status HloDialectEmitter::HandleReduce(HloInstruction* reduce) { +Status HloDialectEmitter::HandleReduce(HloInstruction* instr) { llvm::SmallVector operands; - for (auto operand : reduce->operands()) { + for (auto operand : instr->operands()) { operands.push_back(instruction_to_values_.at(operand)); } const unsigned num_inputs = operands.size() / 2; TF_ASSIGN_OR_RETURN( const auto return_type, - ConvertTensorShapeToType(reduce->shape(), builder_)); + ConvertTensorShapeToType(instr->shape(), builder_)); const auto dimensions_attr = - CreateDenseIntElementsAttrFromVector(reduce->dimensions(), builder_); + CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_); auto reduceOp = builder_.create( - getLocation(reduce), return_type, + getLocation(instr), return_type, llvm::makeArrayRef(operands).take_front(num_inputs), llvm::makeArrayRef(operands).take_back(num_inputs), dimensions_attr); { - auto computation = reduce->to_apply(); + auto computation = instr->to_apply(); auto block = new mlir::Block(); llvm::SmallVector arguments; arguments.reserve(computation->num_parameters()); @@ -188,38 +188,38 @@ Status HloDialectEmitter::HandleReduce(HloInstruction* reduce) { TF_ASSIGN_OR_RETURN(auto result, emitter.EmitComputation(*computation)); OpBuilder body_builder(block); body_builder.setInsertionPointToEnd(block); - body_builder.create(getLocation(reduce), + body_builder.create(getLocation(instr), ArrayRef{result}); } // TODO(b/137624192) Add support for multiple results. - instruction_to_values_[reduce] = reduceOp.getResult(0); + instruction_to_values_[instr] = reduceOp.getResult(0); return Status::OK(); } -Status HloDialectEmitter::HandleCompare(HloInstruction* compare) { +Status HloDialectEmitter::HandleCompare(HloInstruction* instr) { TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType( - compare->shape(), builder_)); + instr->shape(), builder_)); auto comparison_direction_attr = builder_.getNamedAttr( "comparison_direction", builder_.getStringAttr( - ComparisonDirectionToString(compare->comparison_direction()))); + ComparisonDirectionToString(instr->comparison_direction()))); llvm::SmallVector arguments; - for (auto operand : compare->operands()) { + for (auto operand : instr->operands()) { arguments.push_back(instruction_to_values_[operand]); } - instruction_to_values_[compare] = builder_.create( - getLocation(compare), llvm::makeArrayRef(res_type), arguments, + instruction_to_values_[instr] = builder_.create( + getLocation(instr), llvm::makeArrayRef(res_type), arguments, comparison_direction_attr); return Status::OK(); } -Status HloDialectEmitter::HandleIota(HloInstruction* iota) { +Status HloDialectEmitter::HandleIota(HloInstruction* instr) { mlir::IntegerAttr iota_dim = builder_.getI64IntegerAttr( - static_cast(iota)->iota_dimension()); + static_cast(instr)->iota_dimension()); TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType( - iota->shape(), builder_)); - instruction_to_values_[iota] = - builder_.create(getLocation(iota), res_type, iota_dim); + instr->shape(), builder_)); + instruction_to_values_[instr] = + builder_.create(getLocation(instr), res_type, iota_dim); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h index 20d2d1418ca..9590a947734 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h @@ -52,12 +52,12 @@ class HloDialectEmitter : public DfsHloVisitorWithDefault { StatusOr EmitComputation(const HloComputation& computation); Status DefaultAction(HloInstruction* instr) override; - Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleCompare(HloInstruction* compare) override; - Status HandleConstant(HloInstruction* constant) override; - Status HandleIota(HloInstruction* iota) override; - Status HandleParameter(HloInstruction* param) override; - Status HandleReduce(HloInstruction* reduce) override; + Status HandleBroadcast(HloInstruction* instr) override; + Status HandleCompare(HloInstruction* instr) override; + Status HandleConstant(HloInstruction* instr) override; + Status HandleIota(HloInstruction* instr) override; + Status HandleParameter(HloInstruction* instr) override; + Status HandleReduce(HloInstruction* instr) override; private: mlir::Location getLocation(const HloInstruction* instr) const; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 3f17694af1d..93ca91a9670 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -242,23 +242,23 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { return Status::OK(); } -Status LhloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) { +Status LhloDialectEmitter::HandleBroadcast(HloInstruction* instr) { DenseIntElementsAttr broadcast_dim = - CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_); + CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_); - TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*broadcast)); + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); OpBuilder func_builder(function.getBody()); func_builder.create( - getLocation(broadcast), function.getArgument(0), function.getArgument(1), + getLocation(instr), function.getArgument(0), function.getArgument(1), broadcast_dim); return Status::OK(); } -Status LhloDialectEmitter::HandleFusion(HloInstruction* fusion) { - TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*fusion)); +Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); OpBuilder func_builder(function.getBody()); auto fusion_op = - func_builder.create(getLocation(fusion), llvm::None); + func_builder.create(getLocation(instr), llvm::None); // Load the HLO argument tensors from the corresponding buffers. The last // argument is for the result, so no need to load it. @@ -266,63 +266,63 @@ Status LhloDialectEmitter::HandleFusion(HloInstruction* fusion) { llvm::SmallVector arg_values; for (int i = 0, e = function.getNumArguments() - 1; i < e; ++i) { arg_values.push_back(body_builder.create<::mlir::TensorLoadOp>( - getLocation(fusion), function.getArgument(i))); + getLocation(instr), function.getArgument(i))); } HloDialectEmitter hlo_emitter(emission_context_, body_builder, arg_values); TF_ASSIGN_OR_RETURN( auto result, - hlo_emitter.EmitComputation(*fusion->fused_instructions_computation())); + hlo_emitter.EmitComputation(*instr->fused_instructions_computation())); // Insert the write-back from the HLO computation to the result argument // buffer. body_builder.setInsertionPoint(fusion_op.region().back().getTerminator()); Value result_memref = function.getArgument(function.getNumArguments() - 1); - body_builder.create<::mlir::TensorStoreOp>(getLocation(fusion), result, + body_builder.create<::mlir::TensorStoreOp>(getLocation(instr), result, result_memref); return Status::OK(); } -Status LhloDialectEmitter::HandleReduce(HloInstruction* reduce) { - TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*reduce)); +Status LhloDialectEmitter::HandleReduce(HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); llvm::SmallVector arg_values{function.args_begin(), function.args_end()}; OpBuilder builder(function.getBody()); - auto loc = getLocation(reduce); - int input_count = reduce->operand_count() / 3; + auto loc = getLocation(instr); + int input_count = instr->operand_count() / 3; auto inputs = llvm::makeArrayRef(arg_values).slice(input_count); auto init_values = llvm::makeArrayRef(arg_values).slice(input_count, input_count); auto results = llvm::makeArrayRef(arg_values).slice(2 * input_count, input_count); auto dimensions_attr = - CreateDenseIntElementsAttrFromVector(reduce->dimensions(), builder_); + CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_); auto reduce_op = builder.create(loc, inputs, init_values, results, dimensions_attr); - reduce_op.ensureTerminator(reduce_op.body(), builder, getLocation(reduce)); + reduce_op.ensureTerminator(reduce_op.body(), builder, getLocation(instr)); return SpliceHloComputation(OpBuilder{&reduce_op.body()}, loc, - *reduce->to_apply(), emission_context_); + *instr->to_apply(), emission_context_); } -Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* reduce_window) { - TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*reduce_window)); +Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); llvm::SmallVector arg_values{function.args_begin(), function.args_end()}; OpBuilder builder(function.getBody()); - auto loc = getLocation(reduce_window); + auto loc = getLocation(instr); // Collect attribute values. llvm::SmallVector window_dimensions, window_strides, base_dilations, window_dilations; llvm::SmallVector padding; - int64 rank = reduce_window->window().dimensions_size(); + int64 rank = instr->window().dimensions_size(); window_dimensions.reserve(rank); window_strides.reserve(rank); base_dilations.reserve(rank); window_dilations.reserve(rank); padding.reserve(2 * rank); - for (const auto& window : reduce_window->window().dimensions()) { + for (const auto& window : instr->window().dimensions()) { window_dimensions.push_back(window.size()); window_strides.push_back(window.stride()); base_dilations.push_back(window.base_dilation()); @@ -341,23 +341,23 @@ Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* reduce_window) { CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2})); reduce_window_op.ensureTerminator(reduce_window_op.body(), builder, loc); return SpliceHloComputation(OpBuilder{&reduce_window_op.body()}, loc, - *reduce_window->to_apply(), emission_context_); + *instr->to_apply(), emission_context_); } -Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* hlo) { - TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*hlo)); +Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); llvm::SmallVector arg_values{function.args_begin(), function.args_end()}; OpBuilder builder(function.getBody()); - auto loc = getLocation(hlo); + auto loc = getLocation(instr); // Collect attribute values. llvm::SmallVector window_dimensions, window_strides, padding; - int64 rank = hlo->window().dimensions_size(); + int64 rank = instr->window().dimensions_size(); window_dimensions.reserve(rank); window_strides.reserve(rank); padding.reserve(2 * rank); - for (const auto& window : hlo->window().dimensions()) { + for (const auto& window : instr->window().dimensions()) { window_dimensions.push_back(window.size()); window_strides.push_back(window.stride()); padding.push_back(window.padding_low()); @@ -376,75 +376,75 @@ Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* hlo) { builder.createBlock(&select_scatter_op.select()); OpBuilder select_builder{&select_scatter_op.select()}; select_builder.create(loc); - TF_RETURN_IF_ERROR(SpliceHloComputation(select_builder, loc, *hlo->select(), + TF_RETURN_IF_ERROR(SpliceHloComputation(select_builder, loc, *instr->select(), emission_context_)); // Convert `scatter` computation. builder.createBlock(&select_scatter_op.scatter()); OpBuilder scatter_builder{&select_scatter_op.scatter()}; scatter_builder.create(loc); - TF_RETURN_IF_ERROR(SpliceHloComputation(scatter_builder, loc, *hlo->scatter(), - emission_context_)); + TF_RETURN_IF_ERROR(SpliceHloComputation( + scatter_builder, loc, *instr->scatter(), emission_context_)); return Status::OK(); } -Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) { - return ThunkEmitter(this).HandleCustomCall(custom_call); +Status LhloDialectEmitter::HandleCustomCall(HloInstruction* instr) { + return ThunkEmitter(this).HandleCustomCall(instr); } -Status LhloDialectEmitter::HandleParameter(HloInstruction* parameter) { +Status LhloDialectEmitter::HandleParameter(HloInstruction* instr) { return Status::OK(); } -Status LhloDialectEmitter::HandleCompare(HloInstruction* compare) { +Status LhloDialectEmitter::HandleCompare(HloInstruction* instr) { auto comparison_direction_attr = builder_.getNamedAttr( "comparison_direction", builder_.getStringAttr( - ComparisonDirectionToString(compare->comparison_direction()))); + ComparisonDirectionToString(instr->comparison_direction()))); - TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*compare)); + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); OpBuilder func_builder(function.getBody()); llvm::SmallVector arg_values{function.args_begin(), function.args_end()}; - func_builder.create(getLocation(compare), llvm::None, + func_builder.create(getLocation(instr), llvm::None, arg_values, comparison_direction_attr); return Status::OK(); } -Status LhloDialectEmitter::HandleConstant(HloInstruction* constant) { - auto shape = constant->shape(); +Status LhloDialectEmitter::HandleConstant(HloInstruction* instr) { + auto shape = instr->shape(); if (!shape.IsArray() || shape.rank() != 0) { return Unimplemented("non-scalar constants are not supported yet"); } - TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*constant)); + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); OpBuilder func_builder(function.getBody()); TF_ASSIGN_OR_RETURN(auto value, CreateDenseElementsAttrFromLiteral( - constant->literal(), func_builder)); - func_builder.create(getLocation(constant), value, + instr->literal(), func_builder)); + func_builder.create(getLocation(instr), value, *function.args_begin()); return Status::OK(); } -Status LhloDialectEmitter::HandleIota(HloInstruction* iota) { +Status LhloDialectEmitter::HandleIota(HloInstruction* instr) { mlir::IntegerAttr iota_dim = builder_.getI64IntegerAttr( - static_cast(iota)->iota_dimension()); + static_cast(instr)->iota_dimension()); - TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*iota)); + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); OpBuilder func_builder(function.getBody()); - func_builder.create(getLocation(iota), iota_dim, + func_builder.create(getLocation(instr), iota_dim, function.getArgument(0)); return Status::OK(); } -Status LhloDialectEmitter::HandleTuple(HloInstruction* tuple) { +Status LhloDialectEmitter::HandleTuple(HloInstruction* instr) { // For the root node of the entry computation we can elide writing the tuple // buffer. We can always figure out the contents of the tuples from buffer // assignment because we insert copies to ensure non-ambiguous output buffers. // GpuExecutable never reads the tuple buffer. - if (tuple == - tuple->parent()->parent()->entry_computation()->root_instruction()) { + if (instr == + instr->parent()->parent()->entry_computation()->root_instruction()) { return Status::OK(); } return Unimplemented("handling of typles not yet implemented"); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h index f39d20efe2f..5164591b055 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -53,17 +53,17 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, // Default action which emits code for most operations. Operations which are // special in some way are handled explicitly in HandleFoo methods. Status DefaultAction(HloInstruction* instr) override; - Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleCompare(HloInstruction* compare) override; - Status HandleConstant(HloInstruction* constant) override; - Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleIota(HloInstruction* iota) override; - Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce) override; - Status HandleReduceWindow(HloInstruction* reduce_window) override; - Status HandleSelectAndScatter(HloInstruction* hlo) override; - Status HandleTuple(HloInstruction* tuple) override; + Status HandleBroadcast(HloInstruction* instr) override; + Status HandleCompare(HloInstruction* instr) override; + Status HandleConstant(HloInstruction* instr) override; + Status HandleCustomCall(HloInstruction* instr) override; + Status HandleFusion(HloInstruction* instr) override; + Status HandleIota(HloInstruction* instr) override; + Status HandleParameter(HloInstruction* instr) override; + Status HandleReduce(HloInstruction* instr) override; + Status HandleReduceWindow(HloInstruction* instr) override; + Status HandleSelectAndScatter(HloInstruction* instr) override; + Status HandleTuple(HloInstruction* instr) override; Status FinishVisit(HloInstruction* root) override; From dab09e44d0be5633eefd63a075ca9d8e0ce5a7cb Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Mon, 23 Mar 2020 14:08:09 -0700 Subject: [PATCH 443/492] Temporarily disable tsan remote_cluster_test. PiperOrigin-RevId: 302513373 Change-Id: I0449055a7ecb675e22b0c1320b42d98de4c42080 --- tensorflow/python/eager/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 315e85feb3d..9df6113b95f 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -894,6 +894,7 @@ cuda_py_test( shard_count = 16, tags = [ "no_oss", # This test launches local server + "notsan", # TODO(b/152075365) ], deps = [ "//tensorflow/python:array_ops", From 8e19ad854a69dcd98a8b5661990bf1f2613e0fe5 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Mon, 23 Mar 2020 14:12:47 -0700 Subject: [PATCH 444/492] Temporarily disable flaky tests. PiperOrigin-RevId: 302514325 Change-Id: Ic23fe1dfcd62f6afda1947bdcc7bfafcdb6adb82 --- tensorflow/python/eager/remote_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index e0a9523ef57..a210ae0419a 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -98,7 +98,8 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): self.assertAllEqual( remote_output(constant_op.constant([1]))[0].numpy(), 2) - def testMultiDeviceFunctionAmbiguousDevice(self): + # TODO(b/148235520): Re-enable this test. + def DISABLED_testMultiDeviceFunctionAmbiguousDevice(self): @def_function.function def ambiguous_device(i): @@ -452,8 +453,9 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase): with ops.device('/job:my_worker/task:1/device:CPU:0'): self.assertAllEqual(worker_fn(), 8) + # TODO(b/152224115): Re-enable this test. @test_util.eager_lazy_remote_copy_on_and_off - def testSimpleParameterServerWithDeviceFilters(self): + def DISABLED_testSimpleParameterServerWithDeviceFilters(self): cluster_device_filters = server_lib.ClusterDeviceFilters() for i in range(2): cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps']) From a92ffc1b75f4b38953798266e98705f38bf22783 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Mon, 23 Mar 2020 14:18:17 -0700 Subject: [PATCH 445/492] Make ModelCheckpoint non-blocking and able to support multiple steps_per_execution. PiperOrigin-RevId: 302515633 Change-Id: I857267c32dc277466a1e60baebdb023229efc372 --- tensorflow/python/keras/callbacks.py | 43 +++++++++++++++-------- tensorflow/python/keras/callbacks_test.py | 41 +++++++++++++++++++++ 2 files changed, 70 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index e0b6ba52239..7c5124e923e 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1069,10 +1069,12 @@ class ModelCheckpoint(Callback): (`model.save(filepath)`). save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves the model after each epoch. When using integer, the callback saves the - model at end of this many batches. Note that if the saving isn't aligned - to epochs, the monitored metric may potentially be less reliable (it + model at end of this many batches. If the `Model` is compiled with + `experimental_steps_per_execution=N`, then the saving criteria will be + checked every Nth batch. Note that if the saving isn't aligned to + epochs, the monitored metric may potentially be less reliable (it could reflect as little as 1 batch, since the metrics get reset every - epoch). Defaults to `'epoch'` + epoch). Defaults to `'epoch'`. **kwargs: Additional arguments for backwards compatibility. Possible key is `period`. """ @@ -1087,6 +1089,7 @@ class ModelCheckpoint(Callback): save_freq='epoch', **kwargs): super(ModelCheckpoint, self).__init__() + self._supports_tf_logs = True self.monitor = monitor self.verbose = verbose self.filepath = filepath @@ -1095,6 +1098,7 @@ class ModelCheckpoint(Callback): self.save_freq = save_freq self.epochs_since_last_save = 0 self._batches_seen_since_last_saving = 0 + self._last_batch_seen = 0 # Deprecated field `load_weights_on_restart` is for loading the checkpoint # file from `filepath` at the start of `model.fit()` @@ -1197,13 +1201,9 @@ class ModelCheckpoint(Callback): del self._training_state self.model._training_state = None - def on_batch_end(self, batch, logs=None): - if self._implements_train_batch_hooks(): - logs = logs or {} - self._batches_seen_since_last_saving += 1 - if self._batches_seen_since_last_saving >= self.save_freq: - self._save_model(epoch=self._current_epoch, logs=logs) - self._batches_seen_since_last_saving = 0 + def on_train_batch_end(self, batch, logs=None): + if self._should_save_on_batch(batch): + self._save_model(epoch=self._current_epoch, logs=logs) def on_epoch_begin(self, epoch, logs=None): self._current_epoch = epoch @@ -1224,6 +1224,23 @@ class ModelCheckpoint(Callback): # TODO(rchao): Call `back_up` at finer period such as N steps. self._training_state.back_up(epoch) + def _should_save_on_batch(self, batch): + """Handles batch-level saving logic, supports steps_per_execution.""" + if self.save_freq == 'epoch': + return False + + if batch <= self._last_batch_seen: # New epoch. + add_batches = batch + 1 # batches are zero-indexed. + else: + add_batches = batch - self._last_batch_seen + self._batches_seen_since_last_saving += add_batches + self._last_batch_seen = batch + + if self._batches_seen_since_last_saving >= self.save_freq: + self._batches_seen_since_last_saving = 0 + return True + return False + def _save_model(self, epoch, logs): """Saves the model. @@ -1235,6 +1252,8 @@ class ModelCheckpoint(Callback): if isinstance(self.save_freq, int) or self.epochs_since_last_save >= self.period: + # Block only when saving interval is reached. + logs = tf_utils.to_numpy_or_python_type(logs) self.epochs_since_last_save = 0 filepath = self._get_file_path(epoch, logs) @@ -1400,10 +1419,6 @@ class ModelCheckpoint(Callback): # the file path with the largest file name. return file_path_with_largest_file_name - def _implements_train_batch_hooks(self): - # If save_freq="epoch", batch-level hooks don't need to be run. - return isinstance(self.save_freq, int) - @keras_export('keras.callbacks.EarlyStopping') class EarlyStopping(Callback): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 2b7f7c038c6..5de4cacfa8a 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -832,6 +832,47 @@ class KerasCallbacksTest(keras_parameterized.TestCase): 'filepath.*'): model.fit(train_ds, epochs=1, callbacks=[callback]) + def test_ModelCheckpoint_nonblocking(self): + filepath = self.get_temp_dir() + # Should only cause a sync block when saving is actually performed. + callback = keras.callbacks.ModelCheckpoint(filepath=filepath, save_freq=100) + self.assertTrue(callback._supports_tf_logs) + + model = keras.Sequential([keras.layers.Dense(1)]) + cb_list = keras.callbacks.CallbackList([callback], + model=model, + epochs=1, + steps=10, + verbose=0) + + with context.eager_mode(): + tensor = ops.convert_to_tensor(1.) + + def mock_numpy(): + raise RuntimeError( + 'If this error is seen, ModelCheckpoint is causing a blocking ' + 'NumPy conversion even when not checkpointing.') + + with test.mock.patch.object(tensor, 'numpy', mock_numpy): + logs = {'metric': tensor} + + cb_list.on_train_begin(logs) + cb_list.on_epoch_begin(0, logs) + cb_list.on_train_batch_begin(0, logs) + cb_list.on_train_batch_end(0, logs) + cb_list.on_epoch_end(0, logs) + cb_list.on_train_end(logs) + + cb_list.on_test_begin(logs) + cb_list.on_test_batch_begin(0, logs) + cb_list.on_test_batch_end(0, logs) + cb_list.on_test_end(logs) + + cb_list.on_predict_begin(logs) + cb_list.on_predict_batch_begin(logs) + cb_list.on_predict_batch_end(logs) + cb_list.on_predict_end(logs) + def test_EarlyStopping(self): with self.cached_session(): np.random.seed(123) From 4cbd9f0a86c783ad705ee98967967078a53772cc Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 23 Mar 2020 14:34:13 -0700 Subject: [PATCH 446/492] Update tf_generated_ops.td to remove some of the manual changes * Update supported dtypes * Add AllTypesMatch<["x", "out"] to Cumsum PiperOrigin-RevId: 302519530 Change-Id: Iaa5e210a4dbc90365ffcabc11a7500ff4c53e537 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index e5bda71323e..9feeee87374 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1259,7 +1259,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect]> { +def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> { let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; let description = [{ @@ -2082,12 +2082,12 @@ with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. }]; let arguments = (ins - TF_FpOrI32OrI64Tensor:$x, - TF_FpOrI32OrI64Tensor:$y + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Uint64]>:$x, + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Uint64]>:$y ); let results = (outs - TF_FpOrI32OrI64Tensor:$z + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Uint64]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3254,7 +3254,7 @@ cublas. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, AllTypesMatch<["input", "band"]>]> { +def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> { let summary = [{ Copy a tensor setting everything outside a central band in each innermost matrix to zero. @@ -6887,12 +6887,12 @@ def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>, }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint8]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -8373,25 +8373,6 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape]>, let hasCanonicalizer = 1; } -def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> { - let summary = [{ -An op which shards the input based on the given sharding attribute. - }]; - - let description = [{ - }]; - - let arguments = (ins - TF_Tensor:$input - ); - - let results = (outs - TF_Tensor:$output - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - def TF_XlaDynamicUpdateSliceOp : TF_Op<"XlaDynamicUpdateSlice", [NoSideEffect]> { let summary = "Wraps the XLA DynamicUpdateSlice operator, documented at"; @@ -8421,6 +8402,25 @@ Handling of out-of-bounds slice indices is implementation-defined. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> { + let summary = [{ +An op which shards the input based on the given sharding attribute. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_Tensor:$input + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of zeros with the same shape and type as x."; From aba2ca46032fcf09ddf925d664d1bb935cf2cd7d Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Mon, 23 Mar 2020 14:37:05 -0700 Subject: [PATCH 447/492] Adds tf.random.experimental.stateless_split and tf.random.experimental.stateless_fold_in to manage seeds for stateless RNGs. PiperOrigin-RevId: 302520177 Change-Id: I43c3d2e2aa5dd26a64e978027df0ecb1f6095140 --- .../random/stateless_random_ops_test.py | 37 +++++++- tensorflow/python/ops/stateless_random_ops.py | 95 +++++++++++++++++-- .../v1/tensorflow.random.experimental.pbtxt | 8 ++ .../v2/tensorflow.random.experimental.pbtxt | 8 ++ 4 files changed, 139 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py index 38325805d76..0b9fbab716c 100644 --- a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py +++ b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import functools +from absl.testing import parameterized import numpy as np from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -27,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import stateless_random_ops as stateless from tensorflow.python.platform import test @@ -49,7 +51,7 @@ def invert_philox(key, value): return np.array(value) -class StatelessOpsTest(test.TestCase): +class StatelessOpsTest(test.TestCase, parameterized.TestCase): def _test_match(self, cases): # Stateless ops should be the same as stateful ops on the first call @@ -194,6 +196,39 @@ class StatelessOpsTest(test.TestCase): def testDeterminismPoisson(self): self._test_determinism(self._poisson_cases()) + def assertDTypeEqual(self, a, b): + self.assertEqual(dtypes.as_dtype(a), dtypes.as_dtype(b)) + + def assertNoEqualPair(self, ls): + for i in range(len(ls)): + for j in range(i + 1, len(ls)): + self.assertFalse(math_ops.reduce_all(ls[i] == ls[j])) + + @parameterized.parameters(['int32', 'int64']) + @test_util.run_v2_only + def testSplit(self, dtype): + """Test for `split`.""" + seed = constant_op.constant([1, 2], dtype=dtype) + new_seed = stateless.split(seed, 3) + self.assertEqual(new_seed.shape, [3, 2]) + self.assertDTypeEqual(new_seed.dtype, dtype) + self.assertNoEqualPair([seed] + array_ops.unstack(new_seed)) + + @parameterized.parameters(['int32', 'int64']) + @test_util.run_v2_only + def testFoldIn(self, dtype): + """Test for `fold_in`.""" + orig_seed = constant_op.constant([1, 2], dtype='int32') + seed = stateless.fold_in(orig_seed, constant_op.constant(3, dtype=dtype)) + new_seeds = [] + new_seeds.append(seed) + seed = stateless.fold_in(seed, constant_op.constant(4, dtype=dtype)) + new_seeds.append(seed) + for s in new_seeds: + self.assertEqual(s.shape, [2]) + self.assertDTypeEqual(s.dtype, dtype) + self.assertNoEqualPair([math_ops.cast(orig_seed, dtype)] + new_seeds) + @test_util.run_v2_only def testErrors(self): """Tests that proper errors are raised. diff --git a/tensorflow/python/ops/stateless_random_ops.py b/tensorflow/python/ops/stateless_random_ops.py index eb3a2d18b3a..2bf53d3a0f7 100644 --- a/tensorflow/python/ops/stateless_random_ops.py +++ b/tensorflow/python/ops/stateless_random_ops.py @@ -39,6 +39,77 @@ ops.NotDifferentiable("StatelessRandomUniformFullInt") ops.NotDifferentiable("StatelessTruncatedNormal") +@tf_export("random.experimental.stateless_split") +def split(seed, num=2): + """Splits an RNG seed into `num` new seeds by adding a leading axis. + + Example: + + >>> seed = [1, 2] + >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3) + >>> print(new_seeds) + tf.Tensor( + [[1105988140 1738052849] + [-335576002 370444179] + [ 10670227 -246211131]], shape=(3, 2), dtype=int32) + >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :]) + + + Args: + seed: an RNG seed (a tensor with shape [2] and dtype `int32` or + `int64`). (When using XLA, only `int32` is allowed.) + num: optional, a positive integer or scalar tensor indicating the number of + seeds to produce (default 2). + + Returns: + A tensor with shape [num, 2] representing `num` new seeds. It will have the + same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype + will be determined by `tf.convert_to_tensor`). + """ + seed = ops.convert_to_tensor(seed) + return stateless_random_uniform(shape=[num, 2], seed=seed, dtype=seed.dtype, + minval=None, maxval=None) + + +@tf_export("random.experimental.stateless_fold_in") +def fold_in(seed, data): + """Folds in data to an RNG seed to form a new RNG seed. + + For example, in a distributed-training setting, suppose we have a master seed + and a replica ID. We want to fold the replica ID into the master seed to + form a "replica seed" to be used by that replica later on, so that different + replicas will generate different random numbers but the reproducibility of the + whole system can still be controlled by the master seed: + + >>> master_seed = [1, 2] + >>> replica_id = 3 + >>> replica_seed = tf.random.experimental.stateless_fold_in( + ... master_seed, replica_id) + >>> print(replica_seed) + tf.Tensor([1105988140 3], shape=(2,), dtype=int32) + >>> tf.random.stateless_normal(shape=[3], seed=replica_seed) + + + Args: + seed: an RNG seed (a tensor with shape [2] and dtype `int32` or + `int64`). (When using XLA, only `int32` is allowed.) + data: an `int32` or `int64` scalar representing data to be folded in to the + seed. + + Returns: + A new RNG seed that is a deterministic function of the inputs and is + statistically safe for producing a stream of new pseudo-random values. It + will have the same dtype as `data` (if `data` doesn't have an explict dtype, + the dtype will be determined by `tf.convert_to_tensor`). + """ + data = ops.convert_to_tensor(data) + seed1 = stateless_random_uniform(shape=[], seed=seed, dtype=data.dtype, + minval=None, maxval=None) + return array_ops.stack([seed1, data]) + + @tf_export("random.stateless_uniform") def stateless_random_uniform(shape, seed, @@ -77,7 +148,8 @@ def stateless_random_uniform(shape, Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. - seed: A shape [2] integer Tensor of seeds to the random number generator. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) minval: A Tensor or Python value of type `dtype`, broadcastable with `shape` (for integer types, broadcasting is not supported, so it needs to be a scalar). The lower bound on the range of random values to @@ -170,7 +242,8 @@ def stateless_random_binomial(shape, Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. - seed: A shape [2] integer Tensor of seeds to the random number generator. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) counts: Tensor. The counts of the binomial distribution. Must be broadcastable with `probs`, and broadcastable with the rightmost dimensions of `shape`. @@ -264,7 +337,8 @@ def stateless_random_gamma(shape, Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. - seed: A shape [2] integer Tensor of seeds to the random number generator. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) alpha: Tensor. The concentration parameter of the gamma distribution. Must be broadcastable with `beta`, and broadcastable with the rightmost dimensions of `shape`. @@ -336,7 +410,8 @@ def stateless_random_poisson(shape, Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. - seed: A shape [2] integer Tensor of seeds to the random number generator. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape must match the rightmost dimensions of `shape`. dtype: Dtype of the samples (int or float dtypes are permissible, as samples @@ -375,7 +450,8 @@ def stateless_random_normal(shape, Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. - seed: A shape [2] integer Tensor of seeds to the random number generator. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal distribution. stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation @@ -419,7 +495,8 @@ def stateless_truncated_normal(shape, Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. - seed: A shape [2] integer Tensor of seeds to the random number generator. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) mean: A 0-D Tensor or Python value of type `dtype`. The mean of the truncated normal distribution. stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation @@ -471,7 +548,8 @@ def stateless_multinomial(logits, logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` represents the unnormalized log-probabilities for all classes. num_samples: 0-D. Number of independent samples to draw for each row slice. - seed: A shape [2] integer Tensor of seeds to the random number generator. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) output_dtype: integer type to use for the output. Defaults to int64. name: Optional name for the operation. @@ -510,7 +588,8 @@ def stateless_categorical(logits, logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` represents the unnormalized log-probabilities for all classes. num_samples: 0-D. Number of independent samples to draw for each row slice. - seed: A shape [2] integer Tensor of seeds to the random number generator. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) dtype: integer type to use for the output. Defaults to int64. name: Optional name for the operation. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.pbtxt index 73f7497934e..0b4ab252d1b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.pbtxt @@ -20,4 +20,12 @@ tf_module { name: "set_global_generator" argspec: "args=[\'generator\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "stateless_fold_in" + argspec: "args=[\'seed\', \'data\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stateless_split" + argspec: "args=[\'seed\', \'num\'], varargs=None, keywords=None, defaults=[\'2\'], " + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.pbtxt index 73f7497934e..0b4ab252d1b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.pbtxt @@ -20,4 +20,12 @@ tf_module { name: "set_global_generator" argspec: "args=[\'generator\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "stateless_fold_in" + argspec: "args=[\'seed\', \'data\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stateless_split" + argspec: "args=[\'seed\', \'num\'], varargs=None, keywords=None, defaults=[\'2\'], " + } } From 79b8c700d416c29ea886f6f27cb78f27db810e1a Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 23 Mar 2020 14:47:18 -0700 Subject: [PATCH 448/492] Roll forward change to run use MLIR based TensorFlow compiler in XLA on demand compiler This splits compile_mlir_util lib into two parts. One with TF dialect passes that includes TF constant folding hook and other without it. Constant folding hook depends on the TF eager so splitting the library into two parts is required to avoid the circular dependency. PiperOrigin-RevId: 302522554 Change-Id: I4f8f0a8e745a9becff3845cc59950f181e6f415a --- tensorflow/compiler/jit/BUILD | 1 + .../compiler/jit/xla_compilation_cache.cc | 29 ++++++- .../compiler/jit/xla_compilation_cache.h | 4 +- tensorflow/compiler/mlir/tensorflow/BUILD | 76 ++++++++++++------ .../tensorflow/utils/compile_mlir_util.cc | 59 +++++++++++--- .../mlir/tensorflow/utils/compile_mlir_util.h | 10 +++ .../utils/compile_mlir_util_test.cc | 39 +++++++++ tensorflow/compiler/tests/BUILD | 20 +++++ .../compiler/tests/unary_mlir_ops_test.py | 80 +++++++++++++++++++ tensorflow/tools/lib_package/BUILD | 2 + 10 files changed, 278 insertions(+), 42 deletions(-) create mode 100644 tensorflow/compiler/tests/unary_mlir_ops_test.py diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f71331af0df..f44a0253464 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -338,6 +338,7 @@ cc_library( deps = [ ":xla_activity_listener", ":xla_activity_proto_cc", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 5540fee7276..5081df28a08 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" @@ -273,8 +276,30 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - return compiler->CompileGraph(compile_options, node_def.name(), - std::move(graph), args, result); + + bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kParameter; + }); + const ConfigProto* config = ctx->function_library()->config_proto(); + bool use_mlir = config && config->experimental().enable_mlir_bridge(); + // Use MLIR bridge if all the arguments are parameters. + // TODO(hinsu): Support other argument types instead of silently falling + // back to the XLA compiler. + if (!are_params || !use_mlir) { + return compiler->CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result); + } + + absl::InlinedVector arg_shapes; + arg_shapes.reserve(args.size()); + for (const XlaCompiler::Argument& arg : args) { + arg_shapes.push_back(absl::get(arg.shape)); + } + GraphDebugInfo debug_info; + return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()}, + compile_options.use_tuple_arg, + *options.flib_def, debug_info, + options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 83a0bda97d5..cd58cf31988 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -78,7 +78,9 @@ class XlaCompilationCache : public ResourceBase { xla::LocalExecutable** out_executable); // As above, but calls XlaCompiler::CompileSingleOp instead of - // XlaCompiler::CompileFunction. + // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto + // in OpKernelContext, then uses MLIR bridge for compilation instead of + // XlaCompiler, if possible. Status CompileSingleOp( const XlaCompiler::Options& options, absl::Span args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 8ac33c906bb..7b088cad715 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1052,36 +1052,58 @@ gentbl( ], ) +COMPILE_MLIR_UTIL_DEPS = [ + ":bridge_logger", + ":convert_graphdef", + ":convert_type", + ":dump_mlir_util", + ":error_util", + ":mlir_roundtrip_flags", + ":tensorflow", + ":tensorflow_dialect_registration", + ":tensorflow_passes", + ":translate_utils", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", + "//tensorflow/compiler/mlir/xla:type_to_shape", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:logging", + "//tensorflow/stream_executor/lib", +] + +# Prefer to link 'compile_mlir_util' library that also links necessary +# TensorFlow passes to the pipeline. This library without tf passes is useful +# if the constant folding is not required on the TensorFlow dialect. For +# example, this is used in XLA ondemand compilation which compiles a single op +# at a time and doesn't require constant folding. Doing so helps avoid a +# circular dependency between c_api and tf passes. +# TODO(hinsu): Split out the constant folding hook and only exclude that in +# this target. cc_library( - name = "compile_mlir_util", + name = "compile_mlir_util_no_tf_dialect_passes", srcs = ["utils/compile_mlir_util.cc"], hdrs = ["utils/compile_mlir_util.h"], - deps = [ - ":bridge_logger", - ":convert_type", - ":dump_mlir_util", - ":error_util", - ":tensorflow", - ":tensorflow_dialect_registration", - ":tensorflow_passes", + deps = COMPILE_MLIR_UTIL_DEPS, +) + +cc_library( + name = "compile_mlir_util", + hdrs = ["utils/compile_mlir_util.h"], + deps = COMPILE_MLIR_UTIL_DEPS + [ + "compile_mlir_util_no_tf_dialect_passes", ":tf_dialect_passes", - ":translate_utils", - "//tensorflow/compiler/mlir/xla:hlo", - "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", - "//tensorflow/compiler/mlir/xla:type_to_shape", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/core:framework", - "//tensorflow/core/platform:logging", - "//tensorflow/stream_executor/lib", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", ], ) @@ -1096,8 +1118,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/stream_executor/lib", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 5394dbfb21a..3fd711b9ef8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" @@ -276,19 +278,11 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, return Status::OK(); } -Status CompileSerializedMlirToXlaHlo( - llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, +static Status CompileMlirToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { - RegisterDialects(); - mlir::MLIRContext mlir_context; - mlir::OwningModuleRef mlir_module; - - TF_RETURN_IF_ERROR( - ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - auto module_op = mlir_module.get(); - if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -309,9 +303,14 @@ Status CompileSerializedMlirToXlaHlo( GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping); auto shape_representation_fn_no_fast_memory = - [shape_representation_fn](const TensorShape& shape, DataType dtype) { - return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); - }; + [shape_representation_fn](const TensorShape& shape, + DataType dtype) -> StatusOr { + if (shape_representation_fn) + return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; // Compute all input shapes. TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, @@ -333,4 +332,38 @@ Status CompileSerializedMlirToXlaHlo( return Status::OK(); } +Status CompileSerializedMlirToXlaHlo( + llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, + bool use_tuple_args, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + RegisterDialects(); + mlir::MLIRContext mlir_context; + mlir::OwningModuleRef mlir_module; + + TF_RETURN_IF_ERROR( + ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); + return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args, + shape_representation_fn, compilation_result); +} + +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + RegisterDialects(); + mlir::MLIRContext context; + GraphImportConfig config; + config.graph_as_function = true; + auto module_or = + ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); + if (!module_or.ok()) return module_or.status(); + + return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, + use_tuple_args, shape_representation_fn, + compilation_result); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 19423adfe17..0dd4b8c5efe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -53,6 +54,15 @@ Status CompileSerializedMlirToXlaHlo( bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); + +// Same as the above but takes input as TensorFlow Graph. +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 7db3d34a4ad..f65fcc1016d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -20,6 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -285,5 +288,41 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { status_or_hlo_module.ValueOrDie()->ToString()); } +// Verify that conversion from Graph to MLIR and empty shape representation +// function is successful. +TEST(CompileGraphToXlaHlo, Basic) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + Graph graph(OpRegistry::Global()); + + Tensor dummy_tensor(DT_FLOAT, TensorShape({1})); + test::FillValues(&dummy_tensor, {-1.0}); + + Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT); + test::graph::Retval(&graph, 0, arg); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(CompileGraphToXlaHlo( + graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def, + GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); + + const xla::HloModuleConfig module_config( + result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + + string expected_hlo_module_string = R"(HloModule main.3 + +ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { + %Arg_0.1 = f32[] parameter(0) + ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1) +} + +)"; + + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 77cd3dc074c..d586b8178c5 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1354,6 +1354,26 @@ tf_xla_py_test( ], ) +# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it. +tf_xla_py_test( + name = "unary_mlir_ops_test", + size = "medium", + srcs = ["unary_mlir_ops_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fused_batchnorm_test", size = "medium", diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py new file mode 100644 index 00000000000..2b3dec3d5a7 --- /dev/null +++ b/tensorflow/compiler/tests/unary_mlir_ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class UnaryOpsTest(xla_test.XLATestCase): + """Test cases for unary operators.""" + + def __init__(self, method_name='runTest'): + super(UnaryOpsTest, self).__init__(method_name) + context.context().enable_mlir_bridge = True + + def _assertOpOutputMatchesExpected(self, + op, + inp, + expected, + equality_test=None, + rtol=1e-3, + atol=1e-5): + """Verifies that 'op' produces 'expected' when fed input 'inp' . + + Args: + op: operator to test + inp: numpy input array to use as input to 'op'. + expected: numpy array representing the expected output of 'op'. + equality_test: either None, or a function that tests two numpy arrays for + equality. If None, self.assertAllClose is used. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + with self.session() as session: + with self.test_scope(): + pinp = array_ops.placeholder( + dtypes.as_dtype(inp.dtype), inp.shape, name='a') + output = op(pinp) + result = session.run(output, {pinp: inp}) + if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + else: + equality_test(result, expected, rtol=rtol, atol=atol) + + def testNumericOps(self): + # TODO(hinsu): Enable complex types after fixing the failure in export to + # HLOModule. + for dtype in self.numeric_types - {np.int8, np.uint8} - self.complex_types: + self._assertOpOutputMatchesExpected( + math_ops.abs, + np.array([[2, -1]], dtype=dtype), + expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 30ab95e370d..89ec2a0c7c3 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -154,6 +154,7 @@ genrule( "@icu//:icu4c/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", + "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE", @@ -234,6 +235,7 @@ genrule( "@icu//:icu4j/main/shared/licenses/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", + "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE", From 083fc9754e64b9ebb7970dc49a9d13d1a72413cf Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 23 Mar 2020 14:49:55 -0700 Subject: [PATCH 449/492] Move legacy TF RNN cells to Keras and export there. PiperOrigin-RevId: 302523110 Change-Id: Ifd9b6a2ee82bfaa9ca054dfda754257227bb5789 --- tensorflow/python/BUILD | 3 +- .../python/keras/layers/legacy_rnn/BUILD | 54 + .../keras/layers/legacy_rnn/rnn_cell_impl.py | 1355 +++++++++++++++++ .../legacy_rnn/rnn_cell_wrapper_impl.py | 516 +++++++ tensorflow/python/ops/rnn_cell_impl.py | 1353 +--------------- .../python/ops/rnn_cell_wrapper_impl.py | 499 +----- ...perimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt | 4 +- ....experimental.nn.-tf-lite-r-n-n-cell.pbtxt | 4 +- ...flow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt | 6 +- ...orflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt | 6 +- ...nsorflow.nn.rnn_cell.-device-wrapper.pbtxt | 8 +- ...sorflow.nn.rnn_cell.-dropout-wrapper.pbtxt | 8 +- .../tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt | 6 +- ...tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt | 6 +- ...low.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt | 4 +- ...orflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt | 4 +- .../tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt | 2 +- ...orflow.nn.rnn_cell.-residual-wrapper.pbtxt | 8 +- ...orflow.nn.-r-n-n-cell-device-wrapper.pbtxt | 2 +- ...rflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt | 2 +- ...flow.nn.-r-n-n-cell-residual-wrapper.pbtxt | 2 +- 21 files changed, 1991 insertions(+), 1861 deletions(-) create mode 100644 tensorflow/python/keras/layers/legacy_rnn/BUILD create mode 100644 tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py create mode 100644 tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 74df39049cc..1669508ac4f 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4043,7 +4043,6 @@ py_library( ":nn_ops", ":nn_ops_gen", ":platform_device_context", - ":rnn", ":sparse_ops", ":util", ":variables", @@ -4245,6 +4244,8 @@ py_library( ":util", ":variable_scope", ":variables", + "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_impl", + "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_wrapper_impl", ], ) diff --git a/tensorflow/python/keras/layers/legacy_rnn/BUILD b/tensorflow/python/keras/layers/legacy_rnn/BUILD new file mode 100644 index 00000000000..4d3b4a4c852 --- /dev/null +++ b/tensorflow/python/keras/layers/legacy_rnn/BUILD @@ -0,0 +1,54 @@ +# Description: +# Contains the legacy TF RNN APIs (internal TensorFlow version). + +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "rnn_cell_impl", + srcs = ["rnn_cell_impl.py"], + deps = [ + ":rnn_cell_wrapper_impl", + "//tensorflow/python:array_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers_base", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:platform", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/keras:activations", + "//tensorflow/python/keras:initializers", + "//tensorflow/python/keras/engine:input_spec", + "//tensorflow/python/keras/utils:tf_utils", + "//tensorflow/python/training/tracking:base", + ], +) + +py_library( + name = "rnn_cell_wrapper_impl", + srcs = ["rnn_cell_wrapper_impl.py"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:util", + "//tensorflow/python/keras/utils:generic_utils", + ], +) diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py new file mode 100644 index 00000000000..1d03780d51b --- /dev/null +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py @@ -0,0 +1,1355 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module implementing RNN Cells. + +This module provides a number of basic commonly used RNN cells, such as LSTM +(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of +operators that allow adding dropouts, projections, or embeddings for inputs. +Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by +calling the `rnn` ops several times. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import activations +from tensorflow.python.keras import initializers +from tensorflow.python.keras.engine import input_spec +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import nest +from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.tf_export import tf_export + +_BIAS_VARIABLE_NAME = "bias" +_WEIGHTS_VARIABLE_NAME = "kernel" + +# This can be used with self.assertRaisesRegexp for assert_like_rnncell. +ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" + + +def _hasattr(obj, attr_name): + try: + getattr(obj, attr_name) + except AttributeError: + return False + else: + return True + + +def assert_like_rnncell(cell_name, cell): + """Raises a TypeError if cell is not like an RNNCell. + + NOTE: Do not rely on the error message (in particular in tests) which can be + subject to change to increase readability. Use + ASSERT_LIKE_RNNCELL_ERROR_REGEXP. + + Args: + cell_name: A string to give a meaningful error referencing to the name of + the functionargument. + cell: The object which should behave like an RNNCell. + + Raises: + TypeError: A human-friendly exception. + """ + conditions = [ + _hasattr(cell, "output_size"), + _hasattr(cell, "state_size"), + _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"), + callable(cell), + ] + errors = [ + "'output_size' property is missing", "'state_size' property is missing", + "either 'zero_state' or 'get_initial_state' method is required", + "is not callable" + ] + + if not all(conditions): + + errors = [error for error, cond in zip(errors, conditions) if not cond] + raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( + cell_name, cell, ", ".join(errors))) + + +def _concat(prefix, suffix, static=False): + """Concat that enables int, Tensor, or TensorShape values. + + This function takes a size specification, which can be an integer, a + TensorShape, or a Tensor, and converts it into a concatenated Tensor + (if static = False) or a list of integers (if static = True). + + Args: + prefix: The prefix; usually the batch size (and/or time step size). + (TensorShape, int, or Tensor.) + suffix: TensorShape, int, or Tensor. + static: If `True`, return a python list with possibly unknown dimensions. + Otherwise return a `Tensor`. + + Returns: + shape: the concatenation of prefix and suffix. + + Raises: + ValueError: if `suffix` is not a scalar or vector (or TensorShape). + ValueError: if prefix or suffix was `None` and asked for dynamic + Tensors out. + """ + if isinstance(prefix, ops.Tensor): + p = prefix + p_static = tensor_util.constant_value(prefix) + if p.shape.ndims == 0: + p = array_ops.expand_dims(p, 0) + elif p.shape.ndims != 1: + raise ValueError("prefix tensor must be either a scalar or vector, " + "but saw tensor: %s" % p) + else: + p = tensor_shape.as_shape(prefix) + p_static = p.as_list() if p.ndims is not None else None + p = ( + constant_op.constant(p.as_list(), dtype=dtypes.int32) + if p.is_fully_defined() else None) + if isinstance(suffix, ops.Tensor): + s = suffix + s_static = tensor_util.constant_value(suffix) + if s.shape.ndims == 0: + s = array_ops.expand_dims(s, 0) + elif s.shape.ndims != 1: + raise ValueError("suffix tensor must be either a scalar or vector, " + "but saw tensor: %s" % s) + else: + s = tensor_shape.as_shape(suffix) + s_static = s.as_list() if s.ndims is not None else None + s = ( + constant_op.constant(s.as_list(), dtype=dtypes.int32) + if s.is_fully_defined() else None) + + if static: + shape = tensor_shape.as_shape(p_static).concatenate(s_static) + shape = shape.as_list() if shape.ndims is not None else None + else: + if p is None or s is None: + raise ValueError("Provided a prefix or suffix of None: %s and %s" % + (prefix, suffix)) + shape = array_ops.concat((p, s), 0) + return shape + + +def _zero_state_tensors(state_size, batch_size, dtype): + """Create tensors of zeros based on state_size, batch_size, and dtype.""" + + def get_state_shape(s): + """Combine s with batch_size to get a proper tensor shape.""" + c = _concat(batch_size, s) + size = array_ops.zeros(c, dtype=dtype) + if not context.executing_eagerly(): + c_static = _concat(batch_size, s, static=True) + size.set_shape(c_static) + return size + + return nest.map_structure(get_state_shape, state_size) + + +@tf_export(v1=["nn.rnn_cell.RNNCell"]) +class RNNCell(base_layer.Layer): + """Abstract object representing an RNN cell. + + Every `RNNCell` must have the properties below and implement `call` with + the signature `(output, next_state) = call(input, state)`. The optional + third input argument, `scope`, is allowed for backwards compatibility + purposes; but should be left off for new subclasses. + + This definition of cell differs from the definition used in the literature. + In the literature, 'cell' refers to an object with a single scalar output. + This definition refers to a horizontal array of such units. + + An RNN cell, in the most abstract setting, is anything that has + a state and performs some operation that takes a matrix of inputs. + This operation results in an output matrix with `self.output_size` columns. + If `self.state_size` is an integer, this operation also results in a new + state matrix with `self.state_size` columns. If `self.state_size` is a + (possibly nested tuple of) TensorShape object(s), then it should return a + matching structure of Tensors having shape `[batch_size].concatenate(s)` + for each `s` in `self.batch_size`. + """ + + def __init__(self, trainable=True, name=None, dtype=None, **kwargs): + super(RNNCell, self).__init__( + trainable=trainable, name=name, dtype=dtype, **kwargs) + # Attribute that indicates whether the cell is a TF RNN cell, due the slight + # difference between TF and Keras RNN cell. Notably the state is not wrapped + # in a list for TF cell where they are single tensor state, whereas keras + # cell will wrap the state into a list, and call() will have to unwrap them. + self._is_tf_rnn_cell = True + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: if `self.state_size` is an integer, this should be a `2-D Tensor` + with shape `[batch_size, self.state_size]`. Otherwise, if + `self.state_size` is a tuple of integers, this should be a tuple with + shapes `[batch_size, s] for s in self.state_size`. + scope: VariableScope for the created subgraph; defaults to class name. + + Returns: + A pair containing: + + - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. + - New state: Either a single `2-D` tensor, or a tuple of tensors matching + the arity and shapes of `state`. + """ + if scope is not None: + with vs.variable_scope( + scope, custom_getter=self._rnn_get_variable) as scope: + return super(RNNCell, self).__call__(inputs, state, scope=scope) + else: + scope_attrname = "rnncell_scope" + scope = getattr(self, scope_attrname, None) + if scope is None: + scope = vs.variable_scope( + vs.get_variable_scope(), custom_getter=self._rnn_get_variable) + setattr(self, scope_attrname, scope) + with scope: + return super(RNNCell, self).__call__(inputs, state) + + def _rnn_get_variable(self, getter, *args, **kwargs): + variable = getter(*args, **kwargs) + if context.executing_eagerly(): + trainable = variable._trainable # pylint: disable=protected-access + else: + trainable = ( + variable in tf_variables.trainable_variables() or + (isinstance(variable, tf_variables.PartitionedVariable) and + list(variable)[0] in tf_variables.trainable_variables())) + if trainable and all(variable is not v for v in self._trainable_weights): + self._trainable_weights.append(variable) + elif not trainable and all( + variable is not v for v in self._non_trainable_weights): + self._non_trainable_weights.append(variable) + return variable + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + + It can be represented by an Integer, a TensorShape or a tuple of Integers + or TensorShapes. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def build(self, _): + # This tells the parent Layer object that it's OK to call + # self.add_variable() inside the call() method. + pass + + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + if inputs is not None: + # Validate the given batch_size and dtype against inputs if provided. + inputs = ops.convert_to_tensor(inputs, name="inputs") + if batch_size is not None: + if tensor_util.is_tensor(batch_size): + static_batch_size = tensor_util.constant_value( + batch_size, partial=True) + else: + static_batch_size = batch_size + if inputs.shape.dims[0].value != static_batch_size: + raise ValueError( + "batch size from input tensor is different from the " + "input param. Input tensor batch: {}, batch_size: {}".format( + inputs.shape.dims[0].value, batch_size)) + + if dtype is not None and inputs.dtype != dtype: + raise ValueError( + "dtype from input tensor is different from the " + "input param. Input tensor dtype: {}, dtype: {}".format( + inputs.dtype, dtype)) + + batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0] + dtype = inputs.dtype + if batch_size is None or dtype is None: + raise ValueError( + "batch_size and dtype cannot be None while constructing initial " + "state: batch_size={}, dtype={}".format(batch_size, dtype)) + return self.zero_state(batch_size, dtype) + + def zero_state(self, batch_size, dtype): + """Return zero-filled state tensor(s). + + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + + Returns: + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size, state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with + the shapes `[batch_size, s]` for each s in `state_size`. + """ + # Try to use the last cached zero_state. This is done to avoid recreating + # zeros, especially when eager execution is enabled. + state_size = self.state_size + is_eager = context.executing_eagerly() + if is_eager and _hasattr(self, "_last_zero_state"): + (last_state_size, last_batch_size, last_dtype, + last_output) = getattr(self, "_last_zero_state") + if (last_batch_size == batch_size and last_dtype == dtype and + last_state_size == state_size): + return last_output + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + output = _zero_state_tensors(state_size, batch_size, dtype) + if is_eager: + self._last_zero_state = (state_size, batch_size, dtype, output) + return output + + # TODO(b/134773139): Remove when contrib RNN cells implement `get_config` + def get_config(self): # pylint: disable=useless-super-delegation + return super(RNNCell, self).get_config() + + +class LayerRNNCell(RNNCell): + """Subclass of RNNCells that act like proper `tf.Layer` objects. + + For backwards compatibility purposes, most `RNNCell` instances allow their + `call` methods to instantiate variables via `tf.compat.v1.get_variable`. The + underlying + variable scope thus keeps track of any variables, and returning cached + versions. This is atypical of `tf.layer` objects, which separate this + part of layer building into a `build` method that is only called once. + + Here we provide a subclass for `RNNCell` objects that act exactly as + `Layer` objects do. They must provide a `build` method and their + `call` methods do not access Variables `tf.compat.v1.get_variable`. + """ + + def __call__(self, inputs, state, scope=None, *args, **kwargs): + """Run this RNN cell on inputs, starting from the given state. + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: if `self.state_size` is an integer, this should be a `2-D Tensor` + with shape `[batch_size, self.state_size]`. Otherwise, if + `self.state_size` is a tuple of integers, this should be a tuple with + shapes `[batch_size, s] for s in self.state_size`. + scope: optional cell scope. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + A pair containing: + + - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. + - New state: Either a single `2-D` tensor, or a tuple of tensors matching + the arity and shapes of `state`. + """ + # Bypass RNNCell's variable capturing semantics for LayerRNNCell. + # Instead, it is up to subclasses to provide a proper build + # method. See the class docstring for more details. + return base_layer.Layer.__call__( + self, inputs, state, scope=scope, *args, **kwargs) + + +@tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) +class BasicRNNCell(LayerRNNCell): + """The most basic RNN cell. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU. + + Args: + num_units: int, The number of units in the RNN cell. + activation: Nonlinearity to use. Default: `tanh`. It could also be string + that is within Keras activation function names. + reuse: (optional) Python boolean describing whether to reuse variables in an + existing scope. If not `True`, and the existing scope already has the + given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + """ + + @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell," + " and will be replaced by that in Tensorflow 2.0.") + def __init__(self, + num_units, + activation=None, + reuse=None, + name=None, + dtype=None, + **kwargs): + super(BasicRNNCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + _check_supported_dtypes(self.dtype) + if context.executing_eagerly() and context.num_gpus() > 0: + logging.warn( + "%s: Note that this cell is not optimized for performance. " + "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better " + "performance on GPU.", self) + + # Inputs must be 2-dimensional. + self.input_spec = input_spec.InputSpec(ndim=2) + + self._num_units = num_units + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @tf_utils.shape_type_conversion + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + str(inputs_shape)) + _check_supported_dtypes(self.dtype) + + input_depth = inputs_shape[-1] + self._kernel = self.add_variable( + _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, self._num_units]) + self._bias = self.add_variable( + _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=init_ops.zeros_initializer(dtype=self.dtype)) + + self.built = True + + def call(self, inputs, state): + """Most basic RNN: output = new_state = act(W * input + U * state + B).""" + _check_rnn_cell_input_dtypes([inputs, state]) + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + output = self._activation(gate_inputs) + return output, output + + def get_config(self): + config = { + "num_units": self._num_units, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(BasicRNNCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@tf_export(v1=["nn.rnn_cell.GRUCell"]) +class GRUCell(LayerRNNCell): + """Gated Recurrent Unit cell. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or + `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU. + + Args: + num_units: int, The number of units in the GRU cell. + activation: Nonlinearity to use. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables in an + existing scope. If not `True`, and the existing scope already has the + given variables, an error is raised. + kernel_initializer: (optional) The initializer to use for the weight and + projection matrices. + bias_initializer: (optional) The initializer to use for the bias. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + + References: + Learning Phrase Representations using RNN Encoder Decoder for Statistical + Machine Translation: + [Cho et al., 2014] + (https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179) + ([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf)) + """ + + @deprecated(None, "This class is equivalent as tf.keras.layers.GRUCell," + " and will be replaced by that in Tensorflow 2.0.") + def __init__(self, + num_units, + activation=None, + reuse=None, + kernel_initializer=None, + bias_initializer=None, + name=None, + dtype=None, + **kwargs): + super(GRUCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + _check_supported_dtypes(self.dtype) + + if context.executing_eagerly() and context.num_gpus() > 0: + logging.warn( + "%s: Note that this cell is not optimized for performance. " + "Please use tf.contrib.cudnn_rnn.CudnnGRU for better " + "performance on GPU.", self) + # Inputs must be 2-dimensional. + self.input_spec = input_spec.InputSpec(ndim=2) + + self._num_units = num_units + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + self._kernel_initializer = initializers.get(kernel_initializer) + self._bias_initializer = initializers.get(bias_initializer) + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @tf_utils.shape_type_conversion + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + str(inputs_shape)) + _check_supported_dtypes(self.dtype) + input_depth = inputs_shape[-1] + self._gate_kernel = self.add_variable( + "gates/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_bias = self.add_variable( + "gates/%s" % _BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.constant_initializer(1.0, dtype=self.dtype))) + self._candidate_kernel = self.add_variable( + "candidate/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, self._num_units], + initializer=self._kernel_initializer) + self._candidate_bias = self.add_variable( + "candidate/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) + + self.built = True + + def call(self, inputs, state): + """Gated recurrent unit (GRU) with nunits cells.""" + _check_rnn_cell_input_dtypes([inputs, state]) + + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._gate_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + r_state = r * state + + candidate = math_ops.matmul( + array_ops.concat([inputs, r_state], 1), self._candidate_kernel) + candidate = nn_ops.bias_add(candidate, self._candidate_bias) + + c = self._activation(candidate) + new_h = u * state + (1 - u) * c + return new_h, new_h + + def get_config(self): + config = { + "num_units": self._num_units, + "kernel_initializer": initializers.serialize(self._kernel_initializer), + "bias_initializer": initializers.serialize(self._bias_initializer), + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(GRUCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) + + +@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"]) +class LSTMStateTuple(_LSTMStateTuple): + """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. + + Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state + and `h` is the output. + + Only used when `state_is_tuple=True`. + """ + __slots__ = () + + @property + def dtype(self): + (c, h) = self + if c.dtype != h.dtype: + raise TypeError("Inconsistent internal state: %s vs %s" % + (str(c.dtype), str(h.dtype))) + return c.dtype + + +@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"]) +class BasicLSTMCell(LayerRNNCell): + """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead. + + Basic LSTM recurrent network cell. + + The implementation is based on + + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell` + that follows. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or + `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for + better performance on CPU. + """ + + @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," + " and will be replaced by that in Tensorflow 2.0.") + def __init__(self, + num_units, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None, + dtype=None, + **kwargs): + """Initialize the basic LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + forget_bias: float, The bias added to forget gates (see above). Must set + to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of the + `c_state` and `m_state`. If False, they are concatenated along the + column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. It + could also be string that is within Keras activation function names. + reuse: (optional) Python boolean describing whether to reuse variables in + an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + When restoring from CudnnLSTM-trained checkpoints, must use + `CudnnCompatibleLSTMCell` instead. + """ + super(BasicLSTMCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + _check_supported_dtypes(self.dtype) + if not state_is_tuple: + logging.warn( + "%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if context.executing_eagerly() and context.num_gpus() > 0: + logging.warn( + "%s: Note that this cell is not optimized for performance. " + "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " + "performance on GPU.", self) + + # Inputs must be 2-dimensional. + self.input_spec = input_spec.InputSpec(ndim=2) + + self._num_units = num_units + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + + @property + def state_size(self): + return (LSTMStateTuple(self._num_units, self._num_units) + if self._state_is_tuple else 2 * self._num_units) + + @property + def output_size(self): + return self._num_units + + @tf_utils.shape_type_conversion + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + str(inputs_shape)) + _check_supported_dtypes(self.dtype) + input_depth = inputs_shape[-1] + h_depth = self._num_units + self._kernel = self.add_variable( + _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + h_depth, 4 * self._num_units]) + self._bias = self.add_variable( + _BIAS_VARIABLE_NAME, + shape=[4 * self._num_units], + initializer=init_ops.zeros_initializer(dtype=self.dtype)) + + self.built = True + + def call(self, inputs, state): + """Long short-term memory cell (LSTM). + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, + num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a + `Tensor` shaped `[batch_size, 2 * num_units]`. + + Returns: + A pair containing the new hidden state, and the new state (either a + `LSTMStateTuple` or a concatenated state, depending on + `state_is_tuple`). + """ + _check_rnn_cell_input_dtypes([inputs, state]) + + sigmoid = math_ops.sigmoid + one = constant_op.constant(1, dtype=dtypes.int32) + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) + + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, h], 1), self._kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split( + value=gate_inputs, num_or_size_splits=4, axis=one) + + forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) + # Note that using `add` and `multiply` instead of `+` and `*` gives a + # performance improvement. So using those at the cost of readability. + add = math_ops.add + multiply = math_ops.multiply + new_c = add( + multiply(c, sigmoid(add(f, forget_bias_tensor))), + multiply(sigmoid(i), self._activation(j))) + new_h = multiply(self._activation(new_c), sigmoid(o)) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = array_ops.concat([new_c, new_h], 1) + return new_h, new_state + + def get_config(self): + config = { + "num_units": self._num_units, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(BasicLSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@tf_export(v1=["nn.rnn_cell.LSTMCell"]) +class LSTMCell(LayerRNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + The default non-peephole implementation is based on (Gers et al., 1999). + The peephole implementation is based on (Sak et al., 2014). + + The class uses optional peep-hole connections, optional cell clipping, and + an optional projection layer. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or + `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for + better performance on CPU. + References: + Long short-term memory recurrent neural network architectures for large + scale acoustic modeling: + [Sak et al., 2014] + (https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html) + ([pdf] + (https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf)) + Learning to forget: + [Gers et al., 1999] + (http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218) + ([pdf](https://arxiv.org/pdf/1409.2329.pdf)) + Long Short-Term Memory: + [Hochreiter et al., 1997] + (https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) + ([pdf](http://ml.jku.at/publications/older/3504.pdf)) + """ + + @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," + " and will be replaced by that in Tensorflow 2.0.") + def __init__(self, + num_units, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_proj=None, + proj_clip=None, + num_unit_shards=None, + num_proj_shards=None, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None, + dtype=None, + **kwargs): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + forget_bias: Biases of the forget gate are initialized by default to 1 in + order to reduce the scale of forgetting at the beginning of the + training. Must set it manually to `0.0` when restoring from CudnnLSTM + trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of the + `c_state` and `m_state`. If False, they are concatenated along the + column axis. This latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. It + could also be string that is within Keras activation function names. + reuse: (optional) Python boolean describing whether to reuse variables in + an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + When restoring from CudnnLSTM-trained checkpoints, use + `CudnnCompatibleLSTMCell` instead. + """ + super(LSTMCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + _check_supported_dtypes(self.dtype) + if not state_is_tuple: + logging.warn( + "%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if num_unit_shards is not None or num_proj_shards is not None: + logging.warn( + "%s: The num_unit_shards and proj_unit_shards parameters are " + "deprecated and will be removed in Jan 2017. " + "Use a variable scope with a partitioner instead.", self) + if context.executing_eagerly() and context.num_gpus() > 0: + logging.warn( + "%s: Note that this cell is not optimized for performance. " + "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " + "performance on GPU.", self) + + # Inputs must be 2-dimensional. + self.input_spec = input_spec.InputSpec(ndim=2) + + self._num_units = num_units + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializers.get(initializer) + self._num_proj = num_proj + self._proj_clip = proj_clip + self._num_unit_shards = num_unit_shards + self._num_proj_shards = num_proj_shards + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + + if num_proj: + self._state_size = ( + LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units + + num_proj) + self._output_size = num_proj + else: + self._state_size = ( + LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 * + num_units) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + @tf_utils.shape_type_conversion + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + str(inputs_shape)) + _check_supported_dtypes(self.dtype) + input_depth = inputs_shape[-1] + h_depth = self._num_units if self._num_proj is None else self._num_proj + maybe_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_unit_shards) + if self._num_unit_shards is not None else None) + self._kernel = self.add_variable( + _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + h_depth, 4 * self._num_units], + initializer=self._initializer, + partitioner=maybe_partitioner) + if self.dtype is None: + initializer = init_ops.zeros_initializer + else: + initializer = init_ops.zeros_initializer(dtype=self.dtype) + self._bias = self.add_variable( + _BIAS_VARIABLE_NAME, + shape=[4 * self._num_units], + initializer=initializer) + if self._use_peepholes: + self._w_f_diag = self.add_variable( + "w_f_diag", shape=[self._num_units], initializer=self._initializer) + self._w_i_diag = self.add_variable( + "w_i_diag", shape=[self._num_units], initializer=self._initializer) + self._w_o_diag = self.add_variable( + "w_o_diag", shape=[self._num_units], initializer=self._initializer) + + if self._num_proj is not None: + maybe_proj_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_proj_shards) + if self._num_proj_shards is not None else None) + self._proj_kernel = self.add_variable( + "projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[self._num_units, self._num_proj], + initializer=self._initializer, + partitioner=maybe_proj_partitioner) + + self.built = True + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, must be 2-D, `[batch, input_size]`. + state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, + [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple + of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch, output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + _check_rnn_cell_input_dtypes([inputs, state]) + + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + input_size = inputs.get_shape().with_rank(2).dims[1].value + if input_size is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = math_ops.matmul( + array_ops.concat([inputs, m_prev], 1), self._kernel) + lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) + + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) + # Diagonal connections + if self._use_peepholes: + c = ( + sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) + else: + c = ( + sigmoid(f + self._forget_bias) * c_prev + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + self._w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + m = math_ops.matmul(m, self._proj_kernel) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = ( + LSTMStateTuple(c, m) + if self._state_is_tuple else array_ops.concat([c, m], 1)) + return m, new_state + + def get_config(self): + config = { + "num_units": self._num_units, + "use_peepholes": self._use_peepholes, + "cell_clip": self._cell_clip, + "initializer": initializers.serialize(self._initializer), + "num_proj": self._num_proj, + "proj_clip": self._proj_clip, + "num_unit_shards": self._num_unit_shards, + "num_proj_shards": self._num_proj_shards, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(LSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class _RNNCellWrapperV1(RNNCell): + """Base class for cells wrappers V1 compatibility. + + This class along with `_RNNCellWrapperV2` allows to define cells wrappers that + are compatible with V1 and V2, and defines helper methods for this purpose. + """ + + def __init__(self, cell, *args, **kwargs): + super(_RNNCellWrapperV1, self).__init__(*args, **kwargs) + assert_like_rnncell("cell", cell) + self.cell = cell + if isinstance(cell, trackable.Trackable): + self._track_trackable(self.cell, name="cell") + + def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): + """Calls the wrapped cell and performs the wrapping logic. + + This method is called from the wrapper's `call` or `__call__` methods. + + Args: + inputs: A tensor with wrapped cell's input. + state: A tensor or tuple of tensors with wrapped cell's state. + cell_call_fn: Wrapped cell's method to use for step computation (cell's + `__call__` or 'call' method). + **kwargs: Additional arguments. + + Returns: + A pair containing: + - Output: A tensor with cell's output. + - New state: A tensor or tuple of tensors with new wrapped cell's state. + """ + raise NotImplementedError + + def __call__(self, inputs, state, scope=None): + """Runs the RNN cell step computation. + + We assume that the wrapped RNNCell is being built within its `__call__` + method. We directly use the wrapped cell's `__call__` in the overridden + wrapper `__call__` method. + + This allows to use the wrapped cell and the non-wrapped cell equivalently + when using `__call__`. + + Args: + inputs: A tensor with wrapped cell's input. + state: A tensor or tuple of tensors with wrapped cell's state. + scope: VariableScope for the subgraph created in the wrapped cells' + `__call__`. + + Returns: + A pair containing: + + - Output: A tensor with cell's output. + - New state: A tensor or tuple of tensors with new wrapped cell's state. + """ + return self._call_wrapped_cell( + inputs, state, cell_call_fn=self.cell.__call__, scope=scope) + + def get_config(self): + config = { + "cell": { + "class_name": self.cell.__class__.__name__, + "config": self.cell.get_config() + }, + } + base_config = super(_RNNCellWrapperV1, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() + cell = config.pop("cell") + try: + assert_like_rnncell("cell", cell) + return cls(cell, **config) + except TypeError: + raise ValueError("RNNCellWrapper cannot reconstruct the wrapped cell. " + "Please overwrite the cell in the config with a RNNCell " + "instance.") + + +@tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) +class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, + _RNNCellWrapperV1): + """Operator adding dropout to inputs and outputs of the given cell.""" + + def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation + super(DropoutWrapper, self).__init__(*args, **kwargs) + + __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ + + +@tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) +class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, + _RNNCellWrapperV1): + """RNNCell wrapper that ensures cell inputs are added to the outputs.""" + + def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation + super(ResidualWrapper, self).__init__(*args, **kwargs) + + __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ + + +@tf_export(v1=["nn.rnn_cell.DeviceWrapper"]) +class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, + _RNNCellWrapperV1): + + def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation + super(DeviceWrapper, self).__init__(*args, **kwargs) + + __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__ + + +@tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) +class MultiRNNCell(RNNCell): + """RNN cell composed sequentially of multiple simple cells. + + Example: + + ```python + num_units = [128, 64] + cells = [BasicLSTMCell(num_units=n) for n in num_units] + stacked_rnn_cell = MultiRNNCell(cells) + ``` + """ + + @deprecated(None, "This class is equivalent as " + "tf.keras.layers.StackedRNNCells, and will be replaced by " + "that in Tensorflow 2.0.") + def __init__(self, cells, state_is_tuple=True): + """Create a RNN cell composed sequentially of a number of RNNCells. + + Args: + cells: list of RNNCells that will be composed in this order. + state_is_tuple: If True, accepted and returned states are n-tuples, where + `n = len(cells)`. If False, the states are all concatenated along the + column axis. This latter behavior will soon be deprecated. + + Raises: + ValueError: if cells is empty (not allowed), or at least one of the cells + returns a state tuple but the flag `state_is_tuple` is `False`. + """ + super(MultiRNNCell, self).__init__() + if not cells: + raise ValueError("Must specify at least one cell for MultiRNNCell.") + if not nest.is_sequence(cells): + raise TypeError("cells must be a list or tuple, but saw: %s." % cells) + + if len(set(id(cell) for cell in cells)) < len(cells): + logging.log_first_n( + logging.WARN, "At least two cells provided to MultiRNNCell " + "are the same object and will share weights.", 1) + + self._cells = cells + for cell_number, cell in enumerate(self._cells): + # Add Trackable dependencies on these cells so their variables get + # saved with this object when using object-based saving. + if isinstance(cell, trackable.Trackable): + # TODO(allenl): Track down non-Trackable callers. + self._track_trackable(cell, name="cell-%d" % (cell_number,)) + self._state_is_tuple = state_is_tuple + if not state_is_tuple: + if any(nest.is_sequence(c.state_size) for c in self._cells): + raise ValueError("Some cells return tuples of states, but the flag " + "state_is_tuple is not set. State sizes are: %s" % + str([c.state_size for c in self._cells])) + + @property + def state_size(self): + if self._state_is_tuple: + return tuple(cell.state_size for cell in self._cells) + else: + return sum(cell.state_size for cell in self._cells) + + @property + def output_size(self): + return self._cells[-1].output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + if self._state_is_tuple: + return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) + else: + # We know here that state_size of each cell is not a tuple and + # presumably does not contain TensorArrays or anything else fancy + return super(MultiRNNCell, self).zero_state(batch_size, dtype) + + @property + def trainable_weights(self): + if not self.trainable: + return [] + weights = [] + for cell in self._cells: + if isinstance(cell, base_layer.Layer): + weights += cell.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for cell in self._cells: + if isinstance(cell, base_layer.Layer): + weights += cell.non_trainable_weights + if not self.trainable: + trainable_weights = [] + for cell in self._cells: + if isinstance(cell, base_layer.Layer): + trainable_weights += cell.trainable_weights + return trainable_weights + weights + return weights + + def call(self, inputs, state): + """Run this multi-layer cell on inputs, starting from state.""" + cur_state_pos = 0 + cur_inp = inputs + new_states = [] + for i, cell in enumerate(self._cells): + with vs.variable_scope("cell_%d" % i): + if self._state_is_tuple: + if not nest.is_sequence(state): + raise ValueError( + "Expected state to be a tuple of length %d, but received: %s" % + (len(self.state_size), state)) + cur_state = state[i] + else: + cur_state = array_ops.slice(state, [0, cur_state_pos], + [-1, cell.state_size]) + cur_state_pos += cell.state_size + cur_inp, new_state = cell(cur_inp, cur_state) + new_states.append(new_state) + + new_states = ( + tuple(new_states) if self._state_is_tuple else array_ops.concat( + new_states, 1)) + + return cur_inp, new_states + + +def _check_rnn_cell_input_dtypes(inputs): + """Check whether the input tensors are with supported dtypes. + + Default RNN cells only support floats and complex as its dtypes since the + activation function (tanh and sigmoid) only allow those types. This function + will throw a proper error message if the inputs is not in a supported type. + + Args: + inputs: tensor or nested structure of tensors that are feed to RNN cell as + input or state. + + Raises: + ValueError: if any of the input tensor are not having dtypes of float or + complex. + """ + for t in nest.flatten(inputs): + _check_supported_dtypes(t.dtype) + + +def _check_supported_dtypes(dtype): + if dtype is None: + return + dtype = dtypes.as_dtype(dtype) + if not (dtype.is_floating or dtype.is_complex): + raise ValueError("RNN cell only supports floating point inputs, " + "but saw dtype: %s" % dtype) diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py new file mode 100644 index 00000000000..62a6baa5640 --- /dev/null +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py @@ -0,0 +1,516 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module contains the implementation of RNN cell wrappers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import numbers +import sys +import types as python_types +import warnings + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.util import nest + + +class DropoutWrapperBase(object): + """Operator adding dropout to inputs and outputs of the given cell.""" + + def __init__(self, + cell, + input_keep_prob=1.0, + output_keep_prob=1.0, + state_keep_prob=1.0, + variational_recurrent=False, + input_size=None, + dtype=None, + seed=None, + dropout_state_filter_visitor=None, + **kwargs): + """Create a cell with added input, state, and/or output dropout. + + If `variational_recurrent` is set to `True` (**NOT** the default behavior), + then the same dropout mask is applied at every step, as described in: + [A Theoretically Grounded Application of Dropout in Recurrent + Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287). + + Otherwise a different dropout mask is applied at every time step. + + Note, by default (unless a custom `dropout_state_filter` is provided), + the memory state (`c` component of any `LSTMStateTuple`) passing through + a `DropoutWrapper` is never modified. This behavior is described in the + above article. + + Args: + cell: an RNNCell, a projection to output_size is added to it. + input_keep_prob: unit Tensor or float between 0 and 1, input keep + probability; if it is constant and 1, no input dropout will be added. + output_keep_prob: unit Tensor or float between 0 and 1, output keep + probability; if it is constant and 1, no output dropout will be added. + state_keep_prob: unit Tensor or float between 0 and 1, output keep + probability; if it is constant and 1, no output dropout will be added. + State dropout is performed on the outgoing states of the cell. **Note** + the state components to which dropout is applied when `state_keep_prob` + is in `(0, 1)` are also determined by the argument + `dropout_state_filter_visitor` (e.g. by default dropout is never applied + to the `c` component of an `LSTMStateTuple`). + variational_recurrent: Python bool. If `True`, then the same dropout + pattern is applied across all time steps per run call. If this parameter + is set, `input_size` **must** be provided. + input_size: (optional) (possibly nested tuple of) `TensorShape` objects + containing the depth(s) of the input tensors expected to be passed in to + the `DropoutWrapper`. Required and used **iff** `variational_recurrent + = True` and `input_keep_prob < 1`. + dtype: (optional) The `dtype` of the input, state, and output tensors. + Required and used **iff** `variational_recurrent = True`. + seed: (optional) integer, the randomness seed. + dropout_state_filter_visitor: (optional), default: (see below). Function + that takes any hierarchical level of the state and returns a scalar or + depth=1 structure of Python booleans describing which terms in the state + should be dropped out. In addition, if the function returns `True`, + dropout is applied across this sublevel. If the function returns + `False`, dropout is not applied across this entire sublevel. + Default behavior: perform dropout on all terms except the memory (`c`) + state of `LSTMCellState` objects, and don't try to apply dropout to + `TensorArray` objects: ``` + def dropout_state_filter_visitor(s): + if isinstance(s, LSTMCellState): # Never perform dropout on the c + state. return LSTMCellState(c=False, h=True) + elif isinstance(s, TensorArray): return False return True ``` + **kwargs: dict of keyword arguments for base layer. + + Raises: + TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided + but not `callable`. + ValueError: if any of the keep_probs are not between 0 and 1. + """ + super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs) + + if (dropout_state_filter_visitor is not None and + not callable(dropout_state_filter_visitor)): + raise TypeError("dropout_state_filter_visitor must be callable") + self._dropout_state_filter = ( + dropout_state_filter_visitor or _default_dropout_state_filter_visitor) + with ops.name_scope("DropoutWrapperInit"): + + def tensor_and_const_value(v): + tensor_value = ops.convert_to_tensor(v) + const_value = tensor_util.constant_value(tensor_value) + return (tensor_value, const_value) + + for prob, attr in [(input_keep_prob, "input_keep_prob"), + (state_keep_prob, "state_keep_prob"), + (output_keep_prob, "output_keep_prob")]: + tensor_prob, const_prob = tensor_and_const_value(prob) + if const_prob is not None: + if const_prob < 0 or const_prob > 1: + raise ValueError("Parameter %s must be between 0 and 1: %d" % + (attr, const_prob)) + setattr(self, "_%s" % attr, float(const_prob)) + else: + setattr(self, "_%s" % attr, tensor_prob) + + # Set variational_recurrent, seed before running the code below + self._variational_recurrent = variational_recurrent + self._input_size = input_size + self._seed = seed + + self._recurrent_input_noise = None + self._recurrent_state_noise = None + self._recurrent_output_noise = None + + if variational_recurrent: + if dtype is None: + raise ValueError( + "When variational_recurrent=True, dtype must be provided") + + def convert_to_batch_shape(s): + # Prepend a 1 for the batch dimension; for recurrent + # variational dropout we use the same dropout mask for all + # batch elements. + return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0) + + def batch_noise(s, inner_seed): + shape = convert_to_batch_shape(s) + return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) + + if (not isinstance(self._input_keep_prob, numbers.Real) or + self._input_keep_prob < 1.0): + if input_size is None: + raise ValueError( + "When variational_recurrent=True and input_keep_prob < 1.0 or " + "is unknown, input_size must be provided") + self._recurrent_input_noise = _enumerated_map_structure_up_to( + input_size, + lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), + input_size) + self._recurrent_state_noise = _enumerated_map_structure_up_to( + cell.state_size, + lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), + cell.state_size) + self._recurrent_output_noise = _enumerated_map_structure_up_to( + cell.output_size, + lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), + cell.output_size) + + def _gen_seed(self, salt_prefix, index): + if self._seed is None: + return None + salt = "%s_%d" % (salt_prefix, index) + string = (str(self._seed) + salt).encode("utf-8") + return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF + + @property + def wrapped_cell(self): + return self.cell + + @property + def state_size(self): + return self.cell.state_size + + @property + def output_size(self): + return self.cell.output_size + + def build(self, inputs_shape): + self.cell.build(inputs_shape) + self.built = True + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + return self.cell.zero_state(batch_size, dtype) + + def _variational_recurrent_dropout_value( + self, unused_index, value, noise, keep_prob): + """Performs dropout given the pre-calculated noise tensor.""" + # uniform [keep_prob, 1.0 + keep_prob) + random_tensor = keep_prob + noise + + # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) + binary_tensor = math_ops.floor(random_tensor) + ret = math_ops.divide(value, keep_prob) * binary_tensor + ret.set_shape(value.get_shape()) + return ret + + def _dropout(self, + values, + salt_prefix, + recurrent_noise, + keep_prob, + shallow_filtered_substructure=None): + """Decides whether to perform standard dropout or recurrent dropout.""" + + if shallow_filtered_substructure is None: + # Put something so we traverse the entire structure; inside the + # dropout function we check to see if leafs of this are bool or not. + shallow_filtered_substructure = values + + if not self._variational_recurrent: + + def dropout(i, do_dropout, v): + if not isinstance(do_dropout, bool) or do_dropout: + return nn_ops.dropout_v2( + v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i)) + else: + return v + + return _enumerated_map_structure_up_to( + shallow_filtered_substructure, dropout, + *[shallow_filtered_substructure, values]) + else: + + def dropout(i, do_dropout, v, n): + if not isinstance(do_dropout, bool) or do_dropout: + return self._variational_recurrent_dropout_value(i, v, n, keep_prob) + else: + return v + + return _enumerated_map_structure_up_to( + shallow_filtered_substructure, dropout, + *[shallow_filtered_substructure, values, recurrent_noise]) + + def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): + """Runs the wrapped cell and applies dropout. + + Args: + inputs: A tensor with wrapped cell's input. + state: A tensor or tuple of tensors with wrapped cell's state. + cell_call_fn: Wrapped cell's method to use for step computation (cell's + `__call__` or 'call' method). + **kwargs: Additional arguments. + + Returns: + A pair containing: + + - Output: A tensor with cell's output. + - New state: A tensor or tuple of tensors with new wrapped cell's state. + """ + + def _should_dropout(p): + return (not isinstance(p, float)) or p < 1 + + if _should_dropout(self._input_keep_prob): + inputs = self._dropout(inputs, "input", self._recurrent_input_noise, + self._input_keep_prob) + output, new_state = cell_call_fn(inputs, state, **kwargs) + if _should_dropout(self._state_keep_prob): + # Identify which subsets of the state to perform dropout on and + # which ones to keep. + shallow_filtered_substructure = nest.get_traverse_shallow_structure( + self._dropout_state_filter, new_state) + new_state = self._dropout(new_state, "state", self._recurrent_state_noise, + self._state_keep_prob, + shallow_filtered_substructure) + if _should_dropout(self._output_keep_prob): + output = self._dropout(output, "output", self._recurrent_output_noise, + self._output_keep_prob) + return output, new_state + + def get_config(self): + """Returns the config of the dropout wrapper.""" + config = { + "input_keep_prob": self._input_keep_prob, + "output_keep_prob": self._output_keep_prob, + "state_keep_prob": self._state_keep_prob, + "variational_recurrent": self._variational_recurrent, + "input_size": self._input_size, + "seed": self._seed, + } + if self._dropout_state_filter != _default_dropout_state_filter_visitor: + function, function_type, function_module = _serialize_function_to_config( + self._dropout_state_filter) + config.update({"dropout_fn": function, + "dropout_fn_type": function_type, + "dropout_fn_module": function_module}) + base_config = super(DropoutWrapperBase, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + if "dropout_fn" in config: + config = config.copy() + dropout_state_filter = _parse_config_to_function( + config, custom_objects, "dropout_fn", "dropout_fn_type", + "dropout_fn_module") + config.pop("dropout_fn") + config["dropout_state_filter_visitor"] = dropout_state_filter + return super(DropoutWrapperBase, cls).from_config( + config, custom_objects=custom_objects) + + +class ResidualWrapperBase(object): + """RNNCell wrapper that ensures cell inputs are added to the outputs.""" + + def __init__(self, cell, residual_fn=None, **kwargs): + """Constructs a `ResidualWrapper` for `cell`. + + Args: + cell: An instance of `RNNCell`. + residual_fn: (Optional) The function to map raw cell inputs and raw cell + outputs to the actual cell outputs of the residual network. + Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs + and outputs. + **kwargs: dict of keyword arguments for base layer. + """ + super(ResidualWrapperBase, self).__init__(cell, **kwargs) + self._residual_fn = residual_fn + + @property + def state_size(self): + return self.cell.state_size + + @property + def output_size(self): + return self.cell.output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + return self.cell.zero_state(batch_size, dtype) + + def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): + """Run the cell and then apply the residual_fn on its inputs to its outputs. + + Args: + inputs: cell inputs. + state: cell state. + cell_call_fn: Wrapped cell's method to use for step computation (cell's + `__call__` or 'call' method). + **kwargs: Additional arguments passed to the wrapped cell's `call`. + + Returns: + Tuple of cell outputs and new state. + + Raises: + TypeError: If cell inputs and outputs have different structure (type). + ValueError: If cell inputs and outputs have different structure (value). + """ + outputs, new_state = cell_call_fn(inputs, state, **kwargs) + + # Ensure shapes match + def assert_shape_match(inp, out): + inp.get_shape().assert_is_compatible_with(out.get_shape()) + + def default_residual_fn(inputs, outputs): + nest.assert_same_structure(inputs, outputs) + nest.map_structure(assert_shape_match, inputs, outputs) + return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) + + res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) + return (res_outputs, new_state) + + def get_config(self): + """Returns the config of the residual wrapper.""" + if self._residual_fn is not None: + function, function_type, function_module = _serialize_function_to_config( + self._residual_fn) + config = { + "residual_fn": function, + "residual_fn_type": function_type, + "residual_fn_module": function_module + } + else: + config = {} + base_config = super(ResidualWrapperBase, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + if "residual_fn" in config: + config = config.copy() + residual_function = _parse_config_to_function(config, custom_objects, + "residual_fn", + "residual_fn_type", + "residual_fn_module") + config["residual_fn"] = residual_function + return super(ResidualWrapperBase, cls).from_config( + config, custom_objects=custom_objects) + + +class DeviceWrapperBase(object): + """Operator that ensures an RNNCell runs on a particular device.""" + + def __init__(self, cell, device, **kwargs): + """Construct a `DeviceWrapper` for `cell` with device `device`. + + Ensures the wrapped `cell` is called with `tf.device(device)`. + + Args: + cell: An instance of `RNNCell`. + device: A device string or function, for passing to `tf.device`. + **kwargs: dict of keyword arguments for base layer. + """ + super(DeviceWrapperBase, self).__init__(cell, **kwargs) + self._device = device + + @property + def state_size(self): + return self.cell.state_size + + @property + def output_size(self): + return self.cell.output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + with ops.device(self._device): + return self.cell.zero_state(batch_size, dtype) + + def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): + """Run the cell on specified device.""" + with ops.device(self._device): + return cell_call_fn(inputs, state, **kwargs) + + def get_config(self): + config = {"device": self._device} + base_config = super(DeviceWrapperBase, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +def _serialize_function_to_config(function): + """Serialize the function for get_config().""" + if isinstance(function, python_types.LambdaType): + output = generic_utils.func_dump(function) + output_type = "lambda" + module = function.__module__ + elif callable(function): + output = function.__name__ + output_type = "function" + module = function.__module__ + else: + raise ValueError("Unrecognized function type for input: {}".format( + type(function))) + + return output, output_type, module + + +def _parse_config_to_function(config, custom_objects, func_attr_name, + func_type_attr_name, module_attr_name): + """Reconstruct the function from the config.""" + globs = globals() + module = config.pop(module_attr_name, None) + if module in sys.modules: + globs.update(sys.modules[module].__dict__) + elif module is not None: + # Note: we don't know the name of the function if it's a lambda. + warnings.warn("{} is not loaded, but a layer uses it. " + "It may cause errors.".format(module), UserWarning) + if custom_objects: + globs.update(custom_objects) + function_type = config.pop(func_type_attr_name) + if function_type == "function": + # Simple lookup in custom objects + function = generic_utils.deserialize_keras_object( + config[func_attr_name], + custom_objects=custom_objects, + printable_module_name="function in wrapper") + elif function_type == "lambda": + # Unsafe deserialization from bytecode + function = generic_utils.func_load( + config[func_attr_name], globs=globs) + else: + raise TypeError("Unknown function type:", function_type) + return function + + +def _default_dropout_state_filter_visitor(substate): + from tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl import LSTMStateTuple # pylint: disable=g-import-not-at-top + if isinstance(substate, LSTMStateTuple): + # Do not perform dropout on the memory state. + return LSTMStateTuple(c=False, h=True) + elif isinstance(substate, tensor_array_ops.TensorArray): + return False + return True + + +def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): + ix = [0] + + def enumerated_fn(*inner_args, **inner_kwargs): + r = map_fn(ix[0], *inner_args, **inner_kwargs) + ix[0] += 1 + return r + + return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, + **kwargs) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 719b06a41d5..b0b29de22ea 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -24,1332 +24,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.keras import activations -from tensorflow.python.keras import initializers -from tensorflow.python.keras.engine import input_spec -from tensorflow.python.keras.utils import tf_utils -from tensorflow.python.layers import base as base_layer -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import partitioned_variables -from tensorflow.python.ops import rnn_cell_wrapper_impl -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training.tracking import base as trackable -from tensorflow.python.util import nest -from tensorflow.python.util.deprecation import deprecated -from tensorflow.python.util.tf_export import tf_export - -_BIAS_VARIABLE_NAME = "bias" -_WEIGHTS_VARIABLE_NAME = "kernel" - -# This can be used with self.assertRaisesRegexp for assert_like_rnncell. -ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" - - -def _hasattr(obj, attr_name): - try: - getattr(obj, attr_name) - except AttributeError: - return False - else: - return True - - -def assert_like_rnncell(cell_name, cell): - """Raises a TypeError if cell is not like an RNNCell. - - NOTE: Do not rely on the error message (in particular in tests) which can be - subject to change to increase readability. Use - ASSERT_LIKE_RNNCELL_ERROR_REGEXP. - - Args: - cell_name: A string to give a meaningful error referencing to the name of - the functionargument. - cell: The object which should behave like an RNNCell. - - Raises: - TypeError: A human-friendly exception. - """ - conditions = [ - _hasattr(cell, "output_size"), - _hasattr(cell, "state_size"), - _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"), - callable(cell), - ] - errors = [ - "'output_size' property is missing", "'state_size' property is missing", - "either 'zero_state' or 'get_initial_state' method is required", - "is not callable" - ] - - if not all(conditions): - - errors = [error for error, cond in zip(errors, conditions) if not cond] - raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( - cell_name, cell, ", ".join(errors))) - - -def _concat(prefix, suffix, static=False): - """Concat that enables int, Tensor, or TensorShape values. - - This function takes a size specification, which can be an integer, a - TensorShape, or a Tensor, and converts it into a concatenated Tensor - (if static = False) or a list of integers (if static = True). - - Args: - prefix: The prefix; usually the batch size (and/or time step size). - (TensorShape, int, or Tensor.) - suffix: TensorShape, int, or Tensor. - static: If `True`, return a python list with possibly unknown dimensions. - Otherwise return a `Tensor`. - - Returns: - shape: the concatenation of prefix and suffix. - - Raises: - ValueError: if `suffix` is not a scalar or vector (or TensorShape). - ValueError: if prefix or suffix was `None` and asked for dynamic - Tensors out. - """ - if isinstance(prefix, ops.Tensor): - p = prefix - p_static = tensor_util.constant_value(prefix) - if p.shape.ndims == 0: - p = array_ops.expand_dims(p, 0) - elif p.shape.ndims != 1: - raise ValueError("prefix tensor must be either a scalar or vector, " - "but saw tensor: %s" % p) - else: - p = tensor_shape.as_shape(prefix) - p_static = p.as_list() if p.ndims is not None else None - p = ( - constant_op.constant(p.as_list(), dtype=dtypes.int32) - if p.is_fully_defined() else None) - if isinstance(suffix, ops.Tensor): - s = suffix - s_static = tensor_util.constant_value(suffix) - if s.shape.ndims == 0: - s = array_ops.expand_dims(s, 0) - elif s.shape.ndims != 1: - raise ValueError("suffix tensor must be either a scalar or vector, " - "but saw tensor: %s" % s) - else: - s = tensor_shape.as_shape(suffix) - s_static = s.as_list() if s.ndims is not None else None - s = ( - constant_op.constant(s.as_list(), dtype=dtypes.int32) - if s.is_fully_defined() else None) - - if static: - shape = tensor_shape.as_shape(p_static).concatenate(s_static) - shape = shape.as_list() if shape.ndims is not None else None - else: - if p is None or s is None: - raise ValueError("Provided a prefix or suffix of None: %s and %s" % - (prefix, suffix)) - shape = array_ops.concat((p, s), 0) - return shape - - -def _zero_state_tensors(state_size, batch_size, dtype): - """Create tensors of zeros based on state_size, batch_size, and dtype.""" - - def get_state_shape(s): - """Combine s with batch_size to get a proper tensor shape.""" - c = _concat(batch_size, s) - size = array_ops.zeros(c, dtype=dtype) - if not context.executing_eagerly(): - c_static = _concat(batch_size, s, static=True) - size.set_shape(c_static) - return size - - return nest.map_structure(get_state_shape, state_size) - - -@tf_export(v1=["nn.rnn_cell.RNNCell"]) -class RNNCell(base_layer.Layer): - """Abstract object representing an RNN cell. - - Every `RNNCell` must have the properties below and implement `call` with - the signature `(output, next_state) = call(input, state)`. The optional - third input argument, `scope`, is allowed for backwards compatibility - purposes; but should be left off for new subclasses. - - This definition of cell differs from the definition used in the literature. - In the literature, 'cell' refers to an object with a single scalar output. - This definition refers to a horizontal array of such units. - - An RNN cell, in the most abstract setting, is anything that has - a state and performs some operation that takes a matrix of inputs. - This operation results in an output matrix with `self.output_size` columns. - If `self.state_size` is an integer, this operation also results in a new - state matrix with `self.state_size` columns. If `self.state_size` is a - (possibly nested tuple of) TensorShape object(s), then it should return a - matching structure of Tensors having shape `[batch_size].concatenate(s)` - for each `s` in `self.batch_size`. - """ - - def __init__(self, trainable=True, name=None, dtype=None, **kwargs): - super(RNNCell, self).__init__( - trainable=trainable, name=name, dtype=dtype, **kwargs) - # Attribute that indicates whether the cell is a TF RNN cell, due the slight - # difference between TF and Keras RNN cell. Notably the state is not wrapped - # in a list for TF cell where they are single tensor state, whereas keras - # cell will wrap the state into a list, and call() will have to unwrap them. - self._is_tf_rnn_cell = True - - def __call__(self, inputs, state, scope=None): - """Run this RNN cell on inputs, starting from the given state. - - Args: - inputs: `2-D` tensor with shape `[batch_size, input_size]`. - state: if `self.state_size` is an integer, this should be a `2-D Tensor` - with shape `[batch_size, self.state_size]`. Otherwise, if - `self.state_size` is a tuple of integers, this should be a tuple with - shapes `[batch_size, s] for s in self.state_size`. - scope: VariableScope for the created subgraph; defaults to class name. - - Returns: - A pair containing: - - - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. - - New state: Either a single `2-D` tensor, or a tuple of tensors matching - the arity and shapes of `state`. - """ - if scope is not None: - with vs.variable_scope( - scope, custom_getter=self._rnn_get_variable) as scope: - return super(RNNCell, self).__call__(inputs, state, scope=scope) - else: - scope_attrname = "rnncell_scope" - scope = getattr(self, scope_attrname, None) - if scope is None: - scope = vs.variable_scope( - vs.get_variable_scope(), custom_getter=self._rnn_get_variable) - setattr(self, scope_attrname, scope) - with scope: - return super(RNNCell, self).__call__(inputs, state) - - def _rnn_get_variable(self, getter, *args, **kwargs): - variable = getter(*args, **kwargs) - if context.executing_eagerly(): - trainable = variable._trainable # pylint: disable=protected-access - else: - trainable = ( - variable in tf_variables.trainable_variables() or - (isinstance(variable, tf_variables.PartitionedVariable) and - list(variable)[0] in tf_variables.trainable_variables())) - if trainable and all(variable is not v for v in self._trainable_weights): - self._trainable_weights.append(variable) - elif not trainable and all( - variable is not v for v in self._non_trainable_weights): - self._non_trainable_weights.append(variable) - return variable - - @property - def state_size(self): - """size(s) of state(s) used by this cell. - - It can be represented by an Integer, a TensorShape or a tuple of Integers - or TensorShapes. - """ - raise NotImplementedError("Abstract method") - - @property - def output_size(self): - """Integer or TensorShape: size of outputs produced by this cell.""" - raise NotImplementedError("Abstract method") - - def build(self, _): - # This tells the parent Layer object that it's OK to call - # self.add_variable() inside the call() method. - pass - - def get_initial_state(self, inputs=None, batch_size=None, dtype=None): - if inputs is not None: - # Validate the given batch_size and dtype against inputs if provided. - inputs = ops.convert_to_tensor(inputs, name="inputs") - if batch_size is not None: - if tensor_util.is_tensor(batch_size): - static_batch_size = tensor_util.constant_value( - batch_size, partial=True) - else: - static_batch_size = batch_size - if inputs.shape.dims[0].value != static_batch_size: - raise ValueError( - "batch size from input tensor is different from the " - "input param. Input tensor batch: {}, batch_size: {}".format( - inputs.shape.dims[0].value, batch_size)) - - if dtype is not None and inputs.dtype != dtype: - raise ValueError( - "dtype from input tensor is different from the " - "input param. Input tensor dtype: {}, dtype: {}".format( - inputs.dtype, dtype)) - - batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0] - dtype = inputs.dtype - if batch_size is None or dtype is None: - raise ValueError( - "batch_size and dtype cannot be None while constructing initial " - "state: batch_size={}, dtype={}".format(batch_size, dtype)) - return self.zero_state(batch_size, dtype) - - def zero_state(self, batch_size, dtype): - """Return zero-filled state tensor(s). - - Args: - batch_size: int, float, or unit Tensor representing the batch size. - dtype: the data type to use for the state. - - Returns: - If `state_size` is an int or TensorShape, then the return value is a - `N-D` tensor of shape `[batch_size, state_size]` filled with zeros. - - If `state_size` is a nested list or tuple, then the return value is - a nested list or tuple (of the same structure) of `2-D` tensors with - the shapes `[batch_size, s]` for each s in `state_size`. - """ - # Try to use the last cached zero_state. This is done to avoid recreating - # zeros, especially when eager execution is enabled. - state_size = self.state_size - is_eager = context.executing_eagerly() - if is_eager and _hasattr(self, "_last_zero_state"): - (last_state_size, last_batch_size, last_dtype, - last_output) = getattr(self, "_last_zero_state") - if (last_batch_size == batch_size and last_dtype == dtype and - last_state_size == state_size): - return last_output - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - output = _zero_state_tensors(state_size, batch_size, dtype) - if is_eager: - self._last_zero_state = (state_size, batch_size, dtype, output) - return output - - # TODO(b/134773139): Remove when contrib RNN cells implement `get_config` - def get_config(self): # pylint: disable=useless-super-delegation - return super(RNNCell, self).get_config() - - -class LayerRNNCell(RNNCell): - """Subclass of RNNCells that act like proper `tf.Layer` objects. - - For backwards compatibility purposes, most `RNNCell` instances allow their - `call` methods to instantiate variables via `tf.compat.v1.get_variable`. The - underlying - variable scope thus keeps track of any variables, and returning cached - versions. This is atypical of `tf.layer` objects, which separate this - part of layer building into a `build` method that is only called once. - - Here we provide a subclass for `RNNCell` objects that act exactly as - `Layer` objects do. They must provide a `build` method and their - `call` methods do not access Variables `tf.compat.v1.get_variable`. - """ - - def __call__(self, inputs, state, scope=None, *args, **kwargs): - """Run this RNN cell on inputs, starting from the given state. - - Args: - inputs: `2-D` tensor with shape `[batch_size, input_size]`. - state: if `self.state_size` is an integer, this should be a `2-D Tensor` - with shape `[batch_size, self.state_size]`. Otherwise, if - `self.state_size` is a tuple of integers, this should be a tuple with - shapes `[batch_size, s] for s in self.state_size`. - scope: optional cell scope. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - A pair containing: - - - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. - - New state: Either a single `2-D` tensor, or a tuple of tensors matching - the arity and shapes of `state`. - """ - # Bypass RNNCell's variable capturing semantics for LayerRNNCell. - # Instead, it is up to subclasses to provide a proper build - # method. See the class docstring for more details. - return base_layer.Layer.__call__( - self, inputs, state, scope=scope, *args, **kwargs) - - -@tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) -class BasicRNNCell(LayerRNNCell): - """The most basic RNN cell. - - Note that this cell is not optimized for performance. Please use - `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU. - - Args: - num_units: int, The number of units in the RNN cell. - activation: Nonlinearity to use. Default: `tanh`. It could also be string - that is within Keras activation function names. - reuse: (optional) Python boolean describing whether to reuse variables in an - existing scope. If not `True`, and the existing scope already has the - given variables, an error is raised. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - """ - - @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell," - " and will be replaced by that in Tensorflow 2.0.") - def __init__(self, - num_units, - activation=None, - reuse=None, - name=None, - dtype=None, - **kwargs): - super(BasicRNNCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - _check_supported_dtypes(self.dtype) - if context.executing_eagerly() and context.num_gpus() > 0: - logging.warn( - "%s: Note that this cell is not optimized for performance. " - "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better " - "performance on GPU.", self) - - # Inputs must be 2-dimensional. - self.input_spec = input_spec.InputSpec(ndim=2) - - self._num_units = num_units - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - - @property - def state_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - str(inputs_shape)) - _check_supported_dtypes(self.dtype) - - input_depth = inputs_shape[-1] - self._kernel = self.add_variable( - _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + self._num_units, self._num_units]) - self._bias = self.add_variable( - _BIAS_VARIABLE_NAME, - shape=[self._num_units], - initializer=init_ops.zeros_initializer(dtype=self.dtype)) - - self.built = True - - def call(self, inputs, state): - """Most basic RNN: output = new_state = act(W * input + U * state + B).""" - _check_rnn_cell_input_dtypes([inputs, state]) - gate_inputs = math_ops.matmul( - array_ops.concat([inputs, state], 1), self._kernel) - gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) - output = self._activation(gate_inputs) - return output, output - - def get_config(self): - config = { - "num_units": self._num_units, - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(BasicRNNCell, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -@tf_export(v1=["nn.rnn_cell.GRUCell"]) -class GRUCell(LayerRNNCell): - """Gated Recurrent Unit cell. - - Note that this cell is not optimized for performance. Please use - `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or - `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU. - - Args: - num_units: int, The number of units in the GRU cell. - activation: Nonlinearity to use. Default: `tanh`. - reuse: (optional) Python boolean describing whether to reuse variables in an - existing scope. If not `True`, and the existing scope already has the - given variables, an error is raised. - kernel_initializer: (optional) The initializer to use for the weight and - projection matrices. - bias_initializer: (optional) The initializer to use for the bias. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - - References: - Learning Phrase Representations using RNN Encoder Decoder for Statistical - Machine Translation: - [Cho et al., 2014] - (https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179) - ([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf)) - """ - - @deprecated(None, "This class is equivalent as tf.keras.layers.GRUCell," - " and will be replaced by that in Tensorflow 2.0.") - def __init__(self, - num_units, - activation=None, - reuse=None, - kernel_initializer=None, - bias_initializer=None, - name=None, - dtype=None, - **kwargs): - super(GRUCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - _check_supported_dtypes(self.dtype) - - if context.executing_eagerly() and context.num_gpus() > 0: - logging.warn( - "%s: Note that this cell is not optimized for performance. " - "Please use tf.contrib.cudnn_rnn.CudnnGRU for better " - "performance on GPU.", self) - # Inputs must be 2-dimensional. - self.input_spec = input_spec.InputSpec(ndim=2) - - self._num_units = num_units - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - self._kernel_initializer = initializers.get(kernel_initializer) - self._bias_initializer = initializers.get(bias_initializer) - - @property - def state_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - str(inputs_shape)) - _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] - self._gate_kernel = self.add_variable( - "gates/%s" % _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + self._num_units, 2 * self._num_units], - initializer=self._kernel_initializer) - self._gate_bias = self.add_variable( - "gates/%s" % _BIAS_VARIABLE_NAME, - shape=[2 * self._num_units], - initializer=(self._bias_initializer - if self._bias_initializer is not None else - init_ops.constant_initializer(1.0, dtype=self.dtype))) - self._candidate_kernel = self.add_variable( - "candidate/%s" % _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + self._num_units, self._num_units], - initializer=self._kernel_initializer) - self._candidate_bias = self.add_variable( - "candidate/%s" % _BIAS_VARIABLE_NAME, - shape=[self._num_units], - initializer=(self._bias_initializer - if self._bias_initializer is not None else - init_ops.zeros_initializer(dtype=self.dtype))) - - self.built = True - - def call(self, inputs, state): - """Gated recurrent unit (GRU) with nunits cells.""" - _check_rnn_cell_input_dtypes([inputs, state]) - - gate_inputs = math_ops.matmul( - array_ops.concat([inputs, state], 1), self._gate_kernel) - gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) - - value = math_ops.sigmoid(gate_inputs) - r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) - - r_state = r * state - - candidate = math_ops.matmul( - array_ops.concat([inputs, r_state], 1), self._candidate_kernel) - candidate = nn_ops.bias_add(candidate, self._candidate_bias) - - c = self._activation(candidate) - new_h = u * state + (1 - u) * c - return new_h, new_h - - def get_config(self): - config = { - "num_units": self._num_units, - "kernel_initializer": initializers.serialize(self._kernel_initializer), - "bias_initializer": initializers.serialize(self._bias_initializer), - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(GRUCell, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) - - -@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"]) -class LSTMStateTuple(_LSTMStateTuple): - """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. - - Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state - and `h` is the output. - - Only used when `state_is_tuple=True`. - """ - __slots__ = () - - @property - def dtype(self): - (c, h) = self - if c.dtype != h.dtype: - raise TypeError("Inconsistent internal state: %s vs %s" % - (str(c.dtype), str(h.dtype))) - return c.dtype - - -@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"]) -class BasicLSTMCell(LayerRNNCell): - """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead. - - Basic LSTM recurrent network cell. - - The implementation is based on - - We add forget_bias (default: 1) to the biases of the forget gate in order to - reduce the scale of forgetting in the beginning of the training. - - It does not allow cell clipping, a projection layer, and does not - use peep-hole connections: it is the basic baseline. - - For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell` - that follows. - - Note that this cell is not optimized for performance. Please use - `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or - `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for - better performance on CPU. - """ - - @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," - " and will be replaced by that in Tensorflow 2.0.") - def __init__(self, - num_units, - forget_bias=1.0, - state_is_tuple=True, - activation=None, - reuse=None, - name=None, - dtype=None, - **kwargs): - """Initialize the basic LSTM cell. - - Args: - num_units: int, The number of units in the LSTM cell. - forget_bias: float, The bias added to forget gates (see above). Must set - to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. - state_is_tuple: If True, accepted and returned states are 2-tuples of the - `c_state` and `m_state`. If False, they are concatenated along the - column axis. The latter behavior will soon be deprecated. - activation: Activation function of the inner states. Default: `tanh`. It - could also be string that is within Keras activation function names. - reuse: (optional) Python boolean describing whether to reuse variables in - an existing scope. If not `True`, and the existing scope already has - the given variables, an error is raised. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - When restoring from CudnnLSTM-trained checkpoints, must use - `CudnnCompatibleLSTMCell` instead. - """ - super(BasicLSTMCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - _check_supported_dtypes(self.dtype) - if not state_is_tuple: - logging.warn( - "%s: Using a concatenated state is slower and will soon be " - "deprecated. Use state_is_tuple=True.", self) - if context.executing_eagerly() and context.num_gpus() > 0: - logging.warn( - "%s: Note that this cell is not optimized for performance. " - "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " - "performance on GPU.", self) - - # Inputs must be 2-dimensional. - self.input_spec = input_spec.InputSpec(ndim=2) - - self._num_units = num_units - self._forget_bias = forget_bias - self._state_is_tuple = state_is_tuple - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - - @property - def state_size(self): - return (LSTMStateTuple(self._num_units, self._num_units) - if self._state_is_tuple else 2 * self._num_units) - - @property - def output_size(self): - return self._num_units - - @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - str(inputs_shape)) - _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] - h_depth = self._num_units - self._kernel = self.add_variable( - _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + h_depth, 4 * self._num_units]) - self._bias = self.add_variable( - _BIAS_VARIABLE_NAME, - shape=[4 * self._num_units], - initializer=init_ops.zeros_initializer(dtype=self.dtype)) - - self.built = True - - def call(self, inputs, state): - """Long short-term memory cell (LSTM). - - Args: - inputs: `2-D` tensor with shape `[batch_size, input_size]`. - state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, - num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a - `Tensor` shaped `[batch_size, 2 * num_units]`. - - Returns: - A pair containing the new hidden state, and the new state (either a - `LSTMStateTuple` or a concatenated state, depending on - `state_is_tuple`). - """ - _check_rnn_cell_input_dtypes([inputs, state]) - - sigmoid = math_ops.sigmoid - one = constant_op.constant(1, dtype=dtypes.int32) - # Parameters of gates are concatenated into one multiply for efficiency. - if self._state_is_tuple: - c, h = state - else: - c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) - - gate_inputs = math_ops.matmul( - array_ops.concat([inputs, h], 1), self._kernel) - gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = array_ops.split( - value=gate_inputs, num_or_size_splits=4, axis=one) - - forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) - # Note that using `add` and `multiply` instead of `+` and `*` gives a - # performance improvement. So using those at the cost of readability. - add = math_ops.add - multiply = math_ops.multiply - new_c = add( - multiply(c, sigmoid(add(f, forget_bias_tensor))), - multiply(sigmoid(i), self._activation(j))) - new_h = multiply(self._activation(new_c), sigmoid(o)) - - if self._state_is_tuple: - new_state = LSTMStateTuple(new_c, new_h) - else: - new_state = array_ops.concat([new_c, new_h], 1) - return new_h, new_state - - def get_config(self): - config = { - "num_units": self._num_units, - "forget_bias": self._forget_bias, - "state_is_tuple": self._state_is_tuple, - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(BasicLSTMCell, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -@tf_export(v1=["nn.rnn_cell.LSTMCell"]) -class LSTMCell(LayerRNNCell): - """Long short-term memory unit (LSTM) recurrent network cell. - - The default non-peephole implementation is based on (Gers et al., 1999). - The peephole implementation is based on (Sak et al., 2014). - - The class uses optional peep-hole connections, optional cell clipping, and - an optional projection layer. - - Note that this cell is not optimized for performance. Please use - `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or - `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for - better performance on CPU. - References: - Long short-term memory recurrent neural network architectures for large - scale acoustic modeling: - [Sak et al., 2014] - (https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html) - ([pdf] - (https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf)) - Learning to forget: - [Gers et al., 1999] - (http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218) - ([pdf](https://arxiv.org/pdf/1409.2329.pdf)) - Long Short-Term Memory: - [Hochreiter et al., 1997] - (https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) - ([pdf](http://ml.jku.at/publications/older/3504.pdf)) - """ - - @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," - " and will be replaced by that in Tensorflow 2.0.") - def __init__(self, - num_units, - use_peepholes=False, - cell_clip=None, - initializer=None, - num_proj=None, - proj_clip=None, - num_unit_shards=None, - num_proj_shards=None, - forget_bias=1.0, - state_is_tuple=True, - activation=None, - reuse=None, - name=None, - dtype=None, - **kwargs): - """Initialize the parameters for an LSTM cell. - - Args: - num_units: int, The number of units in the LSTM cell. - use_peepholes: bool, set True to enable diagonal/peephole connections. - cell_clip: (optional) A float value, if provided the cell state is clipped - by this value prior to the cell output activation. - initializer: (optional) The initializer to use for the weight and - projection matrices. - num_proj: (optional) int, The output dimensionality for the projection - matrices. If None, no projection is performed. - proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is - provided, then the projected values are clipped elementwise to within - `[-proj_clip, proj_clip]`. - num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a - variable_scope partitioner instead. - num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a - variable_scope partitioner instead. - forget_bias: Biases of the forget gate are initialized by default to 1 in - order to reduce the scale of forgetting at the beginning of the - training. Must set it manually to `0.0` when restoring from CudnnLSTM - trained checkpoints. - state_is_tuple: If True, accepted and returned states are 2-tuples of the - `c_state` and `m_state`. If False, they are concatenated along the - column axis. This latter behavior will soon be deprecated. - activation: Activation function of the inner states. Default: `tanh`. It - could also be string that is within Keras activation function names. - reuse: (optional) Python boolean describing whether to reuse variables in - an existing scope. If not `True`, and the existing scope already has - the given variables, an error is raised. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - When restoring from CudnnLSTM-trained checkpoints, use - `CudnnCompatibleLSTMCell` instead. - """ - super(LSTMCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - _check_supported_dtypes(self.dtype) - if not state_is_tuple: - logging.warn( - "%s: Using a concatenated state is slower and will soon be " - "deprecated. Use state_is_tuple=True.", self) - if num_unit_shards is not None or num_proj_shards is not None: - logging.warn( - "%s: The num_unit_shards and proj_unit_shards parameters are " - "deprecated and will be removed in Jan 2017. " - "Use a variable scope with a partitioner instead.", self) - if context.executing_eagerly() and context.num_gpus() > 0: - logging.warn( - "%s: Note that this cell is not optimized for performance. " - "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " - "performance on GPU.", self) - - # Inputs must be 2-dimensional. - self.input_spec = input_spec.InputSpec(ndim=2) - - self._num_units = num_units - self._use_peepholes = use_peepholes - self._cell_clip = cell_clip - self._initializer = initializers.get(initializer) - self._num_proj = num_proj - self._proj_clip = proj_clip - self._num_unit_shards = num_unit_shards - self._num_proj_shards = num_proj_shards - self._forget_bias = forget_bias - self._state_is_tuple = state_is_tuple - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - - if num_proj: - self._state_size = ( - LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units + - num_proj) - self._output_size = num_proj - else: - self._state_size = ( - LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 * - num_units) - self._output_size = num_units - - @property - def state_size(self): - return self._state_size - - @property - def output_size(self): - return self._output_size - - @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - str(inputs_shape)) - _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] - h_depth = self._num_units if self._num_proj is None else self._num_proj - maybe_partitioner = ( - partitioned_variables.fixed_size_partitioner(self._num_unit_shards) - if self._num_unit_shards is not None else None) - self._kernel = self.add_variable( - _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + h_depth, 4 * self._num_units], - initializer=self._initializer, - partitioner=maybe_partitioner) - if self.dtype is None: - initializer = init_ops.zeros_initializer - else: - initializer = init_ops.zeros_initializer(dtype=self.dtype) - self._bias = self.add_variable( - _BIAS_VARIABLE_NAME, - shape=[4 * self._num_units], - initializer=initializer) - if self._use_peepholes: - self._w_f_diag = self.add_variable( - "w_f_diag", shape=[self._num_units], initializer=self._initializer) - self._w_i_diag = self.add_variable( - "w_i_diag", shape=[self._num_units], initializer=self._initializer) - self._w_o_diag = self.add_variable( - "w_o_diag", shape=[self._num_units], initializer=self._initializer) - - if self._num_proj is not None: - maybe_proj_partitioner = ( - partitioned_variables.fixed_size_partitioner(self._num_proj_shards) - if self._num_proj_shards is not None else None) - self._proj_kernel = self.add_variable( - "projection/%s" % _WEIGHTS_VARIABLE_NAME, - shape=[self._num_units, self._num_proj], - initializer=self._initializer, - partitioner=maybe_proj_partitioner) - - self.built = True - - def call(self, inputs, state): - """Run one step of LSTM. - - Args: - inputs: input Tensor, must be 2-D, `[batch, input_size]`. - state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, - [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple - of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. - - Returns: - A tuple containing: - - - A `2-D, [batch, output_dim]`, Tensor representing the output of the - LSTM after reading `inputs` when previous state was `state`. - Here output_dim is: - num_proj if num_proj was set, - num_units otherwise. - - Tensor(s) representing the new state of LSTM after reading `inputs` when - the previous state was `state`. Same type and shape(s) as `state`. - - Raises: - ValueError: If input size cannot be inferred from inputs via - static shape inference. - """ - _check_rnn_cell_input_dtypes([inputs, state]) - - num_proj = self._num_units if self._num_proj is None else self._num_proj - sigmoid = math_ops.sigmoid - - if self._state_is_tuple: - (c_prev, m_prev) = state - else: - c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) - m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) - - input_size = inputs.get_shape().with_rank(2).dims[1].value - if input_size is None: - raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - lstm_matrix = math_ops.matmul( - array_ops.concat([inputs, m_prev], 1), self._kernel) - lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) - - i, j, f, o = array_ops.split( - value=lstm_matrix, num_or_size_splits=4, axis=1) - # Diagonal connections - if self._use_peepholes: - c = ( - sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + - sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) - else: - c = ( - sigmoid(f + self._forget_bias) * c_prev + - sigmoid(i) * self._activation(j)) - - if self._cell_clip is not None: - # pylint: disable=invalid-unary-operand-type - c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) - # pylint: enable=invalid-unary-operand-type - if self._use_peepholes: - m = sigmoid(o + self._w_o_diag * c) * self._activation(c) - else: - m = sigmoid(o) * self._activation(c) - - if self._num_proj is not None: - m = math_ops.matmul(m, self._proj_kernel) - - if self._proj_clip is not None: - # pylint: disable=invalid-unary-operand-type - m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) - # pylint: enable=invalid-unary-operand-type - - new_state = ( - LSTMStateTuple(c, m) - if self._state_is_tuple else array_ops.concat([c, m], 1)) - return m, new_state - - def get_config(self): - config = { - "num_units": self._num_units, - "use_peepholes": self._use_peepholes, - "cell_clip": self._cell_clip, - "initializer": initializers.serialize(self._initializer), - "num_proj": self._num_proj, - "proj_clip": self._proj_clip, - "num_unit_shards": self._num_unit_shards, - "num_proj_shards": self._num_proj_shards, - "forget_bias": self._forget_bias, - "state_is_tuple": self._state_is_tuple, - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(LSTMCell, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class _RNNCellWrapperV1(RNNCell): - """Base class for cells wrappers V1 compatibility. - - This class along with `_RNNCellWrapperV2` allows to define cells wrappers that - are compatible with V1 and V2, and defines helper methods for this purpose. - """ - - def __init__(self, cell, *args, **kwargs): - super(_RNNCellWrapperV1, self).__init__(*args, **kwargs) - assert_like_rnncell("cell", cell) - self.cell = cell - if isinstance(cell, trackable.Trackable): - self._track_trackable(self.cell, name="cell") - - def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): - """Calls the wrapped cell and performs the wrapping logic. - - This method is called from the wrapper's `call` or `__call__` methods. - - Args: - inputs: A tensor with wrapped cell's input. - state: A tensor or tuple of tensors with wrapped cell's state. - cell_call_fn: Wrapped cell's method to use for step computation (cell's - `__call__` or 'call' method). - **kwargs: Additional arguments. - - Returns: - A pair containing: - - Output: A tensor with cell's output. - - New state: A tensor or tuple of tensors with new wrapped cell's state. - """ - raise NotImplementedError - - def __call__(self, inputs, state, scope=None): - """Runs the RNN cell step computation. - - We assume that the wrapped RNNCell is being built within its `__call__` - method. We directly use the wrapped cell's `__call__` in the overridden - wrapper `__call__` method. - - This allows to use the wrapped cell and the non-wrapped cell equivalently - when using `__call__`. - - Args: - inputs: A tensor with wrapped cell's input. - state: A tensor or tuple of tensors with wrapped cell's state. - scope: VariableScope for the subgraph created in the wrapped cells' - `__call__`. - - Returns: - A pair containing: - - - Output: A tensor with cell's output. - - New state: A tensor or tuple of tensors with new wrapped cell's state. - """ - return self._call_wrapped_cell( - inputs, state, cell_call_fn=self.cell.__call__, scope=scope) - - def get_config(self): - config = { - "cell": { - "class_name": self.cell.__class__.__name__, - "config": self.cell.get_config() - }, - } - base_config = super(_RNNCellWrapperV1, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - @classmethod - def from_config(cls, config, custom_objects=None): - config = config.copy() - cell = config.pop("cell") - try: - assert_like_rnncell("cell", cell) - return cls(cell, **config) - except TypeError: - raise ValueError("RNNCellWrapper cannot reconstruct the wrapped cell. " - "Please overwrite the cell in the config with a RNNCell " - "instance.") - - -@tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) -class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, - _RNNCellWrapperV1): - """Operator adding dropout to inputs and outputs of the given cell.""" - - def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation - super(DropoutWrapper, self).__init__(*args, **kwargs) - - __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ - - -@tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) -class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, - _RNNCellWrapperV1): - """RNNCell wrapper that ensures cell inputs are added to the outputs.""" - - def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation - super(ResidualWrapper, self).__init__(*args, **kwargs) - - __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ - - -@tf_export(v1=["nn.rnn_cell.DeviceWrapper"]) -class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, - _RNNCellWrapperV1): - - def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation - super(DeviceWrapper, self).__init__(*args, **kwargs) - - __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__ - - -@tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) -class MultiRNNCell(RNNCell): - """RNN cell composed sequentially of multiple simple cells. - - Example: - - ```python - num_units = [128, 64] - cells = [BasicLSTMCell(num_units=n) for n in num_units] - stacked_rnn_cell = MultiRNNCell(cells) - ``` - """ - - @deprecated(None, "This class is equivalent as " - "tf.keras.layers.StackedRNNCells, and will be replaced by " - "that in Tensorflow 2.0.") - def __init__(self, cells, state_is_tuple=True): - """Create a RNN cell composed sequentially of a number of RNNCells. - - Args: - cells: list of RNNCells that will be composed in this order. - state_is_tuple: If True, accepted and returned states are n-tuples, where - `n = len(cells)`. If False, the states are all concatenated along the - column axis. This latter behavior will soon be deprecated. - - Raises: - ValueError: if cells is empty (not allowed), or at least one of the cells - returns a state tuple but the flag `state_is_tuple` is `False`. - """ - super(MultiRNNCell, self).__init__() - if not cells: - raise ValueError("Must specify at least one cell for MultiRNNCell.") - if not nest.is_sequence(cells): - raise TypeError("cells must be a list or tuple, but saw: %s." % cells) - - if len(set(id(cell) for cell in cells)) < len(cells): - logging.log_first_n( - logging.WARN, "At least two cells provided to MultiRNNCell " - "are the same object and will share weights.", 1) - - self._cells = cells - for cell_number, cell in enumerate(self._cells): - # Add Trackable dependencies on these cells so their variables get - # saved with this object when using object-based saving. - if isinstance(cell, trackable.Trackable): - # TODO(allenl): Track down non-Trackable callers. - self._track_trackable(cell, name="cell-%d" % (cell_number,)) - self._state_is_tuple = state_is_tuple - if not state_is_tuple: - if any(nest.is_sequence(c.state_size) for c in self._cells): - raise ValueError("Some cells return tuples of states, but the flag " - "state_is_tuple is not set. State sizes are: %s" % - str([c.state_size for c in self._cells])) - - @property - def state_size(self): - if self._state_is_tuple: - return tuple(cell.state_size for cell in self._cells) - else: - return sum(cell.state_size for cell in self._cells) - - @property - def output_size(self): - return self._cells[-1].output_size - - def zero_state(self, batch_size, dtype): - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - if self._state_is_tuple: - return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) - else: - # We know here that state_size of each cell is not a tuple and - # presumably does not contain TensorArrays or anything else fancy - return super(MultiRNNCell, self).zero_state(batch_size, dtype) - - @property - def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for cell in self._cells: - if isinstance(cell, base_layer.Layer): - weights += cell.trainable_weights - return weights - - @property - def non_trainable_weights(self): - weights = [] - for cell in self._cells: - if isinstance(cell, base_layer.Layer): - weights += cell.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for cell in self._cells: - if isinstance(cell, base_layer.Layer): - trainable_weights += cell.trainable_weights - return trainable_weights + weights - return weights - - def call(self, inputs, state): - """Run this multi-layer cell on inputs, starting from state.""" - cur_state_pos = 0 - cur_inp = inputs - new_states = [] - for i, cell in enumerate(self._cells): - with vs.variable_scope("cell_%d" % i): - if self._state_is_tuple: - if not nest.is_sequence(state): - raise ValueError( - "Expected state to be a tuple of length %d, but received: %s" % - (len(self.state_size), state)) - cur_state = state[i] - else: - cur_state = array_ops.slice(state, [0, cur_state_pos], - [-1, cell.state_size]) - cur_state_pos += cell.state_size - cur_inp, new_state = cell(cur_inp, cur_state) - new_states.append(new_state) - - new_states = ( - tuple(new_states) if self._state_is_tuple else array_ops.concat( - new_states, 1)) - - return cur_inp, new_states - - -def _check_rnn_cell_input_dtypes(inputs): - """Check whether the input tensors are with supported dtypes. - - Default RNN cells only support floats and complex as its dtypes since the - activation function (tanh and sigmoid) only allow those types. This function - will throw a proper error message if the inputs is not in a supported type. - - Args: - inputs: tensor or nested structure of tensors that are feed to RNN cell as - input or state. - - Raises: - ValueError: if any of the input tensor are not having dtypes of float or - complex. - """ - for t in nest.flatten(inputs): - _check_supported_dtypes(t.dtype) - - -def _check_supported_dtypes(dtype): - if dtype is None: - return - dtype = dtypes.as_dtype(dtype) - if not (dtype.is_floating or dtype.is_complex): - raise ValueError("RNN cell only supports floating point inputs, " - "but saw dtype: %s" % dtype) +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_impl + +# Remove caller that rely on private symbol in future. +# pylint: disable=protected-access +_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME +_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME +_concat = rnn_cell_impl._concat +_zero_state_tensors = rnn_cell_impl._zero_state_tensors +# pylint: disable=protected-access + + +assert_like_rnncell = rnn_cell_impl.assert_like_rnncell +ASSERT_LIKE_RNNCELL_ERROR_REGEXP = rnn_cell_impl.ASSERT_LIKE_RNNCELL_ERROR_REGEXP # pylint: disable=line-too-long +BasicLSTMCell = rnn_cell_impl.BasicLSTMCell +BasicRNNCell = rnn_cell_impl.BasicRNNCell +DeviceWrapper = rnn_cell_impl.DeviceWrapper +DropoutWrapper = rnn_cell_impl.DropoutWrapper +GRUCell = rnn_cell_impl.GRUCell +LayerRNNCell = rnn_cell_impl.LayerRNNCell +LSTMCell = rnn_cell_impl.LSTMCell +LSTMStateTuple = rnn_cell_impl.LSTMStateTuple +MultiRNNCell = rnn_cell_impl.MultiRNNCell +ResidualWrapper = rnn_cell_impl.ResidualWrapper +RNNCell = rnn_cell_impl.RNNCell diff --git a/tensorflow/python/ops/rnn_cell_wrapper_impl.py b/tensorflow/python/ops/rnn_cell_wrapper_impl.py index 9c42caea63e..a8b1fea2a5f 100644 --- a/tensorflow/python/ops/rnn_cell_wrapper_impl.py +++ b/tensorflow/python/ops/rnn_cell_wrapper_impl.py @@ -17,500 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import hashlib -import numbers -import sys -import types as python_types -import warnings +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.keras.utils import generic_utils -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.util import nest - -class DropoutWrapperBase(object): - """Operator adding dropout to inputs and outputs of the given cell.""" - - def __init__(self, - cell, - input_keep_prob=1.0, - output_keep_prob=1.0, - state_keep_prob=1.0, - variational_recurrent=False, - input_size=None, - dtype=None, - seed=None, - dropout_state_filter_visitor=None, - **kwargs): - """Create a cell with added input, state, and/or output dropout. - - If `variational_recurrent` is set to `True` (**NOT** the default behavior), - then the same dropout mask is applied at every step, as described in: - [A Theoretically Grounded Application of Dropout in Recurrent - Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287). - - Otherwise a different dropout mask is applied at every time step. - - Note, by default (unless a custom `dropout_state_filter` is provided), - the memory state (`c` component of any `LSTMStateTuple`) passing through - a `DropoutWrapper` is never modified. This behavior is described in the - above article. - - Args: - cell: an RNNCell, a projection to output_size is added to it. - input_keep_prob: unit Tensor or float between 0 and 1, input keep - probability; if it is constant and 1, no input dropout will be added. - output_keep_prob: unit Tensor or float between 0 and 1, output keep - probability; if it is constant and 1, no output dropout will be added. - state_keep_prob: unit Tensor or float between 0 and 1, output keep - probability; if it is constant and 1, no output dropout will be added. - State dropout is performed on the outgoing states of the cell. **Note** - the state components to which dropout is applied when `state_keep_prob` - is in `(0, 1)` are also determined by the argument - `dropout_state_filter_visitor` (e.g. by default dropout is never applied - to the `c` component of an `LSTMStateTuple`). - variational_recurrent: Python bool. If `True`, then the same dropout - pattern is applied across all time steps per run call. If this parameter - is set, `input_size` **must** be provided. - input_size: (optional) (possibly nested tuple of) `TensorShape` objects - containing the depth(s) of the input tensors expected to be passed in to - the `DropoutWrapper`. Required and used **iff** `variational_recurrent - = True` and `input_keep_prob < 1`. - dtype: (optional) The `dtype` of the input, state, and output tensors. - Required and used **iff** `variational_recurrent = True`. - seed: (optional) integer, the randomness seed. - dropout_state_filter_visitor: (optional), default: (see below). Function - that takes any hierarchical level of the state and returns a scalar or - depth=1 structure of Python booleans describing which terms in the state - should be dropped out. In addition, if the function returns `True`, - dropout is applied across this sublevel. If the function returns - `False`, dropout is not applied across this entire sublevel. - Default behavior: perform dropout on all terms except the memory (`c`) - state of `LSTMCellState` objects, and don't try to apply dropout to - `TensorArray` objects: ``` - def dropout_state_filter_visitor(s): - if isinstance(s, LSTMCellState): # Never perform dropout on the c - state. return LSTMCellState(c=False, h=True) - elif isinstance(s, TensorArray): return False return True ``` - **kwargs: dict of keyword arguments for base layer. - - Raises: - TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided - but not `callable`. - ValueError: if any of the keep_probs are not between 0 and 1. - """ - super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs) - - if (dropout_state_filter_visitor is not None and - not callable(dropout_state_filter_visitor)): - raise TypeError("dropout_state_filter_visitor must be callable") - self._dropout_state_filter = ( - dropout_state_filter_visitor or _default_dropout_state_filter_visitor) - with ops.name_scope("DropoutWrapperInit"): - - def tensor_and_const_value(v): - tensor_value = ops.convert_to_tensor(v) - const_value = tensor_util.constant_value(tensor_value) - return (tensor_value, const_value) - - for prob, attr in [(input_keep_prob, "input_keep_prob"), - (state_keep_prob, "state_keep_prob"), - (output_keep_prob, "output_keep_prob")]: - tensor_prob, const_prob = tensor_and_const_value(prob) - if const_prob is not None: - if const_prob < 0 or const_prob > 1: - raise ValueError("Parameter %s must be between 0 and 1: %d" % - (attr, const_prob)) - setattr(self, "_%s" % attr, float(const_prob)) - else: - setattr(self, "_%s" % attr, tensor_prob) - - # Set variational_recurrent, seed before running the code below - self._variational_recurrent = variational_recurrent - self._input_size = input_size - self._seed = seed - - self._recurrent_input_noise = None - self._recurrent_state_noise = None - self._recurrent_output_noise = None - - if variational_recurrent: - if dtype is None: - raise ValueError( - "When variational_recurrent=True, dtype must be provided") - - def convert_to_batch_shape(s): - # Prepend a 1 for the batch dimension; for recurrent - # variational dropout we use the same dropout mask for all - # batch elements. - return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0) - - def batch_noise(s, inner_seed): - shape = convert_to_batch_shape(s) - return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) - - if (not isinstance(self._input_keep_prob, numbers.Real) or - self._input_keep_prob < 1.0): - if input_size is None: - raise ValueError( - "When variational_recurrent=True and input_keep_prob < 1.0 or " - "is unknown, input_size must be provided") - self._recurrent_input_noise = _enumerated_map_structure_up_to( - input_size, - lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), - input_size) - self._recurrent_state_noise = _enumerated_map_structure_up_to( - cell.state_size, - lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), - cell.state_size) - self._recurrent_output_noise = _enumerated_map_structure_up_to( - cell.output_size, - lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), - cell.output_size) - - def _gen_seed(self, salt_prefix, index): - if self._seed is None: - return None - salt = "%s_%d" % (salt_prefix, index) - string = (str(self._seed) + salt).encode("utf-8") - return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF - - @property - def wrapped_cell(self): - return self.cell - - @property - def state_size(self): - return self.cell.state_size - - @property - def output_size(self): - return self.cell.output_size - - def build(self, inputs_shape): - self.cell.build(inputs_shape) - self.built = True - - def zero_state(self, batch_size, dtype): - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - return self.cell.zero_state(batch_size, dtype) - - def _variational_recurrent_dropout_value( - self, unused_index, value, noise, keep_prob): - """Performs dropout given the pre-calculated noise tensor.""" - # uniform [keep_prob, 1.0 + keep_prob) - random_tensor = keep_prob + noise - - # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) - binary_tensor = math_ops.floor(random_tensor) - ret = math_ops.divide(value, keep_prob) * binary_tensor - ret.set_shape(value.get_shape()) - return ret - - def _dropout(self, - values, - salt_prefix, - recurrent_noise, - keep_prob, - shallow_filtered_substructure=None): - """Decides whether to perform standard dropout or recurrent dropout.""" - - if shallow_filtered_substructure is None: - # Put something so we traverse the entire structure; inside the - # dropout function we check to see if leafs of this are bool or not. - shallow_filtered_substructure = values - - if not self._variational_recurrent: - - def dropout(i, do_dropout, v): - if not isinstance(do_dropout, bool) or do_dropout: - return nn_ops.dropout_v2( - v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i)) - else: - return v - - return _enumerated_map_structure_up_to( - shallow_filtered_substructure, dropout, - *[shallow_filtered_substructure, values]) - else: - - def dropout(i, do_dropout, v, n): - if not isinstance(do_dropout, bool) or do_dropout: - return self._variational_recurrent_dropout_value(i, v, n, keep_prob) - else: - return v - - return _enumerated_map_structure_up_to( - shallow_filtered_substructure, dropout, - *[shallow_filtered_substructure, values, recurrent_noise]) - - def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): - """Runs the wrapped cell and applies dropout. - - Args: - inputs: A tensor with wrapped cell's input. - state: A tensor or tuple of tensors with wrapped cell's state. - cell_call_fn: Wrapped cell's method to use for step computation (cell's - `__call__` or 'call' method). - **kwargs: Additional arguments. - - Returns: - A pair containing: - - - Output: A tensor with cell's output. - - New state: A tensor or tuple of tensors with new wrapped cell's state. - """ - - def _should_dropout(p): - return (not isinstance(p, float)) or p < 1 - - if _should_dropout(self._input_keep_prob): - inputs = self._dropout(inputs, "input", self._recurrent_input_noise, - self._input_keep_prob) - output, new_state = cell_call_fn(inputs, state, **kwargs) - if _should_dropout(self._state_keep_prob): - # Identify which subsets of the state to perform dropout on and - # which ones to keep. - shallow_filtered_substructure = nest.get_traverse_shallow_structure( - self._dropout_state_filter, new_state) - new_state = self._dropout(new_state, "state", self._recurrent_state_noise, - self._state_keep_prob, - shallow_filtered_substructure) - if _should_dropout(self._output_keep_prob): - output = self._dropout(output, "output", self._recurrent_output_noise, - self._output_keep_prob) - return output, new_state - - def get_config(self): - """Returns the config of the dropout wrapper.""" - config = { - "input_keep_prob": self._input_keep_prob, - "output_keep_prob": self._output_keep_prob, - "state_keep_prob": self._state_keep_prob, - "variational_recurrent": self._variational_recurrent, - "input_size": self._input_size, - "seed": self._seed, - } - if self._dropout_state_filter != _default_dropout_state_filter_visitor: - function, function_type, function_module = _serialize_function_to_config( - self._dropout_state_filter) - config.update({"dropout_fn": function, - "dropout_fn_type": function_type, - "dropout_fn_module": function_module}) - base_config = super(DropoutWrapperBase, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - @classmethod - def from_config(cls, config, custom_objects=None): - if "dropout_fn" in config: - config = config.copy() - dropout_state_filter = _parse_config_to_function( - config, custom_objects, "dropout_fn", "dropout_fn_type", - "dropout_fn_module") - config.pop("dropout_fn") - config["dropout_state_filter_visitor"] = dropout_state_filter - return super(DropoutWrapperBase, cls).from_config( - config, custom_objects=custom_objects) - - -class ResidualWrapperBase(object): - """RNNCell wrapper that ensures cell inputs are added to the outputs.""" - - def __init__(self, cell, residual_fn=None, **kwargs): - """Constructs a `ResidualWrapper` for `cell`. - - Args: - cell: An instance of `RNNCell`. - residual_fn: (Optional) The function to map raw cell inputs and raw cell - outputs to the actual cell outputs of the residual network. - Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs - and outputs. - **kwargs: dict of keyword arguments for base layer. - """ - super(ResidualWrapperBase, self).__init__(cell, **kwargs) - self._residual_fn = residual_fn - - @property - def state_size(self): - return self.cell.state_size - - @property - def output_size(self): - return self.cell.output_size - - def zero_state(self, batch_size, dtype): - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - return self.cell.zero_state(batch_size, dtype) - - def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): - """Run the cell and then apply the residual_fn on its inputs to its outputs. - - Args: - inputs: cell inputs. - state: cell state. - cell_call_fn: Wrapped cell's method to use for step computation (cell's - `__call__` or 'call' method). - **kwargs: Additional arguments passed to the wrapped cell's `call`. - - Returns: - Tuple of cell outputs and new state. - - Raises: - TypeError: If cell inputs and outputs have different structure (type). - ValueError: If cell inputs and outputs have different structure (value). - """ - outputs, new_state = cell_call_fn(inputs, state, **kwargs) - - # Ensure shapes match - def assert_shape_match(inp, out): - inp.get_shape().assert_is_compatible_with(out.get_shape()) - - def default_residual_fn(inputs, outputs): - nest.assert_same_structure(inputs, outputs) - nest.map_structure(assert_shape_match, inputs, outputs) - return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) - - res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) - return (res_outputs, new_state) - - def get_config(self): - """Returns the config of the residual wrapper.""" - if self._residual_fn is not None: - function, function_type, function_module = _serialize_function_to_config( - self._residual_fn) - config = { - "residual_fn": function, - "residual_fn_type": function_type, - "residual_fn_module": function_module - } - else: - config = {} - base_config = super(ResidualWrapperBase, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - @classmethod - def from_config(cls, config, custom_objects=None): - if "residual_fn" in config: - config = config.copy() - residual_function = _parse_config_to_function(config, custom_objects, - "residual_fn", - "residual_fn_type", - "residual_fn_module") - config["residual_fn"] = residual_function - return super(ResidualWrapperBase, cls).from_config( - config, custom_objects=custom_objects) - - -class DeviceWrapperBase(object): - """Operator that ensures an RNNCell runs on a particular device.""" - - def __init__(self, cell, device, **kwargs): - """Construct a `DeviceWrapper` for `cell` with device `device`. - - Ensures the wrapped `cell` is called with `tf.device(device)`. - - Args: - cell: An instance of `RNNCell`. - device: A device string or function, for passing to `tf.device`. - **kwargs: dict of keyword arguments for base layer. - """ - super(DeviceWrapperBase, self).__init__(cell, **kwargs) - self._device = device - - @property - def state_size(self): - return self.cell.state_size - - @property - def output_size(self): - return self.cell.output_size - - def zero_state(self, batch_size, dtype): - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - with ops.device(self._device): - return self.cell.zero_state(batch_size, dtype) - - def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): - """Run the cell on specified device.""" - with ops.device(self._device): - return cell_call_fn(inputs, state, **kwargs) - - def get_config(self): - config = {"device": self._device} - base_config = super(DeviceWrapperBase, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -def _serialize_function_to_config(function): - """Serialize the function for get_config().""" - if isinstance(function, python_types.LambdaType): - output = generic_utils.func_dump(function) - output_type = "lambda" - module = function.__module__ - elif callable(function): - output = function.__name__ - output_type = "function" - module = function.__module__ - else: - raise ValueError("Unrecognized function type for input: {}".format( - type(function))) - - return output, output_type, module - - -def _parse_config_to_function(config, custom_objects, func_attr_name, - func_type_attr_name, module_attr_name): - """Reconstruct the function from the config.""" - globs = globals() - module = config.pop(module_attr_name, None) - if module in sys.modules: - globs.update(sys.modules[module].__dict__) - elif module is not None: - # Note: we don't know the name of the function if it's a lambda. - warnings.warn("{} is not loaded, but a layer uses it. " - "It may cause errors.".format(module), UserWarning) - if custom_objects: - globs.update(custom_objects) - function_type = config.pop(func_type_attr_name) - if function_type == "function": - # Simple lookup in custom objects - function = generic_utils.deserialize_keras_object( - config[func_attr_name], - custom_objects=custom_objects, - printable_module_name="function in wrapper") - elif function_type == "lambda": - # Unsafe deserialization from bytecode - function = generic_utils.func_load( - config[func_attr_name], globs=globs) - else: - raise TypeError("Unknown function type:", function_type) - return function - - -def _default_dropout_state_filter_visitor(substate): - from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple # pylint: disable=g-import-not-at-top - if isinstance(substate, LSTMStateTuple): - # Do not perform dropout on the memory state. - return LSTMStateTuple(c=False, h=True) - elif isinstance(substate, tensor_array_ops.TensorArray): - return False - return True - - -def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): - ix = [0] - - def enumerated_fn(*inner_args, **inner_kwargs): - r = map_fn(ix[0], *inner_args, **inner_kwargs) - ix[0] += 1 - return r - - return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, - **kwargs) +DeviceWrapperBase = rnn_cell_wrapper_impl.DeviceWrapperBase +DropoutWrapperBase = rnn_cell_wrapper_impl.DropoutWrapperBase +ResidualWrapperBase = rnn_cell_wrapper_impl.ResidualWrapperBase diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt index a143468c615..34cc5a20beb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.lite.experimental.nn.TFLiteLSTMCell" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt index fd240a31637..f17c071254e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.lite.experimental.nn.TfLiteRNNCell" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index 02ec119a24a..d337b185c46 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.nn.rnn_cell.BasicLSTMCell" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 185bfa99489..7269795bde5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.nn.rnn_cell.BasicRNNCell" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 102a2266f5a..1368c1cb603 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.nn.rnn_cell.DeviceWrapper" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index bb6bde99e53..6d490621aa9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.nn.rnn_cell.DropoutWrapper" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 832ec6f6be6..a9669ff59a2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.nn.rnn_cell.GRUCell" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index 6f471d3f811..54e517ac974 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.nn.rnn_cell.LSTMCell" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt index 1de8a55dcca..274624a6203 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.LSTMStateTuple" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "c" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index 48d17d35fbe..db31de9d754 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.MultiRNNCell" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index 5c428f658c9..2286a66efd8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.nn.rnn_cell.RNNCell" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index 629d73640f3..5570bf7af98 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.nn.rnn_cell.ResidualWrapper" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt index d047b62497c..e94252a2f6d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.RNNCellDeviceWrapper" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt index 7c6eee87429..9682fd0a29a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.RNNCellDropoutWrapper" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt index e4d9d2e9737..0cd0cebfa1d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.RNNCellResidualWrapper" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" From 4441a60ec159987eb624455cd7395532080d461a Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Mon, 23 Mar 2020 14:59:08 -0700 Subject: [PATCH 450/492] [XLA] Disallow send/recv buffers from memory space assignment. Today, the current mechanism is causing two issues: 1- If the send buffer is prefetched to alternate memory, SendDone will have a tuple operand consisting of the token and context gte'd from Send, and the source buffer from copy-done. SendDone expects only an operand of Send. 2- The memory allocation doesn't know the side effecting nature of send and recv where the send/recv buffer needs a stable allocation throughout the sending or receiving. This can cause memory corruption. I will revisit this later to allow send/recv buffers to be placed in alternate memory. PiperOrigin-RevId: 302525059 Change-Id: Id8bcc5a9f615e1cd3018e95ae4e4bc1e4723ceb9 --- .../xla/service/memory_space_assignment.cc | 25 ++++++--- .../service/memory_space_assignment_test.cc | 55 +++++++++++++++++++ 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index d30c24616ff..4dc1c5782b6 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -315,14 +315,23 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( // Send and Recv HLOs return a request identifier. These should not be // allocated in the alternate memory. - const HloPosition& defining_position = interval.buffer->defining_position(); - if ((defining_position.instruction->opcode() == HloOpcode::kSend || - defining_position.instruction->opcode() == HloOpcode::kRecv) && - defining_position.index == ShapeIndex({1})) { - VLOG(4) - << "Keeping value " << interval.buffer->ToShortString() - << " in default mem because it is a request identifier for send/recv."; - return false; + for (const HloPosition& position : interval.buffer->positions()) { + if ((position.instruction->opcode() == HloOpcode::kSend || + position.instruction->opcode() == HloOpcode::kRecv)) { + // TODO(berkin): Send/recv buffers need a stable buffer allocation + // throughout sending/receiving. Disable memory space allocation for these + // for now. + if (position.index == ShapeIndex({0})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a send/recv buffer."; + return false; + } else if (position.index == ShapeIndex({1})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a request identifier for " + "send/recv."; + return false; + } + } } return true; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index b4797751162..31967e94c46 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -1321,6 +1321,61 @@ TEST_P(MemorySpaceAssignmentTest, } } +TEST_P(MemorySpaceAssignmentTest, SendDoneShouldHaveSendOperand) { + // Ensure that SendDone has only a Send operand. + absl::string_view hlo_string = R"( + HloModule SendRecv, is_scheduled=true + + ENTRY %AddDependency (p: f32[3]) -> f32[3] { + %p0 = f32[3]{0} parameter(0) + %p1 = f32[3]{0} parameter(1) + %neg0 = f32[3]{0} negate(f32[3]{0} %p1) + %neg1 = f32[3]{0} negate(f32[3]{0} %neg0) + %neg2 = f32[3]{0} negate(f32[3]{0} %neg1) + %neg3 = f32[3]{0} negate(f32[3]{0} %neg2) + %neg4 = f32[3]{0} negate(f32[3]{0} %neg3) + %neg5 = f32[3]{0} negate(f32[3]{0} %neg4) + %neg6 = f32[3]{0} negate(f32[3]{0} %neg5) + %after-all = token[] after-all() + %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2 + %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2 + ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); +} + +TEST_P(MemorySpaceAssignmentTest, SendAndSendDoneShouldGetSameAllocation) { + // Ensure that Send and SendDone have the same allocation. + absl::string_view hlo_string = R"( + HloModule SendRecv, is_scheduled=true + + ENTRY %AddDependency (p: f32[3]) -> f32[3] { + %p0 = f32[3]{0} parameter(0) + %p1 = f32[3]{0} parameter(1) + %after-all = token[] after-all() + %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2 + %neg0 = f32[3]{0} negate(f32[3]{0} %p1) + %neg1 = f32[3]{0} negate(f32[3]{0} %neg0) + %neg2 = f32[3]{0} negate(f32[3]{0} %neg1) + %neg3 = f32[3]{0} negate(f32[3]{0} %neg2) + %neg4 = f32[3]{0} negate(f32[3]{0} %neg3) + %neg5 = f32[3]{0} negate(f32[3]{0} %neg4) + %neg6 = f32[3]{0} negate(f32[3]{0} %neg5) + %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2 + ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/4); +} + TEST_P(MemorySpaceAssignmentTest, LastUseOpt) { // Test that checks the last use optimization. It uses two buffers that should // be placed in alternate memory. From b3dab55405d0b7aa27aeadeb35777f190f1bcc16 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 23 Mar 2020 15:07:45 -0700 Subject: [PATCH 451/492] [tstring] In `SerializeToTString()` use `SerializeWithCachedSizesToArray()`. This avoids the need to calculate the serialized size of the message twice, and matches the behavior of `MessageLite::SerializeToString()`. PiperOrigin-RevId: 302527078 Change-Id: Iacbe663ee0271ba04b98eb544719ff684f82b85f --- tensorflow/core/platform/protobuf.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/platform/protobuf.h b/tensorflow/core/platform/protobuf.h index 2422aacd5f6..d05095dcf55 100644 --- a/tensorflow/core/platform/protobuf.h +++ b/tensorflow/core/platform/protobuf.h @@ -91,7 +91,8 @@ inline bool SerializeToTString(const protobuf::MessageLite& proto, tstring* output) { size_t size = proto.ByteSizeLong(); output->resize_uninitialized(size); - return proto.SerializeToArray(output->data(), static_cast(size)); + return proto.SerializeWithCachedSizesToArray( + reinterpret_cast(output->data())); } inline bool ParseFromTString(const tstring& input, From c0d13627b36a68b6e3622e75da433df35ef40292 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 23 Mar 2020 15:09:42 -0700 Subject: [PATCH 452/492] Cleanup unused/deprecated functions in tf.layers.utils. PiperOrigin-RevId: 302527461 Change-Id: I25146c03710baa61c96c666564e3e57f8809bd1f --- tensorflow/python/BUILD | 1 - tensorflow/python/layers/utils.py | 63 +------------------------- tensorflow/python/layers/utils_test.py | 22 --------- 3 files changed, 2 insertions(+), 84 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 1669508ac4f..46a77053652 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -7089,7 +7089,6 @@ py_library( deps = [ ":control_flow_ops", ":smart_cond", - ":util", ":variables", ], ) diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index cf06180cd81..8fb5151fadf 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -18,10 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.ops import variables -from tensorflow.python.ops import control_flow_ops from tensorflow.python.framework import smart_cond as smart_module -from tensorflow.python.util import nest +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables def convert_data_format(data_format, ndim): @@ -225,61 +224,3 @@ def constant_value(pred): if isinstance(pred, variables.Variable): return None return smart_module.smart_constant_value(pred) - - -def object_list_uid(object_list): - """Creates a single string from object ids.""" - object_list = nest.flatten(object_list) - return ', '.join(str(abs(id(x))) for x in object_list) - - -def static_shape(x): - """Get the static shape of a Tensor, or None if it is unavailable.""" - if x is None: - return None - try: - return tuple(x.get_shape().as_list()) - except ValueError: - return None - - -def get_reachable_from_inputs(inputs, targets=None): - """Returns the set of tensors reachable from `inputs`. - - Stops if all targets have been found (target is optional). - - Only valid in Symbolic mode, not Eager mode. - - Args: - inputs: List of tensors. - targets: List of tensors. - - Returns: - A set of tensors reachable from the inputs (includes the inputs themselves). - """ - reachable = set(inputs) - if targets: - targets = set(targets) - queue = inputs[:] - - while queue: - x = queue.pop() - outputs = [] - try: - consumers = x.consumers() - except AttributeError: - # Case where x is a variable type - consumers = [x.op] - for z in consumers: - consumer_outputs = z.outputs - if consumer_outputs: # May be None - outputs += consumer_outputs - - for y in outputs: - if y not in reachable: - reachable.add(y) - queue.insert(0, y) - - if targets and targets.issubset(reachable): - return reachable - return reachable diff --git a/tensorflow/python/layers/utils_test.py b/tensorflow/python/layers/utils_test.py index a0cd66a1f05..15488c024ee 100644 --- a/tensorflow/python/layers/utils_test.py +++ b/tensorflow/python/layers/utils_test.py @@ -115,27 +115,5 @@ class ConstantValueTest(test.TestCase): utils.constant_value(5) -class GetReachableFromInputsTest(test.TestCase): - - @test_util.run_deprecated_v1 - def testGetReachableFromInputs(self): - - pl_1 = array_ops.placeholder(shape=None, dtype='float32') - pl_2 = array_ops.placeholder(shape=None, dtype='float32') - pl_3 = array_ops.placeholder(shape=None, dtype='float32') - x_1 = pl_1 + pl_2 - x_2 = pl_2 * 2 - x_3 = pl_3 + 1 - x_4 = x_1 + x_2 - x_5 = x_3 * pl_1 - - self.assertEqual({pl_1, x_1, x_4, x_5}, - utils.get_reachable_from_inputs([pl_1])) - self.assertEqual({pl_1, pl_2, x_1, x_2, x_4, x_5}, - utils.get_reachable_from_inputs([pl_1, pl_2])) - self.assertEqual({pl_3, x_3, x_5}, utils.get_reachable_from_inputs([pl_3])) - self.assertEqual({x_3, x_5}, utils.get_reachable_from_inputs([x_3])) - - if __name__ == '__main__': test.main() From 77527d0df17be03f76d2bdc70e0126fbba87caaa Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 23 Mar 2020 15:14:16 -0700 Subject: [PATCH 453/492] Auto-generate all unary and binary TensorFlow ops supported by tf2xla bridge PiperOrigin-RevId: 302528398 Change-Id: I1e1c7f4eafc6e08722a1a32126ea68b263110f69 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 865 ++++++++++++++++++ 1 file changed, 865 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 9feeee87374..10a2b4f9451 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -49,6 +49,47 @@ an output element, this operation computes \\(y = |x|\\). TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes acos of x element-wise."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_AcoshOp : TF_Op<"Acosh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes inverse hyperbolic cosine of x element-wise."; + + let description = [{ +Given an input tensor, the function computes inverse hyperbolic cosine of every element. +Input range is `[1, inf]`. It returns `nan` if the input lies outside the range. + +```python +x = tf.constant([-2, -0.5, 1, 1.2, 200, 10000, float("inf")]) +tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf] +``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -149,6 +190,41 @@ retained with length 1. let verifier = [{ return Verify(*this); }]; } +def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "Returns the argument of a complex number."; + + let description = [{ +Given a tensor `input` of complex numbers, this operation returns a tensor of +type `float` that is the argument of each element in `input`. All elements in +`input` must be complex numbers of the form \\(a + bj\\), where *a* +is the real part and *b* is the imaginary part. + +The argument returned by this operation is of the form \\(atan2(b, a)\\). + +For example: + +``` +# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +tf.angle(input) ==> [2.0132, 1.056] +``` + +@compatibility(numpy) +Equivalent to np.angle. +@end_compatibility + }]; + + let arguments = (ins + TensorOf<[TF_Complex128, TF_Complex64]>:$input + ); + + let results = (outs + TF_F32OrF64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; +} + def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> { let summary = [{ Computes the "logical or" of elements across dimensions of a tensor. @@ -278,6 +354,63 @@ array([b'3.14', b'2.72'], dtype=object) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_AsinOp : TF_Op<"Asin", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the trignometric inverse sine of x element-wise."; + + let description = [{ +The `tf.math.asin` operation returns the inverse of `tf.math.sin`, such that +if `y = tf.math.sin(x)` then, `x = tf.math.asin(y)`. + +**Note**: The output of `tf.math.asin` will lie within the invertible range +of sine, i.e [-pi/2, pi/2]. + +For example: + +```python +# Note: [1.047, 0.785] ~= [(pi/3), (pi/4)] +x = tf.constant([1.047, 0.785]) +y = tf.math.sin(x) # [0.8659266, 0.7068252] + +tf.math.asin(y) # [1.047, 0.785] = x +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_AsinhOp : TF_Op<"Asinh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes inverse hyperbolic sine of x element-wise."; + + let description = [{ +Given an input tensor, this function computes inverse hyperbolic sine + for every element in the tensor. Both input and output has a range of + `[-inf, inf]`. + + ```python + x = tf.constant([-float("inf"), -2, -0.5, 1, 1.2, 200, 10000, float("inf")]) + tf.math.asinh(x) ==> [-inf -1.4436355 -0.4812118 0.8813736 1.0159732 5.991471 9.903487 inf] + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AssertOp : TF_Op<"Assert", []> { let summary = "Asserts that the given condition is true."; @@ -354,6 +487,38 @@ this value or a subsequent newer value of the variable. TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<1>; } +def TF_AtanOp : TF_Op<"Atan", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the trignometric inverse tangent of x element-wise."; + + let description = [{ +The `tf.math.atan` operation returns the inverse of `tf.math.tan`, such that +if `y = tf.math.tan(x)` then, `x = tf.math.atan(y)`. + +**Note**: The output of `tf.math.atan` will lie within the invertible range +of tan, i.e (-pi/2, pi/2). + +For example: + +```python +# Note: [1.047, 0.785] ~= [(pi/3), (pi/4)] +x = tf.constant([1.047, 0.785]) +y = tf.math.tan(x) # [1.731261, 0.99920404] + +tf.math.atan(y) # [1.047, 0.785] = x +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = [{ @@ -380,6 +545,33 @@ where \(r = \sqrt(x^2 + y^2) \). TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_AtanhOp : TF_Op<"Atanh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes inverse hyperbolic tangent of x element-wise."; + + let description = [{ +Given an input tensor, this function computes inverse hyperbolic tangent + for every element in the tensor. Input range is `[-1,1]` and output range is + `[-inf, inf]`. If input is `-1`, output will be `-inf` and if the + input is `1`, output will be `inf`. Values outside the range will have + `nan` as output. + + ```python + x = tf.constant([-float("inf"), -1, -0.5, 1, 0, 0.5, 10, float("inf")]) + tf.math.atanh(x) ==> [nan -inf -0.54930615 inf 0. 0.54930615 nan nan] + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AvgPoolOp : TF_Op<"AvgPool", [NoSideEffect]> { let summary = "Performs average pooling on the input."; @@ -546,6 +738,48 @@ reverse of SpaceToBatch. See below for a precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; } +def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Bessel i0e function of `x` element-wise."; + + let description = [{ +Exponentially scaled modified Bessel function of order 0 defined as +`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. + +This function is faster and numerically stabler than `bessel_i0(x)`. + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Bessel i1e function of `x` element-wise."; + + let description = [{ +Exponentially scaled modified Bessel function of order 0 defined as +`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. + +This function is faster and numerically stabler than `bessel_i1(x)`. + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> { let summary = "Adds `bias` to `value`."; @@ -748,6 +982,44 @@ for dtype in dtype_list: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_BitwiseXorOp : TF_Op<"BitwiseXor", [Commutative, NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = "Elementwise computes the bitwise XOR of `x` and `y`."; + + let description = [{ +The result will have those bits set, that are different in `x` and `y`. The +computation is performed on the underlying representations of `x` and `y`. + +For example: + +```python +import tensorflow as tf +from tensorflow.python.ops import bitwise_ops +dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64, + tf.uint8, tf.uint16, tf.uint32, tf.uint64] + +for dtype in dtype_list: + lhs = tf.constant([0, 5, 3, 14], dtype=dtype) + rhs = tf.constant([5, 0, 7, 11], dtype=dtype) + exp = tf.constant([5, 5, 4, 5], dtype=tf.float32) + + res = bitwise_ops.bitwise_xor(lhs, rhs) + tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE +``` + }]; + + let arguments = (ins + TF_IntTensor:$x, + TF_IntTensor:$y + ); + + let results = (outs + TF_IntTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> { let summary = [{ Return the reduction indices for computing gradients of s0 op s1 with broadcast. @@ -1235,6 +1507,31 @@ Given an input tensor, this function computes cosine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CoshOp : TF_Op<"Cosh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes hyperbolic cosine of x element-wise."; + + let description = [{ +Given an input tensor, this function computes hyperbolic cosine of every + element in the tensor. Input range is `[-inf, inf]` and output range + is `[1, inf]`. + + ```python + x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 2, 10, float("inf")]) + tf.math.cosh(x) ==> [inf 4.0515420e+03 1.1276259e+00 1.5430807e+00 1.8106556e+00 3.7621956e+00 1.1013233e+04 inf] + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> { let summary = "An Op to sum inputs across replicated TPU instances."; @@ -1461,6 +1758,26 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_DigammaOp : TF_Op<"Digamma", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes Psi, the derivative of Lgamma (the log of the absolute value of + }]; + + let description = [{ +`Gamma(x)`), element-wise. + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns x / y element-wise."; @@ -1755,6 +2072,59 @@ tf.math.equal(x, y) ==> array([True, True]) }]; } +def TF_ErfOp : TF_Op<"Erf", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Gauss error function of `x` element-wise."; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_ErfcOp : TF_Op<"Erfc", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes the complementary error function of `x` element-wise. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_ErfinvOp : TF_Op<"Erfinv", [NoSideEffect]> { + let summary = ""; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ExpOp : TF_Op<"Exp", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Computes exponential of x element-wise. \\(y = e^x\\). @@ -1854,6 +2224,36 @@ size 1. ]; } +def TF_Expm1Op : TF_Op<"Expm1", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes `exp(x) - 1` element-wise."; + + let description = [{ +i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor. + `e` denotes Euler's number and is approximately equal to 2.718281. + + ```python + x = tf.constant(2.0) + tf.math.expm1(x) ==> 6.389056 + + x = tf.constant([2.0, 8.0]) + tf.math.expm1(x) ==> array([6.389056, 2979.958], dtype=float32) + + x = tf.constant(1 + 1j) + tf.math.expm1(x) ==> (0.46869393991588515+2.2873552871788423j) + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. @@ -2613,6 +3013,92 @@ def ApplyG(op, dy, _): TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; } +def TF_IgammaOp : TF_Op<"Igamma", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Compute the lower regularized incomplete Gamma function `P(a, x)`. + }]; + + let description = [{ +The lower regularized incomplete Gamma function is defined as: + + +\\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) + +where + +\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) + +is the lower incomplete Gamma function. + +Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete +Gamma function. + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = "Computes the gradient of `igamma(a, x)` wrt `a`."; + + let description = [{ + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_IgammacOp : TF_Op<"Igammac", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Compute the upper regularized incomplete Gamma function `Q(a, x)`. + }]; + + let description = [{ +The upper regularized incomplete Gamma function is defined as: + +\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) + +where + +\\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) + +is the upper incomplete Gama function. + +Note, above `P(a, x)` (`Igamma`) is the lower regularized complete +Gamma function. + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ImagOp : TF_Op<"Imag", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Returns the imaginary part of a complex number."; @@ -2799,6 +3285,60 @@ tf.math.is_finite(x) ==> [True, True, True, False, False] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_IsInfOp : TF_Op<"IsInf", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "Returns which elements of x are Inf."; + + let description = [{ +@compatibility(numpy) +Equivalent to np.isinf +@end_compatibility + +Example: + +```python +x = tf.constant([5.0, np.inf, 6.8, np.inf]) +tf.math.is_inf(x) ==> [False, True, False, True] +``` + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + I1Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_IsNanOp : TF_Op<"IsNan", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "Returns which elements of x are NaN."; + + let description = [{ +@compatibility(numpy) +Equivalent to np.isnan +@end_compatibility + +Example: + +```python +x = tf.constant([5.0, np.nan, 6.8, np.nan, np.inf]) +tf.math.is_nan(x) ==> [False, True, False, True, False] +``` + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + I1Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> { let summary = "Gets the next output from the given iterator ."; @@ -3006,6 +3546,34 @@ tf.math.less_equal(x, y) ==> [True, True, True] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_LgammaOp : TF_Op<"Lgamma", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes the log of the absolute value of `Gamma(x)` element-wise. + }]; + + let description = [{ +For positive numbers, this function computes log((input - 1)!) for every element in the tensor. + `lgamma(5) = log((5-1)!) = log(4!) = log(24) = 3.1780539` + +Example: + +```python +x = tf.constant([0, 0.5, 1, 4.5, -4, -5.6]) +tf.math.lgamma(x) ==> [inf, 0.5723649, 0., 2.4537368, inf, -4.6477685] +``` + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_LinSpaceOp : TF_Op<"LinSpace", [NoSideEffect]> { let summary = "Generates values in an interval."; @@ -4135,6 +4703,32 @@ graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.Tensor TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; } +def TF_ModOp : TF_Op<"Mod", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Returns element-wise remainder of division. This emulates C semantics in that + }]; + + let description = [{ +the result here is consistent with a truncating divide. E.g. +`tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`. + +*NOTE*: `Mod` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TF_FpOrI32OrI64Tensor:$x, + TF_FpOrI32OrI64Tensor:$y + ); + + let results = (outs + TF_FpOrI32OrI64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns x * y element-wise."; @@ -4179,6 +4773,23 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_NdtriOp : TF_Op<"Ndtri", [NoSideEffect]> { + let summary = ""; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_NegOp : TF_Op<"Neg", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes numerical negative value element-wise."; @@ -4862,6 +5473,27 @@ the dimension is padded with zeros. TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>; } +def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Computes the derivative of a Gamma random sample w.r.t. `alpha`. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$alpha, + TF_F32OrF64Tensor:$sample + ); + + let results = (outs + TF_F32OrF64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_RandomShuffleOp : TF_Op<"RandomShuffle", [SameOperandsAndResultType]> { let summary = "Randomly shuffles a tensor along its first dimension."; @@ -5108,6 +5740,26 @@ I.e., \\(y = 1 / x\\). let hasCanonicalizer = 1; } +def TF_ReciprocalGradOp : TF_Op<"ReciprocalGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the gradient for the inverse of `x` wrt its input."; + + let description = [{ +Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +is the corresponding input gradient. + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$y, + TF_FpOrComplexTensor:$dy + ); + + let results = (outs + TF_FpOrComplexTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; @@ -5632,6 +6284,32 @@ bitwise_ops.right_shift(lhs, rhs) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Returns element-wise integer closest to x."; + + let description = [{ +If the result is midway between two representable values, +the even representable is chosen. +For example: + +``` +rint(-1.5) ==> -2.0 +rint(0.5000001) ==> 1.0 +rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] +``` + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Rounds the values of a tensor to the nearest integer, element-wise. @@ -6057,6 +6735,26 @@ Specifically, `y = 1 / (1 + exp(-x))`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SigmoidGradOp : TF_Op<"SigmoidGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the gradient of the sigmoid of `x` wrt its input."; + + let description = [{ +Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and +`dy` is the corresponding input gradient. + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$y, + TF_FpOrComplexTensor:$dy + ); + + let results = (outs + TF_FpOrComplexTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns an element-wise indication of the sign of a number."; @@ -6106,6 +6804,31 @@ Given an input tensor, this function computes sine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SinhOp : TF_Op<"Sinh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes hyperbolic sine of x element-wise."; + + let description = [{ +Given an input tensor, this function computes hyperbolic sine of every + element in the tensor. Input range is `[-inf,inf]` and output range + is `[-inf,inf]`. + + ```python + x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 2, 10, float("inf")]) + tf.math.sinh(x) ==> [-inf -4.0515420e+03 -5.2109528e-01 1.1752012e+00 1.5094614e+00 3.6268604e+00 1.1013232e+04 inf] + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SizeOp : TF_Op<"Size", [NoSideEffect]> { let summary = "Returns the size of a tensor."; @@ -6251,6 +6974,59 @@ def TF_SoftplusOp : TF_Op<"Softplus", [NoSideEffect, SameOperandsAndResultType]> TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SoftplusGradOp : TF_Op<"SoftplusGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes softplus gradients for a softplus operation."; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$features + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SoftsignOp : TF_Op<"Softsign", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes softsign: `features / (abs(features) + 1)`."; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$features + ); + + let results = (outs + TF_FpTensor:$activations + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SoftsignGradOp : TF_Op<"SoftsignGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes softsign gradients for a softsign operation."; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$features + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [NoSideEffect]> { let summary = "SpaceToBatch for N-D tensors of type T."; @@ -7137,6 +7913,32 @@ variables. TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } +def TF_TanOp : TF_Op<"Tan", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes tan of x element-wise."; + + let description = [{ +Given an input tensor, this function computes tangent of every + element in the tensor. Input range is `(-inf, inf)` and + output range is `(-inf, inf)`. If input lies outside the boundary, `nan` + is returned. + + ```python + x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10000, float("inf")]) + tf.math.tan(x) ==> [nan 0.45231566 -0.5463025 1.5574077 2.572152 -1.7925274 0.32097113 nan] + ``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { let summary = "Computes hyperbolic tangent of `x` element-wise."; @@ -7969,6 +8771,32 @@ Python Semantics. let hasCanonicalizer = 1; } +def TF_TruncateModOp : TF_Op<"TruncateMod", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Returns element-wise remainder of division. This emulates C semantics in that + }]; + + let description = [{ +the result here is consistent with a truncating divide. E.g. `truncate(x / y) * +y + truncate_mod(x, y) = x`. + +*NOTE*: `TruncateMod` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TF_FpOrI32OrI64Tensor:$x, + TF_FpOrI32OrI64Tensor:$y + ); + + let results = (outs + TF_FpOrI32OrI64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> { let summary = "Finds unique elements in a 1-D tensor."; @@ -8421,6 +9249,43 @@ An op which shards the input based on the given sharding attribute. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect]> { + let summary = "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of zeros with the same shape and type as x."; From c08ebde007b0c57780d8d63c1f925c1d6a8bfc7c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 15:28:47 -0700 Subject: [PATCH 454/492] Append 'grpc' to the C++ service namespaces for google APIs PiperOrigin-RevId: 302531346 Change-Id: I6ed8a2840679d3664cfb491a7560c3e8286aa622 --- third_party/googleapis/build_rules.bzl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/googleapis/build_rules.bzl b/third_party/googleapis/build_rules.bzl index d92ed1c5a13..377d74be1ad 100644 --- a/third_party/googleapis/build_rules.bzl +++ b/third_party/googleapis/build_rules.bzl @@ -56,7 +56,7 @@ def cc_proto_library(name, deps): visibility = ["//visibility:public"], ) -def cc_grpc_library(name, srcs, deps, **kwargs): +def cc_grpc_library(name, srcs, deps, service_namespace = "grpc", **kwargs): """Generates a cc library with grpc implementation and cc proto headers Args: @@ -72,6 +72,9 @@ def cc_grpc_library(name, srcs, deps, **kwargs): generate_cc( name = codegen_grpc_target, srcs = srcs, + flags = [ + "services_namespace=" + service_namespace, + ], plugin = "@com_github_grpc_grpc//src/compiler:grpc_cpp_plugin", well_known_protos = True, generate_mocks = True, From 8abb0d29927c78c6f9ccbad2a6324cf2d07ebc85 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Mon, 23 Mar 2020 15:30:38 -0700 Subject: [PATCH 455/492] [tf.data] Completing migration to new internal APIs that make it possible to overriding policy for handling external state during iterator checkpointing. PiperOrigin-RevId: 302531704 Change-Id: I6a235f9558c948d42acbb771e831ce9adb9d6c8e --- tensorflow/core/framework/dataset.h | 19 +------------------ .../core/kernels/data/cache_dataset_ops.cc | 3 ++- .../experimental/matching_files_dataset_op.cc | 3 ++- 3 files changed, 5 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 25cc8fd759e..9cabcb08490 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -628,14 +628,6 @@ class IteratorBase { return input->SaveInternal(ctx, writer); } - // TODO(jsimsa): Remove this override when all callers are migrated to the - // override that uses SerializationContext. - Status SaveInput(IteratorStateWriter* writer, - const std::unique_ptr& input) { - SerializationContext ctx(/*params=*/{}); - return input->SaveInternal(&ctx, writer); - } - // This is needed so that sub-classes of IteratorBase can call // `RestoreInternal` on their input iterators. Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, @@ -648,16 +640,7 @@ class IteratorBase { // This method is used to store the state of the iterator in a checkpoint. // implementations have an override. virtual Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) { - return SaveInternal(writer); - } - - // TODO(jsimsa): Remove this override when all subclasses are migrated to the - // override that accepts SerializationContext and make that override pure - // virtual. - virtual Status SaveInternal(IteratorStateWriter* writer) { - return errors::Unimplemented("checkpointing is not supported"); - } + IteratorStateWriter* writer) = 0; // Restores the state of this iterator. // diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index f99ac114dc2..707800bc896 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -944,7 +944,8 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index 9ba44aaf909..90a61d72597 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -192,7 +192,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(IteratorStateWriter* writer) override { + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name("current_pattern_index"), current_pattern_index_)); From f446a69e5c340c3698ee57ad1f1902885058dd5c Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 23 Mar 2020 16:01:12 -0700 Subject: [PATCH 456/492] Fix repeat, and add tests PiperOrigin-RevId: 302538059 Change-Id: Ie8a17b4fe5818d04260cc5a3f8868178b92e0014 --- .../python/kernel_tests/array_ops_test.py | 64 ++++++++----------- .../python/kernel_tests/parsing_ops_test.py | 4 +- tensorflow/python/ops/array_ops.py | 29 ++++++--- 3 files changed, 47 insertions(+), 50 deletions(-) diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index b81ec5f36a8..ec3ed932996 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -21,6 +21,7 @@ import re import time import unittest +from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 @@ -1890,48 +1891,33 @@ class BatchGatherNdTest(test_util.TensorFlowTestCase): self.assertEqual(None, tensor_shape.dimension_value(shape[0])) -class RepeatTest(test_util.TensorFlowTestCase): +@test_util.run_all_in_graph_and_eager_modes +class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase): - @test_util.run_deprecated_v1 - def testRepeatScalar(self): - with self.test_session(): - v_tf = array_ops.repeat(constant_op.constant(3), 4) - v_np = np.repeat(3, 4) - self.assertAllEqual(v_tf.eval(), v_np) + @parameterized.parameters( + (3, 4, None), + ([[1, 2], [3, 4]], 2, None), + ([[1, 2], [3, 4]], [1, 2], 0), + ([[1, 2], [3, 4]], [1, 2], 1), + ([[1, 2], [3, 4]], 3, 1), + ([[1, 2], [3, 4]], [1, 2, 3, 4], None), + (np.ones([0, 4]), 0, 1), + (np.ones([1, 2]), [2], None), + ) + def testRepeat(self, array, repeats, axis): + array = np.array(array) - @test_util.run_deprecated_v1 - def testRepeatMatrix(self): - with self.test_session(): - x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat(constant_op.constant(x), 2) - v_np = np.repeat(x, 2) - self.assertAllEqual(v_tf.eval(), v_np) + @def_function.function( + input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)] * 2) + def repeat_fn(array, repeats): + return array_ops.repeat(array, repeats, axis) - @test_util.run_deprecated_v1 - def testRepeatMatrixAxis0(self): - with self.test_session(): - x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat( - constant_op.constant(x), constant_op.constant([1, 2]), axis=0) - v_np = np.repeat(x, [1, 2], axis=0) - self.assertAllEqual(v_tf.eval(), v_np) - - @test_util.run_deprecated_v1 - def testRepeatMatrixAxis1(self): - with self.test_session(): - x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat( - constant_op.constant(x), constant_op.constant(3), axis=1) - v_np = np.repeat(x, 3, axis=1) - self.assertAllEqual(v_tf.eval(), v_np) - - @test_util.run_deprecated_v1 - def testRepeatMatrixRepeatArray(self): - with self.test_session(): - x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat(constant_op.constant(x), [1, 2, 3, 4]) - v_np = np.repeat(x, [1, 2, 3, 4]) - self.assertAllEqual(v_tf.eval(), v_np) + v_tf = array_ops.repeat(constant_op.constant(array), repeats, axis) + v_tf_fn = repeat_fn( + constant_op.constant(array, dtype=dtypes.int32), repeats) + v_np = np.repeat(array, repeats, axis) + self.assertAllEqual(v_tf, v_np) + self.assertAllEqual(v_tf_fn, v_np) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py index 0aaead2fa2b..c94fd0fde49 100644 --- a/tensorflow/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/parsing_ops_test.py @@ -2278,13 +2278,13 @@ class ParseSequenceExampleTest(test.TestCase): serialized=ops.convert_to_tensor(original.SerializeToString()), sequence_features=sequence_features), expected_err=( - (errors_impl.OpError, ValueError), + (errors_impl.InvalidArgumentError, ValueError), # Message for batch=true: "Feature b: values and partitions are not aligned" # Message for batch=false in graph mode: "|.* do not form a valid RaggedTensor" # Message for batch=false in eager mode: - "|Dimensions 2 and 1 are not compatible")) + "|Incompatible shapes")) @test_util.run_all_in_graph_and_eager_modes diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index cbb5db77801..d286c96ec4e 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -5511,17 +5511,24 @@ def repeat_with_axis(data, repeats, axis, name=None): # If `axis` is negative, then convert it to a positive value. axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)") + # If we know that `repeats` is a scalar, then we can just tile & reshape. + if repeats.shape.num_elements() == 1: + repeats = reshape(repeats, []) + expanded = expand_dims(data, axis + 1) + tiled = tile_one_dimension(expanded, axis + 1, repeats) + result_shape = concat([ + data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:] + ], + axis=0) + return reshape(tiled, result_shape) + + # Check data Tensor shapes. if repeats.shape.ndims == 1: data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0]) - # If we know that `repeats` is a scalar, then we can just tile & reshape. - if repeats.shape.ndims == 0: - expanded = expand_dims(data, axis + 1) - tiled = tile_one_dimension(expanded, axis + 1, repeats) - result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]], - axis=0) - return reshape(tiled, result_shape) + repeats = broadcast_to(repeats, [data_shape[axis]]) + repeats_original = repeats # Broadcast the `repeats` tensor so rank(repeats) == axis + 1. if repeats.shape.ndims != axis + 1: @@ -5552,8 +5559,12 @@ def repeat_with_axis(data, repeats, axis, name=None): if axis == 0: result = masked else: - result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]], - axis=0) + repeated_dim_size = gen_math_ops._sum( + repeats_original, + axis=gen_math_ops._range(0, rank(repeats_original), 1)) + result_shape = concat( + [data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]], + axis=0) result = reshape(masked, result_shape) # Preserve shape information. From ebe2d864dd047b84d3ef807a783f77b0e094a719 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Mon, 23 Mar 2020 16:04:40 -0700 Subject: [PATCH 457/492] Disable tests that are failing on Cloud TPU. PiperOrigin-RevId: 302539019 Change-Id: I33c43937200f68ca029efb9eb0f578c48f10a40e --- tensorflow/python/distribute/tpu_strategy_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index f0429ab07ef..598793fa227 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -310,7 +310,8 @@ class TPUStrategyTest(test.TestCase): bar(1) - def test_using_external_variable_inside_tf_function(self): + # TODO(b/152251070): Re-enable once modified to work on Cloud TPU. + def disable_test_using_external_variable_inside_tf_function(self): strategy = get_tpu_strategy() dataset = dataset_ops.Dataset.range(10, output_type=dtypes.float32).batch(2) input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) @@ -329,7 +330,8 @@ class TPUStrategyTest(test.TestCase): expected_result, strategy.experimental_local_results(train_step(next(input_iterator)))) - def test_keras_metric_outside_strategy_scope_per_replica(self): + # TODO(b/152251070): Re-enable once modified to work on Cloud TPU. + def disable_test_keras_metric_outside_strategy_scope_per_replica(self): strategy = get_tpu_strategy() metric = keras.metrics.Mean("test_metric", dtype=dtypes.float32) From a58b1026bd971e63dddf72b931a4b1575b277b3b Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Mon, 23 Mar 2020 16:04:41 -0700 Subject: [PATCH 458/492] Disable flaky eager:remote_test PiperOrigin-RevId: 302539025 Change-Id: Ie8a95813a2c2eb2bc15bdcdef4c89c7c4d22fe06 --- tensorflow/python/eager/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9df6113b95f..8832f043457 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -852,7 +852,9 @@ cuda_py_test( python_version = "PY3", shard_count = 2, tags = [ + "manual", "no_oss", # This test launches local server. + "notap", # TODO(b/152224115) "optonly", # times out ], deps = [ From 3e2b03121175aebd965174f2ee55efd5aa903039 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 16:15:37 -0700 Subject: [PATCH 459/492] Cache the out-of-range status and avoid repeating reads in BufferedInputStream. PiperOrigin-RevId: 302541219 Change-Id: I0f624d6605ddb9aa21a24b809fff1d5de7c6d76e --- .../core/lib/io/buffered_inputstream.cc | 5 +- .../core/lib/io/buffered_inputstream_test.cc | 57 +++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/lib/io/buffered_inputstream.cc b/tensorflow/core/lib/io/buffered_inputstream.cc index 94479a1149f..6f268de8cac 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.cc +++ b/tensorflow/core/lib/io/buffered_inputstream.cc @@ -49,8 +49,7 @@ Status BufferedInputStream::FillBuffer() { Status s = input_stream_->ReadNBytes(size_, &buf_); pos_ = 0; limit_ = buf_.size(); - if (buf_.empty()) { - DCHECK(!s.ok()); + if (!s.ok()) { file_status_ = s; } return s; @@ -93,7 +92,7 @@ Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) { bytes_to_read); } result->clear(); - if (!file_status_.ok() && bytes_to_read > 0) { + if (pos_ == limit_ && !file_status_.ok() && bytes_to_read > 0) { return file_status_; } result->reserve(bytes_to_read); diff --git a/tensorflow/core/lib/io/buffered_inputstream_test.cc b/tensorflow/core/lib/io/buffered_inputstream_test.cc index c4af1e707b4..d6c07344ba3 100644 --- a/tensorflow/core/lib/io/buffered_inputstream_test.cc +++ b/tensorflow/core/lib/io/buffered_inputstream_test.cc @@ -30,6 +30,37 @@ static std::vector BufferSizes() { 12, 13, 14, 15, 16, 17, 18, 19, 20, 65536}; } +// This class will only return OutOfRange error once to make sure that +// BufferedInputStream is able to cache the error. +class ReadOnceInputStream : public InputStreamInterface { + public: + ReadOnceInputStream() : start_(true) {} + + virtual Status ReadNBytes(int64 bytes_to_read, tstring* result) { + if (bytes_to_read < 11) { + return errors::InvalidArgument("Not reading all bytes: ", bytes_to_read); + } + if (start_) { + *result = "0123456789"; + start_ = false; + return errors::OutOfRange("Out of range."); + } + return errors::InvalidArgument( + "Redudant call to ReadNBytes after an OutOfRange error."); + } + + int64 Tell() const override { return start_ ? 0 : 10; } + + // Resets the stream to the beginning. + Status Reset() override { + start_ = true; + return Status::OK(); + } + + private: + bool start_; +}; + TEST(BufferedInputStream, ReadLine_Empty) { Env* env = Env::Default(); string fname; @@ -196,6 +227,32 @@ TEST(BufferedInputStream, ReadNBytes) { } } +TEST(BufferedInputStream, OutOfRangeCache) { + for (auto buf_size : BufferSizes()) { + if (buf_size < 11) { + continue; + } + ReadOnceInputStream input_stream; + tstring read; + BufferedInputStream in(&input_stream, buf_size); + EXPECT_EQ(0, in.Tell()); + TF_ASSERT_OK(in.ReadNBytes(3, &read)); + EXPECT_EQ(read, "012"); + EXPECT_EQ(3, in.Tell()); + TF_ASSERT_OK((in.ReadNBytes(7, &read))); + EXPECT_EQ(read, "3456789"); + EXPECT_EQ(10, in.Tell()); + Status s = in.ReadNBytes(5, &read); + // Make sure the read is failing with OUT_OF_RANGE error. If it is failing + // with other errors, it is not caching the OUT_OF_RANGE properly. + EXPECT_EQ(error::OUT_OF_RANGE, s.code()) << s; + EXPECT_EQ(read, ""); + // Empty read shouldn't cause an error even at the end of the file. + TF_ASSERT_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + } +} + TEST(BufferedInputStream, SkipNBytes) { Env* env = Env::Default(); string fname; From 1bf7b2f0cf2802c741480530a05937a2b56bb591 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 16:33:29 -0700 Subject: [PATCH 460/492] Remove the unnecessary type check from legacy RNN code. PiperOrigin-RevId: 302544931 Change-Id: I591db053c5835529a7cf9a4e1180e0d77299c145 --- tensorflow/python/ops/rnn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 031e807e8b0..adda1f5e564 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -160,6 +161,7 @@ def _is_keras_rnn_cell(rnn_cell): # Keras cells never had zero_state method, which was from the original # interface from TF RNN cell. return (not isinstance(rnn_cell, rnn_cell_impl.RNNCell) and + isinstance(rnn_cell, base_layer.Layer) and getattr(rnn_cell, "zero_state", None) is None) From ec5c9bebf8526ea96fba9c3d1459594ab6727ab7 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Mon, 23 Mar 2020 16:35:47 -0700 Subject: [PATCH 461/492] Reorganized structure of Elementwise operations. Support of linking for two input elementwise. Added broadcast parameters. Removed ApplyMask(mul + broadcast). PiperOrigin-RevId: 302545362 Change-Id: Icb9cb94aaad448a205dc7160f4b44820081d69ca --- tensorflow/lite/delegates/gpu/metal/api.cc | 65 ++++- .../delegates/gpu/metal/compiled_model.cc | 14 +- .../gpu/metal/compute_task_descriptor.h | 4 + .../lite/delegates/gpu/metal/kernels/add.cc | 1 + .../gpu/metal/kernels/elementwise.cc | 236 +++++++++--------- .../delegates/gpu/metal/kernels/elementwise.h | 15 +- .../gpu/metal/kernels/elementwise_test.mm | 32 ++- .../lite/delegates/gpu/metal/kernels/mul.cc | 96 ------- .../lite/delegates/gpu/metal/kernels/mul.h | 5 - .../delegates/gpu/metal/kernels/mul_test.mm | 51 ---- 10 files changed, 240 insertions(+), 279 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index b2887e523a5..dedb2aa8df1 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -51,6 +51,25 @@ namespace tflite { namespace gpu { namespace metal { namespace { +bool IsWidthBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w && + inputs[1]->tensor.shape.w == 1; +} +bool IsHeightBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h && + inputs[1]->tensor.shape.h == 1; +} +bool IsChannelsBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c && + inputs[1]->tensor.shape.c == 1; +} + std::vector SelectDepthWiseConv( int id, ValueId input_id, ValueId output_id, const DepthwiseConvolution2DAttributes& attr, @@ -134,11 +153,22 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, int node_id = static_cast(node->id); auto op_type = OperationTypeFromString(node->operation.type); switch (op_type) { - case OperationType::ADD: - *tasks = Add(node_id, inputs, outputs[0], - absl::any_cast(node->operation.attributes), - options); + case OperationType::ADD: { + const auto srcs = graph.FindInputs(node_id); + ElementwiseBroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(srcs); + broadcast.height = IsHeightBroadcastedForSecondInput(srcs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs); + if (broadcast.width || broadcast.height || broadcast.channels) { + *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, + broadcast); + } else { + *tasks = Add(node_id, inputs, outputs[0], + absl::any_cast(node->operation.attributes), + options); + } break; + } case OperationType::CONCAT: { std::vector input_shapes; for (auto& input : graph.FindInputs(node->id)) { @@ -194,7 +224,18 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, absl::any_cast(node->operation.attributes), options); } else { - *tasks = ApplyMask(node_id, inputs[0], inputs[1], outputs[0], options); + if (inputs.size() == 2) { + const auto srcs = graph.FindInputs(node_id); + ElementwiseBroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(srcs); + broadcast.height = IsHeightBroadcastedForSecondInput(srcs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs); + *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], + op_type, broadcast); + } else { + return absl::UnimplementedError( + "No support of multiply with more than 2 inputs"); + } } break; case OperationType::PAD: { @@ -269,8 +310,18 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, case OperationType::SUB: { const ElementwiseAttributes* attr = absl::any_cast(&node->operation.attributes); - *tasks = - ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, attr); + if (attr) { + *tasks = ElementwiseWithOneInputAndConstantArguent( + node_id, inputs[0], outputs[0], options, op_type, *attr); + } else { + const auto srcs = graph.FindInputs(node_id); + ElementwiseBroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(srcs); + broadcast.height = IsHeightBroadcastedForSecondInput(srcs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs); + *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, + broadcast); + } } break; case OperationType::BATCH_NORMALIZATION: case OperationType::BATCH_TO_SPACE: diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc index 711ed9fed88..06cc10a0520 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc @@ -180,10 +180,16 @@ void BuildFusableChains(const std::vector& input_ids, bool fused = false; for (auto& chain : *chains) { // We can fuse only single output for now. - if (Contains(task_descriptor->input_buffers, - chain.back()->output_buffer.id) && - CanFuseOperations(chain.back(), task_descriptor, output_ids, - *descriptors, chains)) { + bool can_link = false; + if (task_descriptor->is_associative_op) { + can_link = Contains(task_descriptor->input_buffers, + chain.back()->output_buffer.id); + } else { + can_link = task_descriptor->input_buffers[0].id == + chain.back()->output_buffer.id; + } + if (can_link && CanFuseOperations(chain.back(), task_descriptor, + output_ids, *descriptors, chains)) { chain.push_back(task_descriptor); fused = true; break; diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h index 35bad273c50..923f4dcc245 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h +++ b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h @@ -99,6 +99,10 @@ struct ComputeTaskDescriptor { // $2 // output_buffer[linear_index] = value; // } + + // when operation associative, we can rearrange input tensors + // for example add is associative + bool is_associative_op = false; std::string shader_source; std::vector input_buffers; // A single per-operation output is supported now. diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add.cc b/tensorflow/lite/delegates/gpu/metal/kernels/add.cc index c857a092a53..b4a8e781c72 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add.cc @@ -86,6 +86,7 @@ std::vector Add(int id, } desc->is_linkable = true; + desc->is_associative_op = true; desc->shader_source = GetAddTableCodeFused(input_ids.size() - 1); for (int i = 0; i < input_ids.size(); ++i) { diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc index 7fdfd3257ea..9d9e054f40a 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" @@ -29,115 +30,93 @@ namespace metal { namespace { -std::string GetElementwiseWithTwoInputsCode(int src_count, - OperationType op_type, - const float* scalar) { - std::string code = R"( - #include - using namespace metal; +std::string OneInputFunctor(OperationType op_type, const std::string& value) { + const std::unordered_map functors{ + {OperationType::ABS, "abs($0)"}, + {OperationType::SIN, "sin($0)"}, + {OperationType::HARD_SWISH, + "$0 * clamp($0 / 6.0f + FLT4(0.5f), FLT4(0.0f), FLT4(1.0f))"}, + {OperationType::COS, "cos($0)"}, + {OperationType::EXP, "exp($0)"}, + {OperationType::LOG, "log($0)"}, + {OperationType::SQRT, "sqrt($0)"}, + {OperationType::RSQRT, "1.0 / sqrt($0)"}, + {OperationType::SQUARE, "$0 * $0"}, + {OperationType::SIGMOID, "1.0 / (1.0 + exp(-1.0 * $0))"}, + {OperationType::TANH, "tanh($0)"}, + }; - struct uniforms { - int4 src_size; - }; - - $0 - kernel void ComputeFunction( - $1 - uint3 gid[[thread_position_in_grid]]) { - if (static_cast(gid.x) >= params.src_size.x || - static_cast(gid.y) >= params.src_size.y) { - return; - } - - int linear_index = (int(gid.z) * params.src_size.y + int(gid.y)) * - params.src_size.x + int(gid.x); - FLT4 src_0 = src_buffer0[linear_index]; - )"; - - if (scalar == nullptr) { - code += " FLT4 src_1 = src_buffer1[linear_index];"; - } else { - code += " FLT4 src_1 = FLT4(" + std::to_string(*scalar) + ");"; + if (functors.find(op_type) == functors.end()) { + return "Error, unknown op"; } - switch (op_type) { - case OperationType::DIV: { - code += " FLT4 value = src_0 / src_1;"; - break; - } - case OperationType::MAXIMUM: { - code += " FLT4 value = max(src_0, src_1);"; - break; - } - case OperationType::MINIMUM: { - code += " FLT4 value = min(src_0, src_1);"; - break; - } - case OperationType::POW: { - code += " FLT4 value = pow(src_0, src_1);"; - break; - } - case OperationType::SQUARED_DIFF: { - code += " FLT4 value = (src_0 - src_1) * (src_0 - src_1);"; - break; - } - case OperationType::SUB: { - code += " FLT4 value = src_0 - src_1;"; - break; - } - default: { - return ""; - } - } - code += R"( - $2 - dst_buffer[linear_index] = value; - })"; - return code; + + return absl::Substitute(functors.at(op_type), value); } + +std::string TwoInputFunctor(OperationType op_type, const std::string& value0, + const std::string& value1) { + const std::unordered_map functors{ + {OperationType::ADD, "$0 + $1"}, + {OperationType::DIV, "$0 / $1"}, + {OperationType::MAXIMUM, "max($0, $1)"}, + {OperationType::MINIMUM, "min($0, $1)"}, + {OperationType::MUL, "$0 * $1"}, + {OperationType::POW, "pow($0, $1)"}, + {OperationType::SQUARED_DIFF, "($0 - $1) * ($0 - $1)"}, + {OperationType::SUB, "$0 - $1"}, + }; + + if (functors.find(op_type) == functors.end()) { + return "Error, unknown op"; + } + + return absl::Substitute(functors.at(op_type), value0, value1); +} + } // namespace std::vector ElementwiseWithTwoInputs( int id, std::vector input_ids, ValueId output_id, - OperationType op_type, const ElementwiseAttributes* attr) { - const float* scalar = nullptr; - if (attr) { - scalar = absl::get_if(&attr->param); - } + OperationType op_type, const ElementwiseBroadcastSettings& settings) { auto desc = std::make_shared(); desc->id = id; - desc->is_linkable = false; - desc->shader_source = - GetElementwiseWithTwoInputsCode(input_ids.size(), op_type, scalar); - - for (int i = 0; i < input_ids.size(); ++i) { - const std::string buffer_name = - "device FLT4* const src_buffer" + std::to_string(i); - desc->input_buffers.push_back({input_ids[i], buffer_name}); + desc->is_linkable = true; + const std::string x_coord = settings.width ? "0" : "int(gid.x)"; + const std::string y_coord = settings.height ? "0" : "int(gid.y)"; + const std::string s_coord = settings.channels ? "0" : "int(gid.z)"; + std::string code = + "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* " + "const second_tensor, int2 second_size) {\n"; + code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord + + ") * second_size.x + " + x_coord + ";\n"; + code += " FLT4 src_1 = second_tensor[second_index];\n"; + if (settings.channels) { + code += " src_1.y = src_1.x;\n"; + code += " src_1.z = src_1.x;\n"; + code += " src_1.w = src_1.x;\n"; } + code += " return " + TwoInputFunctor(op_type, "value", "src_1") + ";\n"; + code += "}\n"; - desc->output_buffer = {output_id, "device FLT4* dst_buffer", - [input_ids](const std::map& buffers) { - return buffers.find(input_ids[0])->second; - }}; + desc->shader_source = code; + + desc->input_buffers = { + {input_ids[0], "device FLT4* const"}, + {input_ids[1], "device FLT4* const"}, + }; + desc->output_buffer = {output_id}; desc->uniform_buffers = { - {"constant uniforms& params", - [input_ids](const std::map& buffers) { - const auto& dimension = buffers.find(input_ids[0])->second; - std::vector uniform_params = {dimension.w, dimension.h, 0, 0}; + {"constant int2&", + [input_ids, output_id](const std::map& buffers) { + const auto& input_dim_1 = buffers.find(input_ids[1])->second; + std::vector uniform_params{ + input_dim_1.w, + input_dim_1.h, + }; return GetByteBuffer(uniform_params); }}, }; - - desc->resize_function = [input_ids](const std::map& buffers) { - const auto& src_dim = buffers.find(input_ids[0])->second; - const uint3 groups_size{16, 16, 1}; - int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x); - int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y); - const int dst_layers = IntegralDivideRoundUp(src_dim.c, 4); - int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z); - return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); - }; return {desc}; } @@ -146,29 +125,10 @@ std::vector ElementwiseWithOneInput( auto desc = std::make_shared(); desc->id = id; desc->is_linkable = true; - - const std::unordered_map functors{ - {OperationType::ABS, "abs(value)"}, - {OperationType::SIN, "sin(value)"}, - {OperationType::HARD_SWISH, - "value * clamp(value / 6.0f + FLT4(0.5f), FLT4(0.0f), FLT4(1.0f))"}, - {OperationType::COS, "cos(value)"}, - {OperationType::EXP, "exp(value)"}, - {OperationType::LOG, "log(value)"}, - {OperationType::SQRT, "sqrt(value)"}, - {OperationType::RSQRT, "1.0 / sqrt(value)"}, - {OperationType::SQUARE, "value * value"}, - {OperationType::SIGMOID, "1.0 / (1.0 + exp(-1.0 * value))"}, - {OperationType::TANH, "tanh(value)"}, - }; - - if (functors.count(op_type) == 0) { - return {}; - } - desc->shader_source = "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {\n"; - desc->shader_source += " return " + functors.at(op_type) + ";\n"; + desc->shader_source += + " return " + OneInputFunctor(op_type, "value") + ";\n"; desc->shader_source += " }"; desc->input_buffers = {{input_id}}; @@ -176,6 +136,54 @@ std::vector ElementwiseWithOneInput( return {desc}; } +std::vector ElementwiseWithOneInputAndConstantArguent( + int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options, + OperationType op_type, const ElementwiseAttributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = true; + auto scalar = absl::get_if(&attr.param); + auto linear_buf = + absl::get_if>(&attr.param); + std::string param_desc; + if (scalar) { + param_desc += ", float scalar_val"; + } + if (linear_buf) { + param_desc += ", device FLT4* const linear_buf"; + } + desc->shader_source = + "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc + + ") {\n"; + if (scalar) { + desc->shader_source += " FLT4 second_arg = FLT4(scalar_val);\n"; + } else if (linear_buf) { + desc->shader_source += " FLT4 second_arg = linear_buf[gid.z];\n"; + } + desc->shader_source += + " return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n"; + desc->shader_source += " }"; + + desc->input_buffers = {{input_id}}; + desc->output_buffer = {output_id}; + if (scalar) { + std::vector scalar_bits = + GetByteBuffer(std::vector{*scalar}); + desc->uniform_buffers = { + {"constant float&", + [scalar_bits](const std::map& buffers) { + return scalar_bits; + }}, + }; + } else if (linear_buf) { + desc->immutable_buffers = { + {"device FLT4* const", + GetByteBufferConverted(linear_buf->data, options.storage_precision)}, + }; + } + return {desc}; +} + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h index af70e433e79..2520c2f2df4 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h @@ -25,13 +25,26 @@ namespace tflite { namespace gpu { namespace metal { +struct ElementwiseBroadcastSettings { + bool width = false; + bool height = false; + bool channels = false; +}; + +// Two inputs are two runtime tensors std::vector ElementwiseWithTwoInputs( int id, std::vector input_ids, ValueId output_id, - OperationType op_type, const ElementwiseAttributes* attr); + OperationType op_type, const ElementwiseBroadcastSettings& settings); +// One input is one runtime tensor std::vector ElementwiseWithOneInput( int id, ValueId input_id, ValueId output_id, OperationType op_type); +// First input is one runtime tensor and second input is constant argument +std::vector ElementwiseWithOneInputAndConstantArguent( + int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options, + OperationType op_type, const ElementwiseAttributes& attr); + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm index d8521ba76b1..6b30bc5c703 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm @@ -94,7 +94,7 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { - (void)testExp { OperationType op_type = OperationType::EXP; - const BHWC shape(1, 1, 1, 5); + const BHWC shape(1, 1, 1, 7); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, /*inputs=*/{GetTensorRef(0, shape)}, /*outputs=*/{GetTensorRef(1, shape)}); @@ -312,4 +312,34 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } +- (void)testMulBroadcastChannels { + OperationType op_type = OperationType::MUL; + const BHWC shape(1, 1, 2, 2); + const BHWC shape_2(1, 1, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape_2)}, + /*outputs=*/{GetTensorRef(2, shape)}); + XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); + XCTAssertTrue(model.PopulateTensor(1, {2.0, 3.0})); + auto status = model.Invoke(); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + status = CompareVectors({2.0, 4.0, 9.0, 12.0}, model.GetOutput(0), 1e-6f); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); +} + +- (void)testMulBroadcastWidthAndHeight { + OperationType op_type = OperationType::MUL; + const BHWC shape(1, 1, 2, 2); + const BHWC shape_2(1, 1, 1, 2); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape_2)}, + /*outputs=*/{GetTensorRef(2, shape)}); + XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); + XCTAssertTrue(model.PopulateTensor(1, {2.0, 3.0})); + auto status = model.Invoke(); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + status = CompareVectors({2.0, 6.0, 6.0, 12.0}, model.GetOutput(0), 1e-6f); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); +} + @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc b/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc index 21a04f2fc35..e90ab6b4f12 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc @@ -35,102 +35,6 @@ limitations under the License. namespace tflite { namespace gpu { namespace metal { -namespace { - -std::string GetApplyMaskCode() { - std::string shader_source = R"( - #include - using namespace metal; - struct uniforms { - int4 src_0_size; - int4 src_1_size; - int4 dst_size; - }; - - $0 - kernel void ComputeFunction( - $1 - uint3 gid[[thread_position_in_grid]]) { - int X = static_cast(gid.x); - int Y = static_cast(gid.y); - if (X >= params.dst_size.x || Y >= params.dst_size.y) { - return; - } - int src_0_index = (gid.z * params.src_0_size.y + static_cast(gid.y)) * - params.src_0_size.x + static_cast(gid.x); - int src_1_index = 0; - if (params.dst_size.z == 1) { - // [H, W, C] x [H, W, 0][0] - src_1_index = static_cast(gid.y) * params.src_1_size.x + - static_cast(gid.x); - } else if (params.src_0_size.y == params.src_1_size.y && - params.src_0_size.x == params.src_1_size.x) { - // [H, W, C] x [H, W, C] - src_1_index = src_0_index; - } else { - // [H, W, C] x [0, 0, C] - src_1_index = gid.z * params.src_1_size.y * params.src_1_size.x ; - } - FLT4 value = src_buffer_0[src_0_index] * src_buffer_1[src_1_index]; - int linear_index = (gid.z * params.dst_size.y + static_cast(gid.y)) * - params.dst_size.x + static_cast(gid.x); - $2 - dst_buffer[linear_index] = value; - } - )"; - return shader_source; -} -} // namespace - -std::vector ApplyMask(int id, ValueId input_id_0, - ValueId input_id_1, - ValueId output_id, - const RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - desc->shader_source = GetApplyMaskCode(); - - desc->input_buffers = { - {input_id_0, "device FLT4* const src_buffer_0"}, // data - {input_id_1, "device FLT4* const src_buffer_1"}, // mask - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id_0, input_id_1](const std::map& buffers) { - return buffers.find(input_id_0)->second; - }}; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id_0, input_id_1, - output_id](const std::map& buffers) { - const auto& input_dim_0 = buffers.find(input_id_0)->second; - const auto& input_dim_1 = buffers.find(input_id_1)->second; - const auto& output_dim = buffers.find(output_id)->second; - std::vector uniform_params{ - input_dim_0.w, input_dim_0.h, input_dim_0.c, 0, - input_dim_1.w, input_dim_1.h, input_dim_1.c, 0, - output_dim.w, output_dim.h, output_dim.c, 0, - }; - return GetByteBuffer(uniform_params); - }}, - }; - - desc->resize_function = [input_id_0, - input_id_1](const std::map& buffers) { - const auto& src_shape = buffers.find(input_id_0)->second; - const uint3 groups_size{16, 16, 1}; - int groups_x = IntegralDivideRoundUp(src_shape.w, groups_size.x); - int groups_y = IntegralDivideRoundUp(src_shape.h, groups_size.y); - int groups_z = IntegralDivideRoundUp(src_shape.c, 4); - return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); - }; - - return {desc}; -} - std::vector Multiply(int id, ValueId input_id, ValueId output_id, const MultiplyAttributes& attr, diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul.h b/tensorflow/lite/delegates/gpu/metal/kernels/mul.h index bc83b149e78..b5ff37cf560 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul.h @@ -30,11 +30,6 @@ std::vector Multiply(int id, ValueId input_id, ValueId output_id, const MultiplyAttributes& attr, const RuntimeOptions& options); - -std::vector ApplyMask(int id, ValueId input_id_0, - ValueId input_id_1, - ValueId output_id, - const RuntimeOptions& options); } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm index f69598bad5b..d881950c831 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm @@ -95,55 +95,4 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - -- (void)testApplyMaskChannel1 { - TensorRef input; - input.type = DataType::FLOAT32; - input.ref = 0; - input.shape = BHWC(1, 1, 2, 2); - - TensorRef mask; - mask.type = DataType::FLOAT32; - mask.ref = 1; - mask.shape = BHWC(1, 1, 2, 1); - - TensorRef output; - output.type = DataType::FLOAT32; - output.ref = 2; - output.shape = BHWC(1, 1, 2, 2); - - SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output}); - XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); - XCTAssertTrue(model.PopulateTensor(1, {2, 3})); - auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); - status = CompareVectors({2, 4, 9, 12}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); -} - -- (void)testApplyMaskEqualsToInputChannel { - TensorRef input; - input.type = DataType::FLOAT32; - input.ref = 0; - input.shape = BHWC(1, 1, 2, 2); - - TensorRef mask; - mask.type = DataType::FLOAT32; - mask.ref = 1; - mask.shape = BHWC(1, 1, 2, 2); - - TensorRef output; - output.type = DataType::FLOAT32; - output.ref = 2; - output.shape = BHWC(1, 1, 2, 2); - - SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output}); - XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); - XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4})); - auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); - status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); -} - @end From b5c725611842a2abb3d504947a7895341c231d76 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 23 Mar 2020 16:59:53 -0700 Subject: [PATCH 462/492] Internal cleanup: retire support for converting whole classes, which is not supported in TF2 and is unlikely to have real uses. PiperOrigin-RevId: 302550027 Change-Id: I0127d56a3cc45ea38205df9a88480a75d25ea39c --- .../python/autograph/impl/conversion.py | 153 ++---------------- .../python/autograph/impl/conversion_test.py | 53 ------ 2 files changed, 13 insertions(+), 193 deletions(-) diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 3e062f2eba0..e14c8e2bfcf 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -52,12 +52,11 @@ from tensorflow.python.autograph.core import naming from tensorflow.python.autograph.core import unsupported_features_checker from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.pyct import ast_util -from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import inspect_utils +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import pretty_printer -from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import templates from tensorflow.python.autograph.pyct import transformer from tensorflow.python.autograph.utils import ag_logging as logging @@ -299,15 +298,8 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): """Creates a converted instance and binds it to match original entity.""" factory = converted_entity_info.get_factory() - # `factory` is currently bound to the empty module it was loaded from. - # It must instead be bound to the globals and closure from the original - # entity. - if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): - entity_globals = entity.__globals__ - entity_closure = entity.__closure__ or () - elif hasattr(entity, '__module__'): - entity_globals = sys.modules[entity.__module__].__dict__ - entity_closure = () + entity_globals = entity.__globals__ + entity_closure = entity.__closure__ or () assert len(entity_closure) == len(free_nonglobal_var_names) # Fit the original entity's cells to match the order of factory's cells. @@ -328,11 +320,10 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): ag_internal, converted_entity_info.source_map, converted_entity_info.get_module()) - if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): - # Attach the default argument to the converted function. - converted_entity.__defaults__ = entity.__defaults__ - if hasattr(entity, '__kwdefaults__'): - converted_entity.__kwdefaults__ = entity.__kwdefaults__ + # Attach the default argument to the converted function. + converted_entity.__defaults__ = entity.__defaults__ + if hasattr(entity, '__kwdefaults__'): + converted_entity.__kwdefaults__ = entity.__kwdefaults__ return converted_entity @@ -340,14 +331,11 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): def convert(entity, program_ctx): """Converts an entity into an equivalent entity.""" - if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): - if not hasattr(entity, '__code__'): - raise ValueError('Cannot apply autograph to a function that doesn\'t ' - 'expose a __code__ object. If this is a @tf.function,' - ' try passing f.python_function instead.') - free_nonglobal_var_names = entity.__code__.co_freevars - else: - free_nonglobal_var_names = () + if not hasattr(entity, '__code__'): + raise ValueError('Cannot apply autograph to a function that doesn\'t ' + 'expose a __code__ object. If this is a @tf.function,' + ' try passing f.python_function instead.') + free_nonglobal_var_names = entity.__code__.co_freevars for i, name in enumerate(free_nonglobal_var_names): if (name == 'ag__' and @@ -505,22 +493,7 @@ def convert_entity_to_ast(o, program_ctx): """ logging.log(1, 'Converting %s', o) - if tf_inspect.isclass(o): - nodes, name, entity_info = convert_class_to_ast(o, program_ctx) - elif tf_inspect.isfunction(o): - nodes, name, entity_info = convert_func_to_ast(o, program_ctx) - elif tf_inspect.ismethod(o): - nodes, name, entity_info = convert_func_to_ast(o, program_ctx) - elif hasattr(o, '__class__'): - # Note: this should only be raised when attempting to convert the object - # directly. converted_call should still support it. - raise NotImplementedError( - 'cannot convert entity "{}": object conversion is not yet' - ' supported.'.format(o)) - else: - raise NotImplementedError( - 'Entity "%s" has unsupported type "%s". Only functions and classes are ' - 'supported for now.' % (o, type(o))) + nodes, name, entity_info = convert_func_to_ast(o, program_ctx) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) @@ -532,106 +505,6 @@ def convert_entity_to_ast(o, program_ctx): return nodes, name, entity_info -def convert_class_to_ast(c, program_ctx): - """Specialization of `convert_entity_to_ast` for classes.""" - # TODO(mdan): Revisit this altogether. Not sure we still need it. - converted_members = {} - method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) - members = tf_inspect.getmembers(c, predicate=method_filter) - if not members: - raise ValueError('cannot convert %s: no member methods' % c) - - # TODO(mdan): Don't clobber namespaces for each method in one class namespace. - # The assumption that one namespace suffices for all methods only holds if - # all methods were defined in the same module. - # If, instead, functions are imported from multiple modules and then spliced - # into the class, then each function has its own globals and __future__ - # imports that need to stay separate. - - # For example, C's methods could both have `global x` statements referring to - # mod1.x and mod2.x, but using one namespace for C would cause a conflict. - # from mod1 import f1 - # from mod2 import f2 - # class C(object): - # method1 = f1 - # method2 = f2 - - class_namespace = {} - future_features = None - for _, m in members: - # Only convert the members that are directly defined by the class. - if inspect_utils.getdefiningclass(m, c) is not c: - continue - (node,), _, entity_info = convert_func_to_ast( - m, program_ctx=program_ctx, do_rename=False) - class_namespace.update(entity_info.namespace) - converted_members[m] = node - - # TODO(mdan): Similarly check the globals. - if future_features is None: - future_features = entity_info.future_features - elif frozenset(future_features) ^ frozenset(entity_info.future_features): - # Note: we can support this case if ever needed. - raise ValueError( - 'cannot convert {}: if has methods built with mismatched future' - ' features: {} and {}'.format(c, future_features, - entity_info.future_features)) - namer = naming.Namer(class_namespace) - class_name = namer.class_name(c.__name__) - - # Process any base classes: if the superclass if of a whitelisted type, an - # absolute import line is generated. - output_nodes = [] - renames = {} - base_names = [] - for base in c.__bases__: - if isinstance(object, base): - base_names.append('object') - continue - if is_whitelisted(base): - alias = namer.new_symbol(base.__name__, ()) - output_nodes.append( - gast.ImportFrom( - module=base.__module__, - names=[gast.alias(name=base.__name__, asname=alias)], - level=0)) - else: - raise NotImplementedError( - 'Conversion of classes that do not directly extend classes from' - ' whitelisted modules is temporarily suspended. If this breaks' - ' existing code please notify the AutoGraph team immediately.') - base_names.append(alias) - renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) - - # Generate the definition of the converted class. - bases = [ - gast.Name(n, ctx=gast.Load(), annotation=None, type_comment=None) - for n in base_names] - class_def = gast.ClassDef( - class_name, - bases=bases, - keywords=[], - body=list(converted_members.values()), - decorator_list=[]) - # Make a final pass to replace references to the class or its base classes. - # Most commonly, this occurs when making super().__init__() calls. - # TODO(mdan): Making direct references to superclass' superclass will fail. - class_def = qual_names.resolve(class_def) - renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) - class_def = ast_util.rename_symbols(class_def, renames) - - output_nodes.append(class_def) - - # TODO(mdan): Find a way better than forging this object. - entity_info = transformer.EntityInfo( - source_code=None, - source_file=None, - future_features=future_features, - namespace=class_namespace) - - return output_nodes, class_name, entity_info - - def _add_reserved_symbol(namespace, name, entity): if name not in namespace: namespace[name] = entity diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py index 2453a51993c..b0c1e45cc45 100644 --- a/tensorflow/python/autograph/impl/conversion_test.py +++ b/tensorflow/python/autograph/impl/conversion_test.py @@ -36,7 +36,6 @@ from tensorflow.python.autograph.impl.testing import pybind_for_testing from tensorflow.python.autograph.pyct import parser from tensorflow.python.eager import function from tensorflow.python.framework import constant_op -from tensorflow.python.keras.engine import training from tensorflow.python.platform import test @@ -127,11 +126,6 @@ class ConversionTest(test.TestCase): # Note: currently, native bindings are whitelisted by a separate check. self.assertFalse(conversion.is_whitelisted(test_object.method)) - def test_convert_entity_to_ast_unsupported_types(self): - with self.assertRaises(NotImplementedError): - program_ctx = self._simple_program_ctx() - conversion.convert_entity_to_ast('dummy', program_ctx) - def test_convert_entity_to_ast_callable(self): b = 2 @@ -174,53 +168,6 @@ class ConversionTest(test.TestCase): f_node, = nodes self.assertEqual('tf__f', f_node.name) - def test_convert_entity_to_ast_class_hierarchy(self): - - class TestBase(object): - - def __init__(self, x='base'): - self.x = x - - def foo(self): - return self.x - - def bar(self): - return self.x - - class TestSubclass(TestBase): - - def __init__(self, y): - super(TestSubclass, self).__init__('sub') - self.y = y - - def foo(self): - return self.y - - def baz(self): - return self.y - - program_ctx = self._simple_program_ctx() - with self.assertRaisesRegex(NotImplementedError, 'classes.*whitelisted'): - conversion.convert_entity_to_ast(TestSubclass, program_ctx) - - def test_convert_entity_to_ast_class_hierarchy_whitelisted(self): - - class TestSubclass(training.Model): - - def __init__(self, y): - super(TestSubclass, self).__init__() - self.built = False - - def call(self, x): - return 3 * x - - program_ctx = self._simple_program_ctx() - (import_node, class_node), name, _ = conversion.convert_entity_to_ast( - TestSubclass, program_ctx) - self.assertEqual(import_node.names[0].name, 'Model') - self.assertEqual(name, 'TfTestSubclass') - self.assertEqual(class_node.name, 'TfTestSubclass') - def test_convert_entity_to_ast_lambda(self): b = 2 f = lambda x: b * x if x > 0 else -x From f18fa5b6b0953e3abaf2b537a4aa0774bf160783 Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Mon, 23 Mar 2020 17:05:30 -0700 Subject: [PATCH 463/492] TFLite GPU: Replace tflite::gpu::Status with absl::Status. PiperOrigin-RevId: 302551400 Change-Id: Ib36038b364fda986c12543576471c07eef87db14 --- tensorflow/lite/delegates/gpu/BUILD | 6 + tensorflow/lite/delegates/gpu/api.h | 27 +- tensorflow/lite/delegates/gpu/cl/api.cc | 198 ++- tensorflow/lite/delegates/gpu/cl/api.h | 4 +- tensorflow/lite/delegates/gpu/cl/buffer.cc | 22 +- tensorflow/lite/delegates/gpu/cl/buffer.h | 28 +- .../lite/delegates/gpu/cl/cl_command_queue.cc | 101 +- .../lite/delegates/gpu/cl/cl_command_queue.h | 47 +- .../lite/delegates/gpu/cl/cl_context.cc | 23 +- tensorflow/lite/delegates/gpu/cl/cl_context.h | 9 +- tensorflow/lite/delegates/gpu/cl/cl_device.cc | 8 +- tensorflow/lite/delegates/gpu/cl/cl_device.h | 8 +- tensorflow/lite/delegates/gpu/cl/cl_errors.h | 7 +- tensorflow/lite/delegates/gpu/cl/cl_kernel.cc | 68 +- tensorflow/lite/delegates/gpu/cl/cl_kernel.h | 30 +- .../lite/delegates/gpu/cl/cl_program.cc | 54 +- tensorflow/lite/delegates/gpu/cl/cl_program.h | 18 +- tensorflow/lite/delegates/gpu/cl/egl_sync.cc | 17 +- tensorflow/lite/delegates/gpu/cl/egl_sync.h | 6 +- .../lite/delegates/gpu/cl/environment.cc | 24 +- .../lite/delegates/gpu/cl/environment.h | 4 +- .../lite/delegates/gpu/cl/gl_interop.cc | 74 +- tensorflow/lite/delegates/gpu/cl/gl_interop.h | 38 +- .../lite/delegates/gpu/cl/gpu_api_delegate.cc | 18 +- .../delegates/gpu/cl/inference_context.cc | 71 +- .../lite/delegates/gpu/cl/inference_context.h | 40 +- .../lite/delegates/gpu/cl/kernels/add.cc | 6 +- .../lite/delegates/gpu/cl/kernels/add.h | 4 +- .../lite/delegates/gpu/cl/kernels/cl_test.cc | 32 +- .../lite/delegates/gpu/cl/kernels/cl_test.h | 26 +- .../delegates/gpu/cl/kernels/concat_xy.cc | 10 +- .../lite/delegates/gpu/cl/kernels/concat_xy.h | 8 +- .../lite/delegates/gpu/cl/kernels/concat_z.cc | 13 +- .../lite/delegates/gpu/cl/kernels/concat_z.h | 8 +- .../lite/delegates/gpu/cl/kernels/conv_3d.cc | 20 +- .../lite/delegates/gpu/cl/kernels/conv_3d.h | 46 +- .../gpu/cl/kernels/conv_buffer_1x1.cc | 39 +- .../gpu/cl/kernels/conv_buffer_1x1.h | 77 +- .../gpu/cl/kernels/conv_constants.cc | 23 +- .../delegates/gpu/cl/kernels/conv_constants.h | 29 +- .../delegates/gpu/cl/kernels/conv_powervr.cc | 39 +- .../delegates/gpu/cl/kernels/conv_powervr.h | 83 +- .../delegates/gpu/cl/kernels/conv_texture.cc | 34 +- .../delegates/gpu/cl/kernels/conv_texture.h | 75 +- .../delegates/gpu/cl/kernels/converter.cc | 79 +- .../gpu/cl/kernels/convolution_transposed.cc | 22 +- .../gpu/cl/kernels/convolution_transposed.h | 25 +- .../cl/kernels/convolution_transposed_3d.cc | 15 +- .../cl/kernels/convolution_transposed_3d.h | 20 +- .../cl/kernels/convolution_transposed_3x3.cc | 21 +- .../cl/kernels/convolution_transposed_3x3.h | 16 +- .../convolution_transposed_3x3_thin.cc | 19 +- .../kernels/convolution_transposed_3x3_thin.h | 18 +- .../cl/kernels/convolution_transposed_4x4.cc | 18 +- .../cl/kernels/convolution_transposed_4x4.h | 16 +- .../cl/kernels/convolution_transposed_thin.cc | 16 +- .../cl/kernels/convolution_transposed_thin.h | 18 +- .../gpu/cl/kernels/depth_wise_conv.cc | 21 +- .../gpu/cl/kernels/depth_wise_conv.h | 25 +- .../gpu/cl/kernels/depth_wise_conv_3d.cc | 14 +- .../gpu/cl/kernels/depth_wise_conv_3d.h | 20 +- .../gpu/cl/kernels/depth_wise_conv_3x3.cc | 28 +- .../gpu/cl/kernels/depth_wise_conv_3x3.h | 27 +- .../delegates/gpu/cl/kernels/elementwise.cc | 4 +- .../delegates/gpu/cl/kernels/elementwise.h | 2 +- .../gpu/cl/kernels/fully_connected.cc | 17 +- .../gpu/cl/kernels/fully_connected.h | 25 +- .../delegates/gpu/cl/kernels/gpu_operation.cc | 17 +- .../delegates/gpu/cl/kernels/gpu_operation.h | 28 +- .../lite/delegates/gpu/cl/kernels/lstm.cc | 11 +- .../lite/delegates/gpu/cl/kernels/lstm.h | 8 +- .../delegates/gpu/cl/kernels/max_unpooling.cc | 22 +- .../delegates/gpu/cl/kernels/max_unpooling.h | 16 +- .../lite/delegates/gpu/cl/kernels/mean.cc | 8 +- .../lite/delegates/gpu/cl/kernels/mean.h | 6 +- .../delegates/gpu/cl/kernels/multiply_add.cc | 48 +- .../delegates/gpu/cl/kernels/multiply_add.h | 80 +- .../lite/delegates/gpu/cl/kernels/padding.cc | 10 +- .../lite/delegates/gpu/cl/kernels/padding.h | 8 +- .../lite/delegates/gpu/cl/kernels/pooling.cc | 24 +- .../lite/delegates/gpu/cl/kernels/pooling.h | 16 +- .../lite/delegates/gpu/cl/kernels/prelu.cc | 14 +- .../lite/delegates/gpu/cl/kernels/prelu.h | 22 +- .../gpu/cl/kernels/quantize_and_dequantize.cc | 14 +- .../gpu/cl/kernels/quantize_and_dequantize.h | 19 +- .../lite/delegates/gpu/cl/kernels/relu.cc | 4 +- .../lite/delegates/gpu/cl/kernels/relu.h | 2 +- .../lite/delegates/gpu/cl/kernels/reshape.cc | 11 +- .../lite/delegates/gpu/cl/kernels/reshape.h | 8 +- .../delegates/gpu/cl/kernels/reshapex4.cc | 11 +- .../lite/delegates/gpu/cl/kernels/reshapex4.h | 8 +- .../lite/delegates/gpu/cl/kernels/resize.cc | 20 +- .../lite/delegates/gpu/cl/kernels/resize.h | 16 +- .../lite/delegates/gpu/cl/kernels/softmax.cc | 10 +- .../lite/delegates/gpu/cl/kernels/softmax.h | 8 +- .../delegates/gpu/cl/kernels/softmax1x1.cc | 4 +- .../delegates/gpu/cl/kernels/softmax1x1.h | 4 +- .../gpu/cl/kernels/space_to_depth.cc | 8 +- .../delegates/gpu/cl/kernels/space_to_depth.h | 8 +- .../delegates/gpu/cl/kernels/strided_slice.cc | 10 +- .../delegates/gpu/cl/kernels/strided_slice.h | 8 +- .../delegates/gpu/cl/kernels/transpose.cc | 11 +- .../lite/delegates/gpu/cl/kernels/transpose.h | 8 +- .../lite/delegates/gpu/cl/kernels/winograd.cc | 45 +- .../lite/delegates/gpu/cl/kernels/winograd.h | 38 +- .../gpu/cl/kernels/work_group_picking.cc | 49 +- .../gpu/cl/kernels/work_group_picking.h | 24 +- .../lite/delegates/gpu/cl/linear_storage.cc | 22 +- .../lite/delegates/gpu/cl/linear_storage.h | 36 +- .../lite/delegates/gpu/cl/opencl_wrapper.cc | 8 +- .../lite/delegates/gpu/cl/opencl_wrapper.h | 2 +- .../lite/delegates/gpu/cl/program_cache.cc | 33 +- .../lite/delegates/gpu/cl/program_cache.h | 19 +- .../gpu/cl/selectors/convolution_selector.cc | 93 +- .../gpu/cl/selectors/convolution_selector.h | 20 +- .../convolution_transposed_selector.cc | 21 +- .../convolution_transposed_selector.h | 8 +- .../cl/selectors/default/default_selector.cc | 13 +- .../gpu/cl/selectors/default_selector.h | 11 +- .../cl/selectors/dw_convolution_selector.cc | 38 +- .../cl/selectors/dw_convolution_selector.h | 8 +- .../cl/selectors/fully_connected_selector.cc | 40 +- .../cl/selectors/fully_connected_selector.h | 8 +- .../gpu/cl/selectors/operation_selector.cc | 61 +- .../gpu/cl/selectors/operation_selector.h | 11 +- .../gpu/cl/selectors/simple_selectors.cc | 83 +- .../gpu/cl/selectors/simple_selectors.h | 59 +- .../delegates/gpu/cl/storage_type_util.cc | 1 - tensorflow/lite/delegates/gpu/cl/tensor.cc | 151 +-- tensorflow/lite/delegates/gpu/cl/tensor.h | 62 +- .../lite/delegates/gpu/cl/tensor_test.cc | 25 +- .../gpu/cl/testing/performance_profiling.cc | 19 +- tensorflow/lite/delegates/gpu/cl/texture2d.cc | 26 +- tensorflow/lite/delegates/gpu/cl/texture2d.h | 36 +- tensorflow/lite/delegates/gpu/common/BUILD | 5 +- .../lite/delegates/gpu/common/convert.cc | 90 +- .../lite/delegates/gpu/common/convert.h | 24 +- .../delegates/gpu/common/custom_parsers.cc | 8 +- .../delegates/gpu/common/custom_parsers.h | 6 +- .../delegates/gpu/common/memory_management.cc | 35 +- .../delegates/gpu/common/memory_management.h | 21 +- .../memory_management/equality_assignment.h | 8 +- .../greedy_by_breadth_assignment.cc | 6 +- .../greedy_by_breadth_assignment.h | 2 +- .../greedy_by_size_assignment.cc | 14 +- .../greedy_by_size_assignment.h | 4 +- .../greedy_in_order_assignment.h | 10 +- .../min_cost_flow_assignment.cc | 4 +- .../min_cost_flow_assignment.h | 2 +- .../memory_management/naive_assignment.h | 4 +- tensorflow/lite/delegates/gpu/common/model.h | 117 +- .../delegates/gpu/common/model_builder.cc | 1096 ++++++++--------- .../lite/delegates/gpu/common/model_builder.h | 16 +- .../lite/delegates/gpu/common/operations.cc | 15 +- .../lite/delegates/gpu/common/operations.h | 5 +- tensorflow/lite/delegates/gpu/common/status.h | 108 +- .../gpu/common/testing/interpreter_utils.cc | 35 +- .../gpu/common/testing/interpreter_utils.h | 16 +- .../transformations/add_quant_adjustments.cc | 2 +- .../transformations/fuse_add_to_conv.cc | 8 +- .../transformations/fuse_mul_to_conv.cc | 8 +- .../common/transformations/make_padding.cc | 6 +- .../match_dilated_convolution.cc | 2 +- .../transformations/merge_padding_with.cc | 8 +- .../gpu/common/transformations/remove_noop.cc | 8 +- .../gpu/common/workgroup_selection.cc | 13 +- .../gpu/common/workgroup_selection.h | 7 +- tensorflow/lite/delegates/gpu/delegate.cc | 35 +- tensorflow/lite/delegates/gpu/gl/api.cc | 90 +- tensorflow/lite/delegates/gpu/gl/api.h | 24 +- tensorflow/lite/delegates/gpu/gl/api2.cc | 197 ++- tensorflow/lite/delegates/gpu/gl/api2.h | 4 +- .../lite/delegates/gpu/gl/command_queue.cc | 18 +- .../lite/delegates/gpu/gl/command_queue.h | 8 +- tensorflow/lite/delegates/gpu/gl/compiler.cc | 18 +- tensorflow/lite/delegates/gpu/gl/compiler.h | 8 +- .../gpu/gl/compiler/compiled_node.cc | 6 +- .../delegates/gpu/gl/compiler/compiled_node.h | 4 +- .../delegates/gpu/gl/compiler/preprocessor.cc | 16 +- .../delegates/gpu/gl/compiler/preprocessor.h | 2 +- .../lite/delegates/gpu/gl/compiler/rename.cc | 8 +- .../lite/delegates/gpu/gl/compiler/rename.h | 2 +- .../gpu/gl/compiler/shader_codegen.cc | 17 +- .../gpu/gl/compiler/shader_codegen.h | 3 +- .../gpu/gl/converters/bhwc_to_phwc4.cc | 18 +- .../gpu/gl/converters/bhwc_to_phwc4.h | 8 +- .../gpu/gl/converters/bhwc_to_phwc4_test.cc | 6 +- .../gpu/gl/converters/phwc4_to_bhwc.cc | 18 +- .../gpu/gl/converters/phwc4_to_bhwc.h | 8 +- .../gpu/gl/converters/phwc4_to_bhwc_test.cc | 6 +- .../lite/delegates/gpu/gl/egl_context.cc | 42 +- .../lite/delegates/gpu/gl/egl_context.h | 18 +- .../lite/delegates/gpu/gl/egl_environment.cc | 35 +- .../lite/delegates/gpu/gl/egl_environment.h | 10 +- .../lite/delegates/gpu/gl/egl_surface.cc | 11 +- .../lite/delegates/gpu/gl/egl_surface.h | 6 +- tensorflow/lite/delegates/gpu/gl/gl_buffer.cc | 24 +- tensorflow/lite/delegates/gpu/gl/gl_buffer.h | 68 +- .../lite/delegates/gpu/gl/gl_buffer_test.cc | 2 +- tensorflow/lite/delegates/gpu/gl/gl_call.h | 27 +- tensorflow/lite/delegates/gpu/gl/gl_errors.cc | 42 +- tensorflow/lite/delegates/gpu/gl/gl_errors.h | 4 +- .../lite/delegates/gpu/gl/gl_program.cc | 58 +- tensorflow/lite/delegates/gpu/gl/gl_program.h | 13 +- tensorflow/lite/delegates/gpu/gl/gl_shader.cc | 12 +- tensorflow/lite/delegates/gpu/gl/gl_shader.h | 6 +- tensorflow/lite/delegates/gpu/gl/gl_sync.cc | 22 +- tensorflow/lite/delegates/gpu/gl/gl_sync.h | 12 +- .../lite/delegates/gpu/gl/gl_texture.cc | 82 +- tensorflow/lite/delegates/gpu/gl/gl_texture.h | 50 +- .../lite/delegates/gpu/gl/kernels/add.cc | 12 +- .../lite/delegates/gpu/gl/kernels/concat.cc | 35 +- .../lite/delegates/gpu/gl/kernels/conv.cc | 20 +- .../delegates/gpu/gl/kernels/converter.cc | 86 +- .../gpu/gl/kernels/converter_test.cc | 12 +- .../gpu/gl/kernels/depthwise_conv.cc | 6 +- .../delegates/gpu/gl/kernels/elementwise.cc | 21 +- .../gpu/gl/kernels/fully_connected.cc | 6 +- .../lite/delegates/gpu/gl/kernels/lstm.cc | 6 +- .../delegates/gpu/gl/kernels/max_unpooling.cc | 6 +- .../lite/delegates/gpu/gl/kernels/mean.cc | 8 +- .../lite/delegates/gpu/gl/kernels/mul.cc | 18 +- .../lite/delegates/gpu/gl/kernels/pad.cc | 12 +- .../lite/delegates/gpu/gl/kernels/pooling.cc | 24 +- .../lite/delegates/gpu/gl/kernels/prelu.cc | 25 +- .../gpu/gl/kernels/quantize_and_dequantize.cc | 6 +- .../lite/delegates/gpu/gl/kernels/registry.cc | 10 +- .../lite/delegates/gpu/gl/kernels/relu.cc | 6 +- .../lite/delegates/gpu/gl/kernels/reshape.cc | 10 +- .../lite/delegates/gpu/gl/kernels/resize.cc | 18 +- .../lite/delegates/gpu/gl/kernels/slice.cc | 6 +- .../lite/delegates/gpu/gl/kernels/softmax.cc | 22 +- .../gpu/gl/kernels/space_to_depth.cc | 6 +- .../delegates/gpu/gl/kernels/test_util.cc | 10 +- .../lite/delegates/gpu/gl/kernels/test_util.h | 8 +- .../gpu/gl/kernels/transpose_conv.cc | 12 +- .../lite/delegates/gpu/gl/node_shader.h | 4 +- .../lite/delegates/gpu/gl/object_manager.cc | 19 +- .../lite/delegates/gpu/gl/object_manager.h | 14 +- .../lite/delegates/gpu/gl/request_gpu_info.cc | 4 +- .../lite/delegates/gpu/gl/request_gpu_info.h | 2 +- tensorflow/lite/delegates/gpu/gl/runtime.cc | 128 +- tensorflow/lite/delegates/gpu/gl/runtime.h | 20 +- .../delegates/gpu/gl/runtime/shared_buffer.h | 4 +- .../lite/delegates/gpu/gl/serialization.cc | 54 +- .../lite/delegates/gpu/gl/serialization.h | 16 +- .../delegates/gpu/gl/serialization_test.cc | 15 +- tensorflow/lite/delegates/gpu/gl_delegate.cc | 42 +- tensorflow/lite/delegates/gpu/metal/api.cc | 39 +- tensorflow/lite/delegates/gpu/metal/api.h | 4 +- tensorflow/lite/delegates/gpu/metal/common.h | 7 +- tensorflow/lite/delegates/gpu/metal/common.mm | 14 +- .../lite/delegates/gpu/metal/common_test.mm | 8 +- .../delegates/gpu/metal/compiled_model.cc | 12 +- .../lite/delegates/gpu/metal/compiled_model.h | 7 +- .../gpu/metal/compiled_model_test.mm | 22 +- .../lite/delegates/gpu/metal/compute_task.h | 19 +- .../lite/delegates/gpu/metal/compute_task.mm | 39 +- .../delegates/gpu/metal/inference_context.h | 14 +- .../delegates/gpu/metal/inference_context.mm | 27 +- .../gpu/metal/inference_context_test.mm | 18 +- .../delegates/gpu/metal/kernels/add_test.mm | 13 +- .../gpu/metal/kernels/concat_test.mm | 17 +- .../delegates/gpu/metal/kernels/conv_test.mm | 21 +- .../gpu/metal/kernels/custom_registry.cc | 12 +- .../gpu/metal/kernels/custom_registry.h | 10 +- .../gpu/metal/kernels/depthwise_conv_test.mm | 13 +- .../gpu/metal/kernels/elementwise_test.mm | 77 +- .../gpu/metal/kernels/fully_connected_test.mm | 5 +- .../gpu/metal/kernels/max_unpooling_test.mm | 5 +- .../delegates/gpu/metal/kernels/mean_test.mm | 5 +- .../delegates/gpu/metal/kernels/mul_test.mm | 9 +- .../gpu/metal/kernels/padding_test.mm | 13 +- .../gpu/metal/kernels/pooling_test.mm | 15 +- .../delegates/gpu/metal/kernels/prelu_test.mm | 17 +- .../delegates/gpu/metal/kernels/relu_test.mm | 17 +- .../gpu/metal/kernels/reshape_test.mm | 16 +- .../gpu/metal/kernels/resize_test.mm | 25 +- .../delegates/gpu/metal/kernels/slice_test.mm | 25 +- .../gpu/metal/kernels/softmax_test.mm | 13 +- .../gpu/metal/kernels/space_to_depth_test.mm | 8 +- .../delegates/gpu/metal/kernels/test_util.h | 14 +- .../delegates/gpu/metal/kernels/test_util.mm | 22 +- .../gpu/metal/kernels/transpose_conv_test.mm | 25 +- .../lite/delegates/gpu/metal_delegate.mm | 38 +- tensorflow/lite/delegates/gpu/spi.h | 12 +- 286 files changed, 3877 insertions(+), 3922 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index b5fff1d84d5..72af2534988 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -12,6 +12,12 @@ exports_files([ "metal_delegate.h", ]) +# Primary purpose of this config is to replace ::util::Status with our custom +# light implementation ::tflite::gpu::StatusLite to reduce binary size. Besides +# that, certain features that were hard to communicate without full open source +# were hidden away too such as compiled models, serialization, and metadata. +# While the latter will be fully available with the open source release, the +# former will have to stay until absl::Status is released. config_setting( name = "tflite_gpu_binary_release", values = {"copt": "-DTFLITE_GPU_BINARY_RELEASE"}, diff --git a/tensorflow/lite/delegates/gpu/api.h b/tensorflow/lite/delegates/gpu/api.h index 921f2d54006..803983214e2 100644 --- a/tensorflow/lite/delegates/gpu/api.h +++ b/tensorflow/lite/delegates/gpu/api.h @@ -220,8 +220,7 @@ class InferenceBuilder { // Sets new shape for the input if underlying implementation and graph // structure allows dynamic tensors. - virtual absl::Status SetInputShape(int index, - const Dimensions& dimensions) = 0; + virtual Status SetInputShape(int index, const Dimensions& dimensions) = 0; // Updates object definitions for the given index. Implementation may allow // to use different layouts and/or data type conversions between objects @@ -230,21 +229,21 @@ class InferenceBuilder { // A user, however, has an input in DataType::FLOAT16, DataLayout::PHWC4. // An implementation may allow this transformation to happen automatically // under the hood. - virtual absl::Status SetInputObjectDef(int index, ObjectDef def) = 0; - virtual absl::Status SetOutputObjectDef(int index, ObjectDef def) = 0; - virtual absl::Status SetAllInputObjectDefsTo(ObjectDef def) { + virtual Status SetInputObjectDef(int index, ObjectDef def) = 0; + virtual Status SetOutputObjectDef(int index, ObjectDef def) = 0; + virtual Status SetAllInputObjectDefsTo(ObjectDef def) { auto input_defs = inputs(); for (int i = 0; i < input_defs.size(); ++i) { RETURN_IF_ERROR(SetInputObjectDef(i, def)); } - return absl::OkStatus(); + return OkStatus(); } - virtual absl::Status SetAllOutputObjectDefsTo(ObjectDef def) { + virtual Status SetAllOutputObjectDefsTo(ObjectDef def) { auto output_defs = outputs(); for (int i = 0; i < output_defs.size(); ++i) { RETURN_IF_ERROR(SetOutputObjectDef(i, def)); } - return absl::OkStatus(); + return OkStatus(); } // Creates new instance of the inference runner. InferenceBuilder stays valid @@ -252,7 +251,7 @@ class InferenceBuilder { // // This method may take significant time to prepare new inference runner. For // example, it may require to compile OpenGL shaders. - virtual absl::Status Build(std::unique_ptr* runner) = 0; + virtual Status Build(std::unique_ptr* runner) = 0; }; // Runs prepared inference. Every object marked as external needs to be set @@ -269,12 +268,12 @@ class InferenceRunner { // Setters allow to set or change external object for the given index. Note, // object need to match object definition set before in InferenceBuilder. - virtual absl::Status GetInputObject(int index, TensorObject* object) = 0; - virtual absl::Status GetOutputObject(int index, TensorObject* object) = 0; - virtual absl::Status SetInputObject(int index, TensorObject object) = 0; - virtual absl::Status SetOutputObject(int index, TensorObject object) = 0; + virtual Status GetInputObject(int index, TensorObject* object) = 0; + virtual Status GetOutputObject(int index, TensorObject* object) = 0; + virtual Status SetInputObject(int index, TensorObject object) = 0; + virtual Status SetOutputObject(int index, TensorObject object) = 0; - virtual absl::Status Run() = 0; + virtual Status Run() = 0; }; // Encapsulated compilation/runtime tradeoffs. diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index a6488c51ce4..4e85f92c6de 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -54,22 +54,22 @@ class NoopTensorTie : public TensorTie { return def.external_def == def.internal_def; } - absl::Status SetExternalObject(TensorObject obj) final { + Status SetExternalObject(TensorObject obj) final { if (!def().external_def.object_def.user_provided) { - return absl::InvalidArgumentError("Tensor object is readonly."); + return InvalidArgumentError("Tensor object is readonly."); } if (!IsValid(def().external_def, obj)) { - return absl::InvalidArgumentError("Given object is not valid"); + return InvalidArgumentError("Given object is not valid"); } obj_ = obj; - return absl::OkStatus(); + return OkStatus(); } TensorObject GetExternalObject() final { return obj_; } - absl::Status CopyToExternalObject() final { return absl::OkStatus(); } + Status CopyToExternalObject() final { return OkStatus(); } - absl::Status CopyFromExternalObject() final { return absl::OkStatus(); } + Status CopyFromExternalObject() final { return OkStatus(); } private: TensorObject obj_; @@ -93,45 +93,45 @@ class DefaultTensorTie : public TensorTie { converter_builder.IsSupported(def.external_def, def.internal_def); } - static absl::Status New(const TensorTieDef& def, TensorObject internal_object, - TensorObjectConverterBuilder* converter_builder, - Environment* env, std::unique_ptr* tie) { + static Status New(const TensorTieDef& def, TensorObject internal_object, + TensorObjectConverterBuilder* converter_builder, + Environment* env, std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def, internal_object); RETURN_IF_ERROR(tie_impl->Init(converter_builder, env)); *tie = std::move(tie_impl); - return absl::OkStatus(); + return OkStatus(); } - absl::Status CopyToExternalObject() final { + Status CopyToExternalObject() final { if (!converter_to_) { - return absl::UnavailableError("Conversion is not available"); + return UnavailableError("Conversion is not available"); } return converter_to_->Convert(internal_obj_, GetExternalObject()); } - absl::Status CopyFromExternalObject() final { + Status CopyFromExternalObject() final { if (!converter_from_) { - return absl::UnavailableError("Conversion is not available"); + return UnavailableError("Conversion is not available"); } return converter_from_->Convert(GetExternalObject(), internal_obj_); } - absl::Status SetExternalObject(TensorObject obj) final { + Status SetExternalObject(TensorObject obj) final { if (!def().external_def.object_def.user_provided) { - return absl::InvalidArgumentError("External object is read-only"); + return InvalidArgumentError("External object is read-only"); } if (!IsValid(def().external_def, obj)) { - return absl::InvalidArgumentError("Given object is not valid"); + return InvalidArgumentError("Given object is not valid"); } external_obj_ = obj; - return absl::OkStatus(); + return OkStatus(); } TensorObject GetExternalObject() final { return external_obj_; } private: - absl::Status Init(TensorObjectConverterBuilder* converter_builder, - Environment* env) { + Status Init(TensorObjectConverterBuilder* converter_builder, + Environment* env) { RETURN_IF_ERROR(converter_builder->MakeConverter( def().internal_def, def().external_def, &converter_to_)); RETURN_IF_ERROR(converter_builder->MakeConverter( @@ -139,10 +139,10 @@ class DefaultTensorTie : public TensorTie { return MaybeAllocateExternalObject(env); } - absl::Status MaybeAllocateExternalObject(Environment* env) { + Status MaybeAllocateExternalObject(Environment* env) { const TensorObjectDef& d = def().external_def; if (d.object_def.user_provided) { - return absl::OkStatus(); + return OkStatus(); } switch (d.object_def.object_type) { case ObjectType::CPU_MEMORY: { @@ -170,9 +170,9 @@ class DefaultTensorTie : public TensorTie { break; } default: - return absl::InternalError("Unexpected object type"); + return InternalError("Unexpected object type"); } - return absl::OkStatus(); + return OkStatus(); } const TensorObject internal_obj_; @@ -198,26 +198,26 @@ class TwoStepTensorTie : public TensorTie { DefaultTensorTie::IsSupported(defs.second, converter_builder); } - static absl::Status New(const TensorTieDef& def, TensorObject internal_object, - TensorObjectConverterBuilder* converter_builder, - Environment* env, std::unique_ptr* tie) { + static Status New(const TensorTieDef& def, TensorObject internal_object, + TensorObjectConverterBuilder* converter_builder, + Environment* env, std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def); RETURN_IF_ERROR(tie_impl->Init(internal_object, converter_builder, env)); *tie = std::move(tie_impl); - return absl::OkStatus(); + return OkStatus(); } - absl::Status CopyToExternalObject() final { + Status CopyToExternalObject() final { RETURN_IF_ERROR(inner_tie_->CopyToExternalObject()); return outer_tie_->CopyToExternalObject(); } - absl::Status CopyFromExternalObject() final { + Status CopyFromExternalObject() final { RETURN_IF_ERROR(outer_tie_->CopyFromExternalObject()); return inner_tie_->CopyFromExternalObject(); } - absl::Status SetExternalObject(TensorObject obj) final { + Status SetExternalObject(TensorObject obj) final { return outer_tie_->SetExternalObject(obj); } @@ -241,9 +241,9 @@ class TwoStepTensorTie : public TensorTie { return std::make_pair(outer_def, inner_def); } - absl::Status Init(TensorObject internal_object, - TensorObjectConverterBuilder* converter_builder, - Environment* env) { + Status Init(TensorObject internal_object, + TensorObjectConverterBuilder* converter_builder, + Environment* env) { auto defs = MakeOuterInnerDefs(def()); RETURN_IF_ERROR(DefaultTensorTie::New(defs.second, internal_object, converter_builder, env, &inner_tie_)); @@ -274,27 +274,27 @@ class GlBufferHolder : public TensorTie { return DefaultTensorTie::IsSupported(MakeClDef(def), converter_builder); } - static absl::Status New(const TensorTieDef& def, TensorObject internal_object, - TensorObjectConverterBuilder* converter_builder, - GlInteropFabric* gl_interop_fabric, Environment* env, - std::unique_ptr* tie) { + static Status New(const TensorTieDef& def, TensorObject internal_object, + TensorObjectConverterBuilder* converter_builder, + GlInteropFabric* gl_interop_fabric, Environment* env, + std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def, gl_interop_fabric, env); RETURN_IF_ERROR(DefaultTensorTie::New(MakeClDef(def), internal_object, converter_builder, env, &tie_impl->tie_)); *tie = std::move(tie_impl); - return absl::OkStatus(); + return OkStatus(); } - absl::Status SetExternalObject(TensorObject obj) final { + Status SetExternalObject(TensorObject obj) final { auto ssbo = absl::get_if(&obj); if (!ssbo) { - return absl::InvalidArgumentError("Missing OpenGL SSBO"); + return InvalidArgumentError("Missing OpenGL SSBO"); } auto old_ssbo = absl::get_if(&external_obj_); if (old_ssbo && ssbo->id == old_ssbo->id) { - return absl::OkStatus(); + return OkStatus(); } if (cl_object_.memory()) { gl_interop_fabric_->UnregisterMemory(cl_object_.memory()); @@ -304,18 +304,16 @@ class GlBufferHolder : public TensorTie { external_obj_ = obj; RETURN_IF_ERROR(tie_->SetExternalObject(OpenClBuffer{cl_object_.memory()})); gl_interop_fabric_->RegisterMemory(cl_object_.memory()); - return absl::OkStatus(); + return OkStatus(); } TensorObject GetExternalObject() final { return external_obj_; } - absl::Status CopyFromExternalObject() final { + Status CopyFromExternalObject() final { return tie_->CopyFromExternalObject(); } - absl::Status CopyToExternalObject() final { - return tie_->CopyToExternalObject(); - } + Status CopyToExternalObject() final { return tie_->CopyToExternalObject(); } private: static TensorTieDef MakeClDef(const TensorTieDef& def) { @@ -360,20 +358,20 @@ class TensorTieFactory { TwoStepTensorTie::IsSupported(def, *converter_builder_)); } - absl::Status NewTensorTie(const TensorTieDef& def, - std::unique_ptr* tie) { + Status NewTensorTie(const TensorTieDef& def, + std::unique_ptr* tie) { TensorObject internal_object = TensorToObj(*context_.GetTensor(def.id)); auto converter = converter_builder_.get(); if (NoopTensorTie::IsSupported(def)) { *tie = absl::make_unique(def, internal_object); - return absl::OkStatus(); + return OkStatus(); } if (DefaultTensorTie::IsSupported(def, *converter)) { return DefaultTensorTie::New(def, internal_object, converter, &env_, tie); } if (GlBufferHolder::IsSupported(def, *converter)) { if (!gl_interop_fabric_) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "GL object is used but InferenceEnvironmentOptions does not have " "EGL display and context set."); } @@ -383,7 +381,7 @@ class TensorTieFactory { if (TwoStepTensorTie::IsSupported(def, *converter)) { return TwoStepTensorTie::New(def, internal_object, converter, &env_, tie); } - return absl::UnimplementedError("Unsupported tensor tie definition."); + return UnimplementedError("Unsupported tensor tie definition."); } private: @@ -402,9 +400,9 @@ class InferenceRunnerImpl : public InferenceRunner { context_(std::move(context)), gl_interop_fabric_(std::move(gl_interop_fabric)) {} - absl::Status Initialize(const std::vector& inputs, - const std::vector& outputs, - TensorTieFactory* factory) { + Status Initialize(const std::vector& inputs, + const std::vector& outputs, + TensorTieFactory* factory) { RETURN_IF_ERROR(LinkTensors(inputs, factory, &inputs_)); return LinkTensors(outputs, factory, &outputs_); } @@ -417,37 +415,37 @@ class InferenceRunnerImpl : public InferenceRunner { return GetExternalDefinitions(outputs_); } - absl::Status GetInputObject(int index, TensorObject* object) override { + Status GetInputObject(int index, TensorObject* object) override { if (index < 0 || index >= inputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } *object = inputs_[index]->GetExternalObject(); - return absl::OkStatus(); + return OkStatus(); } - absl::Status GetOutputObject(int index, TensorObject* object) override { + Status GetOutputObject(int index, TensorObject* object) override { if (index < 0 || index >= outputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } *object = outputs_[index]->GetExternalObject(); - return absl::OkStatus(); + return OkStatus(); } - absl::Status SetInputObject(int index, TensorObject object) override { + Status SetInputObject(int index, TensorObject object) override { if (index < 0 || index >= inputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } return inputs_[index]->SetExternalObject(object); } - absl::Status SetOutputObject(int index, TensorObject object) override { + Status SetOutputObject(int index, TensorObject object) override { if (index < 0 || index >= outputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } return outputs_[index]->SetExternalObject(object); } - absl::Status Run() override { + Status Run() override { if (gl_interop_fabric_) { RETURN_IF_ERROR(gl_interop_fabric_->Start()); } @@ -462,20 +460,20 @@ class InferenceRunnerImpl : public InferenceRunner { if (gl_interop_fabric_) { RETURN_IF_ERROR(gl_interop_fabric_->Finish()); } - return absl::OkStatus(); + return OkStatus(); } private: - static absl::Status LinkTensors( - const std::vector& defs, TensorTieFactory* factory, - std::vector>* objects) { + static Status LinkTensors(const std::vector& defs, + TensorTieFactory* factory, + std::vector>* objects) { objects->reserve(defs.size()); for (auto& def : defs) { std::unique_ptr object; RETURN_IF_ERROR(factory->NewTensorTie(def, &object)); objects->push_back(std::move(object)); } - return absl::OkStatus(); + return OkStatus(); } static std::vector GetExternalDefinitions( @@ -513,9 +511,9 @@ class InferenceBuilderImpl : public InferenceBuilder { explicit InferenceBuilderImpl(Environment* environment) : environment_(environment) {} - absl::Status Initialize(const InferenceOptions& options, - const InferenceEnvironmentOptions& env_options, - const GraphFloat32& graph) { + Status Initialize(const InferenceOptions& options, + const InferenceEnvironmentOptions& env_options, + const GraphFloat32& graph) { context_ = absl::make_unique(); InferenceContext::CreateInferenceInfo create_info; create_info.precision = GetPrecision(options); @@ -535,7 +533,7 @@ class InferenceBuilderImpl : public InferenceBuilder { inputs_ = LinkTensors(graph, graph.inputs()); outputs_ = LinkTensors(graph, graph.outputs()); - return absl::OkStatus(); + return OkStatus(); } std::vector inputs() const override { @@ -546,42 +544,40 @@ class InferenceBuilderImpl : public InferenceBuilder { return GetExternalDefinitions(outputs_); } - absl::Status SetInputShape(int index, const Dimensions& dimensions) override { + Status SetInputShape(int index, const Dimensions& dimensions) override { if (index < 0 || index >= inputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } - return absl::UnimplementedError("Changing input shapes is not supported"); + return UnimplementedError("Changing input shapes is not supported"); } - absl::Status SetInputObjectDef(int index, ObjectDef new_def) override { + Status SetInputObjectDef(int index, ObjectDef new_def) override { if (index < 0 || index >= inputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } auto def = inputs_[index]; def.external_def.object_def = new_def; if (!tie_factory_->IsSupported(def)) { - return absl::InvalidArgumentError( - "New object definition is not supported."); + return InvalidArgumentError("New object definition is not supported."); } inputs_[index] = def; - return absl::OkStatus(); + return OkStatus(); } - absl::Status SetOutputObjectDef(int index, ObjectDef new_def) override { + Status SetOutputObjectDef(int index, ObjectDef new_def) override { if (index < 0 || index >= outputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } auto def = outputs_[index]; def.external_def.object_def = new_def; if (!tie_factory_->IsSupported(def)) { - return absl::InvalidArgumentError( - "New object definition is not supported."); + return InvalidArgumentError("New object definition is not supported."); } outputs_[index] = def; - return absl::OkStatus(); + return OkStatus(); } - absl::Status Build(std::unique_ptr* runner) override { + Status Build(std::unique_ptr* runner) override { if (gl_interop_fabric_ && !HasGlObjects()) { // destroy interop layer when there are no GL objects to avoid // extra synchronization cost. @@ -592,7 +588,7 @@ class InferenceBuilderImpl : public InferenceBuilder { RETURN_IF_ERROR( runner_impl->Initialize(inputs_, outputs_, tie_factory_.get())); *runner = std::move(runner_impl); - return absl::OkStatus(); + return OkStatus(); } private: @@ -700,7 +696,7 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { explicit InferenceEnvironmentImpl(const InferenceEnvironmentOptions& options) : options_(options) {} - absl::Status Init() { + Status Init() { RETURN_IF_ERROR(LoadOpenCL()); properties_.is_opencl_available = true; @@ -720,13 +716,13 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { properties_.is_cl_to_gl_fast_sync_supported = IsEglSyncFromClEventSupported(); if (options_.IsGlAware() && !properties_.is_gl_sharing_supported) { - return absl::UnavailableError("GL sharing is not supported"); + return UnavailableError("GL sharing is not supported"); } CLContext context; if (options_.context) { if (options_.IsGlAware()) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "OpenCL context and EGL parameters are set in the same time."); } context = CLContext(options_.context, /* has_ownership = */ false); @@ -758,11 +754,11 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { return environment_.Init(); } - absl::Status NewInferenceBuilder( - const InferenceOptions& options, GraphFloat32 model, - std::unique_ptr* builder) final { + Status NewInferenceBuilder(const InferenceOptions& options, + GraphFloat32 model, + std::unique_ptr* builder) final { if (!IsValid(options)) { - return absl::InvalidArgumentError("InferenceOptions are invalid."); + return InvalidArgumentError("InferenceOptions are invalid."); } InferenceOptions resolved_options = options; ResolveAutoPriority(&resolved_options); @@ -780,7 +776,7 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { RETURN_IF_ERROR( builder_impl->Initialize(resolved_options, options_, model)); *builder = std::move(builder_impl); - return absl::OkStatus(); + return OkStatus(); } std::vector GetSerializedBinaryCache() const final { @@ -804,18 +800,18 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { } // namespace -absl::Status NewInferenceEnvironment( +Status NewInferenceEnvironment( const InferenceEnvironmentOptions& options, std::unique_ptr* environment, InferenceEnvironmentProperties* properties) { auto env_impl = absl::make_unique(options); - absl::Status status = env_impl->Init(); + Status status = env_impl->Init(); if (properties) { *properties = env_impl->properties(); } RETURN_IF_ERROR(status); *environment = std::move(env_impl); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/api.h b/tensorflow/lite/delegates/gpu/cl/api.h index 9d3f9f7214c..2ac5ce2e28b 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.h +++ b/tensorflow/lite/delegates/gpu/cl/api.h @@ -70,7 +70,7 @@ class InferenceEnvironment { public: virtual ~InferenceEnvironment() {} - virtual absl::Status NewInferenceBuilder( + virtual Status NewInferenceBuilder( const InferenceOptions& options, GraphFloat32 model, std::unique_ptr* builder) = 0; @@ -112,7 +112,7 @@ struct InferenceEnvironmentOptions { // Creates new OpenCL environment that needs to stay around until all inference // runners are destroyed. -absl::Status NewInferenceEnvironment( +Status NewInferenceEnvironment( const InferenceEnvironmentOptions& options, std::unique_ptr* environment, InferenceEnvironmentProperties* properties /* optional */); diff --git a/tensorflow/lite/delegates/gpu/cl/buffer.cc b/tensorflow/lite/delegates/gpu/cl/buffer.cc index 207cdec5122..51d9a59e888 100644 --- a/tensorflow/lite/delegates/gpu/cl/buffer.cc +++ b/tensorflow/lite/delegates/gpu/cl/buffer.cc @@ -21,10 +21,8 @@ namespace tflite { namespace gpu { namespace cl { namespace { - -absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, - const void* data, CLContext* context, - Buffer* result) { +Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, const void* data, + CLContext* context, Buffer* result) { cl_mem_flags flags = gpu_read_only ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE; if (data != nullptr) { flags |= CL_MEM_COPY_HOST_PTR; @@ -33,14 +31,14 @@ absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, cl_mem buffer = clCreateBuffer(context->context(), flags, size_in_bytes, const_cast(data), &error_code); if (!buffer) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to allocate device memory with clCreateBuffer", CLErrorCodeToString(error_code))); } *result = Buffer(buffer, size_in_bytes); - return absl::OkStatus(); + return OkStatus(); } } // namespace @@ -71,18 +69,18 @@ void Buffer::Release() { } } -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context, - Buffer* result) { +Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context, + Buffer* result) { return CreateBuffer(size_in_bytes, true, nullptr, context, result); } -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, - CLContext* context, Buffer* result) { +Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, + CLContext* context, Buffer* result) { return CreateBuffer(size_in_bytes, true, data, context, result); } -absl::Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context, - Buffer* result) { +Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context, + Buffer* result) { return CreateBuffer(size_in_bytes, false, nullptr, context, result); } diff --git a/tensorflow/lite/delegates/gpu/cl/buffer.h b/tensorflow/lite/delegates/gpu/cl/buffer.h index 84c3292084b..4282d9c0898 100644 --- a/tensorflow/lite/delegates/gpu/cl/buffer.h +++ b/tensorflow/lite/delegates/gpu/cl/buffer.h @@ -51,11 +51,11 @@ class Buffer { // Writes data to a buffer. Data should point to a region that // has exact size in bytes as size_in_bytes(constructor parameter). template - absl::Status WriteData(CLCommandQueue* queue, const absl::Span data); + Status WriteData(CLCommandQueue* queue, const absl::Span data); // Reads data from Buffer into CPU memory. template - absl::Status ReadData(CLCommandQueue* queue, std::vector* result) const; + Status ReadData(CLCommandQueue* queue, std::vector* result) const; private: void Release(); @@ -64,31 +64,29 @@ class Buffer { size_t size_; }; -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context, - Buffer* result); +Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context, + Buffer* result); -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, - CLContext* context, Buffer* result); +Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, + CLContext* context, Buffer* result); -absl::Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context, - Buffer* result); +Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context, + Buffer* result); template -absl::Status Buffer::WriteData(CLCommandQueue* queue, - const absl::Span data) { +Status Buffer::WriteData(CLCommandQueue* queue, const absl::Span data) { if (size_ != sizeof(T) * data.size()) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "absl::Span data size is different from buffer allocated size."); } RETURN_IF_ERROR(queue->EnqueueWriteBuffer(buffer_, size_, data.data())); - return absl::OkStatus(); + return OkStatus(); } template -absl::Status Buffer::ReadData(CLCommandQueue* queue, - std::vector* result) const { +Status Buffer::ReadData(CLCommandQueue* queue, std::vector* result) const { if (size_ % sizeof(T) != 0) { - return absl::UnknownError("Wrong element size(typename T is not correct?"); + return UnknownError("Wrong element size(typename T is not correct?"); } const int elements_count = size_ / sizeof(T); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc index 7b74840c5e6..328cdaf0a6e 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc @@ -56,9 +56,8 @@ void CLCommandQueue::Release() { } } -absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size, - CLEvent* event) { +Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size, CLEvent* event) { std::vector local(3); std::vector global(3); for (int i = 0; i < 3; ++i) { @@ -73,31 +72,30 @@ absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, *event = CLEvent(resulting_event); } if (error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to clEnqueueNDRangeKernel - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to clEnqueueNDRangeKernel - ", + CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size) { +Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size) { return DispatchImplicit(kernel, grid, work_group_size, nullptr); } -absl::Status CLCommandQueue::EnqueueEvent(CLEvent* event) { +Status CLCommandQueue::EnqueueEvent(CLEvent* event) { cl_event resulting_event; const int error_code = clEnqueueMarker(queue_, &resulting_event); *event = CLEvent(resulting_event); if (error_code != CL_SUCCESS) { - return absl::UnknownError(absl::StrCat("Failed to clEnqueueMarker - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to clEnqueueMarker - ", + CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region, - const void* data) { +Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region, + const void* data) { const size_t origin[] = {0, 0, 0}; const size_t r[] = {static_cast(region.x), static_cast(region.y), @@ -105,16 +103,16 @@ absl::Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region, auto error_code = clEnqueueWriteImage(queue_, memory, CL_TRUE, origin, r, 0, 0, data, 0, nullptr, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to upload data to GPU (clEnqueueWriteImage) - ", CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region, - void* data) { +Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region, + void* data) { const size_t origin[] = {0, 0, 0}; const size_t r[] = {static_cast(region.x), static_cast(region.y), @@ -122,47 +120,45 @@ absl::Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region, auto error_code = clEnqueueReadImage(queue_, memory, CL_TRUE, origin, r, 0, 0, data, 0, nullptr, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to read data from GPU (clEnqueueReadImage) - ", CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CLCommandQueue::EnqueueWriteBuffer(cl_mem memory, - size_t size_in_bytes, - const void* data) { +Status CLCommandQueue::EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, + const void* data) { auto error_code = clEnqueueWriteBuffer( queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to upload data to GPU (clEnqueueWriteBuffer) - ", CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CLCommandQueue::EnqueueReadBuffer(cl_mem memory, - size_t size_in_bytes, - void* data) { +Status CLCommandQueue::EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, + void* data) { auto error_code = clEnqueueReadBuffer( queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to read data from GPU (clEnqueueReadBuffer) - ", CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CLCommandQueue::WaitForCompletion() { +Status CLCommandQueue::WaitForCompletion() { auto error_code = clFinish(queue_); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to clFinish - ", CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } ProfilingCommandQueue::ProfilingCommandQueue(cl_command_queue queue) @@ -191,14 +187,14 @@ void ProfilingCommandQueue::SetEventsLabel(const std::string& name) { void ProfilingCommandQueue::ResetMeasurements() { events_.clear(); } -absl::Status ProfilingCommandQueue::DispatchImplicit(const CLKernel& kernel, - int3 grid, - int3 work_group_size) { +Status ProfilingCommandQueue::DispatchImplicit(const CLKernel& kernel, + int3 grid, + int3 work_group_size) { events_.push_back(CLEvent()); RETURN_IF_ERROR(CLCommandQueue::DispatchImplicit( kernel, grid, work_group_size, &events_[events_.size() - 1])); events_.back().SetName(current_label_); - return absl::OkStatus(); + return OkStatus(); } ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const { @@ -212,7 +208,7 @@ ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const { return result; } -absl::Status ProfilingCommandQueue::GetBestWorkGroupIndex( +Status ProfilingCommandQueue::GetBestWorkGroupIndex( const CLKernel& kernel, const DeviceInfo& device_info, const int3& grid, const std::vector& work_group_sizes, int* index) { // Some Adreno 3xx can have wrong numbers for some events @@ -272,22 +268,20 @@ absl::Status ProfilingCommandQueue::GetBestWorkGroupIndex( *index = minimum_index; - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateCLCommandQueue(const CLDevice& device, - const CLContext& context, - CLCommandQueue* result) { +Status CreateCLCommandQueue(const CLDevice& device, const CLContext& context, + CLCommandQueue* result) { int error_code; cl_command_queue queue = clCreateCommandQueue(context.context(), device.id(), 0, &error_code); if (!queue) { - return absl::UnknownError( - absl::StrCat("Failed to create a command queue - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to create a command queue - ", + CLErrorCodeToString(error_code))); } *result = CLCommandQueue(queue, true); - return absl::OkStatus(); + return OkStatus(); } double ProfilingCommandQueue::GetQueueExecutionTimeMs() const { @@ -306,20 +300,19 @@ double ProfilingCommandQueue::GetSumOfEventsTimeMs() const { return sum; } -absl::Status CreateProfilingCommandQueue(const CLDevice& device, - const CLContext& context, - ProfilingCommandQueue* result) { +Status CreateProfilingCommandQueue(const CLDevice& device, + const CLContext& context, + ProfilingCommandQueue* result) { int error_code; cl_command_queue queue = clCreateCommandQueue( context.context(), device.id(), CL_QUEUE_PROFILING_ENABLE, &error_code); if (!queue) { - return absl::UnknownError( - absl::StrCat("Failed to create a command queue - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to create a command queue - ", + CLErrorCodeToString(error_code))); } *result = ProfilingCommandQueue(queue); - return absl::OkStatus(); + return OkStatus(); } absl::Duration ProfilingInfo::GetTotalTime() const { diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h index 178e3b21a1e..84ffeca67eb 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h @@ -74,23 +74,22 @@ class CLCommandQueue { cl_command_queue queue() const { return queue_; } - virtual absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size); + virtual Status DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size); - absl::Status EnqueueEvent(CLEvent* event); + Status EnqueueEvent(CLEvent* event); - absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size, CLEvent* event); + Status DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size, CLEvent* event); - absl::Status EnqueueWriteImage(cl_mem memory, int3 region, const void* data); - absl::Status EnqueueReadImage(cl_mem memory, int3 region, void* data); + Status EnqueueWriteImage(cl_mem memory, int3 region, const void* data); + Status EnqueueReadImage(cl_mem memory, int3 region, void* data); - absl::Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, - const void* data); - absl::Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, - void* data); + Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, + const void* data); + Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, void* data); - absl::Status WaitForCompletion(); + Status WaitForCompletion(); protected: void Release(); @@ -110,15 +109,14 @@ class ProfilingCommandQueue : public CLCommandQueue { ProfilingCommandQueue(const ProfilingCommandQueue&) = delete; ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete; - absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size) override; + Status DispatchImplicit(const CLKernel& kernel, int3 grid, + int3 work_group_size) override; // will write index for fastest work_group among work_group_sizes - absl::Status GetBestWorkGroupIndex(const CLKernel& kernel, - const DeviceInfo& device_info, - const int3& grid, - const std::vector& work_group_sizes, - int* index); + Status GetBestWorkGroupIndex(const CLKernel& kernel, + const DeviceInfo& device_info, const int3& grid, + const std::vector& work_group_sizes, + int* index); // call ResetMeasurements() to start new seriese of measurements void ResetMeasurements(); @@ -141,13 +139,12 @@ class ProfilingCommandQueue : public CLCommandQueue { std::string current_label_; }; -absl::Status CreateCLCommandQueue(const CLDevice& device, - const CLContext& context, - CLCommandQueue* result); +Status CreateCLCommandQueue(const CLDevice& device, const CLContext& context, + CLCommandQueue* result); -absl::Status CreateProfilingCommandQueue(const CLDevice& device, - const CLContext& context, - ProfilingCommandQueue* result); +Status CreateProfilingCommandQueue(const CLDevice& device, + const CLContext& context, + ProfilingCommandQueue* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/cl_context.cc b/tensorflow/lite/delegates/gpu/cl/cl_context.cc index e697c78b692..e9e0ddf724b 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_context.cc @@ -43,21 +43,19 @@ std::vector GetSupportedImage2DFormats(cl_context context, return result; } -absl::Status CreateCLContext(const CLDevice& device, - cl_context_properties* properties, - CLContext* result) { +Status CreateCLContext(const CLDevice& device, + cl_context_properties* properties, CLContext* result) { int error_code; cl_device_id device_id = device.id(); cl_context context = clCreateContext(properties, 1, &device_id, nullptr, nullptr, &error_code); if (!context) { - return absl::UnknownError( - absl::StrCat("Failed to create a compute context - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to create a compute context - ", + CLErrorCodeToString(error_code))); } *result = CLContext(context, true); - return absl::OkStatus(); + return OkStatus(); } } // namespace @@ -101,16 +99,15 @@ bool CLContext::IsFloatTexture2DSupported(int num_channels, DataType data_type, return false; } -absl::Status CreateCLContext(const CLDevice& device, CLContext* result) { +Status CreateCLContext(const CLDevice& device, CLContext* result) { return CreateCLContext(device, nullptr, result); } -absl::Status CreateCLGLContext(const CLDevice& device, - cl_context_properties egl_context, - cl_context_properties egl_display, - CLContext* result) { +Status CreateCLGLContext(const CLDevice& device, + cl_context_properties egl_context, + cl_context_properties egl_display, CLContext* result) { if (!device.SupportsExtension("cl_khr_gl_sharing")) { - return absl::UnavailableError("Device doesn't support CL-GL sharing."); + return UnavailableError("Device doesn't support CL-GL sharing."); } cl_context_properties platform = reinterpret_cast(device.platform()); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_context.h b/tensorflow/lite/delegates/gpu/cl/cl_context.h index 11922bd3678..20ec35f2b60 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_context.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_context.h @@ -51,11 +51,10 @@ class CLContext { bool has_ownership_ = false; }; -absl::Status CreateCLContext(const CLDevice& device, CLContext* result); -absl::Status CreateCLGLContext(const CLDevice& device, - cl_context_properties egl_context, - cl_context_properties egl_display, - CLContext* result); +Status CreateCLContext(const CLDevice& device, CLContext* result); +Status CreateCLGLContext(const CLDevice& device, + cl_context_properties egl_context, + cl_context_properties egl_display, CLContext* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.cc b/tensorflow/lite/delegates/gpu/cl/cl_device.cc index 5380c9ee653..c47f86a2928 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_device.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_device.cc @@ -516,11 +516,11 @@ void CLDevice::DisableOneLayerTextureArray() { info_.adreno_info.support_one_layer_texture_array = false; } -absl::Status CreateDefaultGPUDevice(CLDevice* result) { +Status CreateDefaultGPUDevice(CLDevice* result) { cl_uint num_platforms; clGetPlatformIDs(0, nullptr, &num_platforms); if (num_platforms == 0) { - return absl::UnknownError("No supported OpenCL platform."); + return UnknownError("No supported OpenCL platform."); } std::vector platforms(num_platforms); clGetPlatformIDs(num_platforms, platforms.data(), nullptr); @@ -529,7 +529,7 @@ absl::Status CreateDefaultGPUDevice(CLDevice* result) { cl_uint num_devices; clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices); if (num_devices == 0) { - return absl::UnknownError("No GPU on current platform."); + return UnknownError("No GPU on current platform."); } std::vector devices(num_devices); @@ -537,7 +537,7 @@ absl::Status CreateDefaultGPUDevice(CLDevice* result) { nullptr); *result = CLDevice(devices[0], platform_id); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.h b/tensorflow/lite/delegates/gpu/cl/cl_device.h index cbc95d485b9..7b3493e3faa 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_device.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_device.h @@ -191,7 +191,7 @@ class CLDevice { DeviceInfo info_; }; -absl::Status CreateDefaultGPUDevice(CLDevice* result); +Status CreateDefaultGPUDevice(CLDevice* result); template T GetDeviceInfo(cl_device_id id, cl_device_info info) { @@ -204,12 +204,12 @@ T GetDeviceInfo(cl_device_id id, cl_device_info info) { } template -absl::Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) { +Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) { cl_int error = clGetDeviceInfo(id, info, sizeof(T), result, nullptr); if (error != CL_SUCCESS) { - return absl::InvalidArgumentError(CLErrorCodeToString(error)); + return InvalidArgumentError(CLErrorCodeToString(error)); } - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/cl_errors.h b/tensorflow/lite/delegates/gpu/cl/cl_errors.h index fb59766bd18..8c16b2696d7 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_errors.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_errors.h @@ -27,12 +27,11 @@ namespace cl { // @return if error_code is success, then return OK status. Otherwise translates // error code into a message. -inline absl::Status GetOpenCLError(cl_int error_code) { +inline Status GetOpenCLError(cl_int error_code) { if (error_code == CL_SUCCESS) { - return absl::OkStatus(); + return OkStatus(); } - return absl::InternalError("OpenCL error: " + - CLErrorCodeToString(error_code)); + return InternalError("OpenCL error: " + CLErrorCodeToString(error_code)); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc b/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc index 04bf95d870a..27d4d36c68a 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc @@ -25,34 +25,34 @@ namespace gpu { namespace cl { namespace { -absl::Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id, - int* result) { +Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id, + int* result) { size_t max_work_group_size; cl_int error_code = clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE, sizeof(size_t), &max_work_group_size, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ", CLErrorCodeToString(error_code))); } *result = static_cast(max_work_group_size); - return absl::OkStatus(); + return OkStatus(); } -absl::Status GetKernelPrivateMemorySize(cl_kernel kernel, - cl_device_id device_id, int* result) { +Status GetKernelPrivateMemorySize(cl_kernel kernel, cl_device_id device_id, + int* result) { cl_ulong private_mem_size; cl_int error_code = clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE, sizeof(cl_ulong), &private_mem_size, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ", CLErrorCodeToString(error_code))); } *result = static_cast(private_mem_size); - return absl::OkStatus(); + return OkStatus(); } } // namespace @@ -82,17 +82,17 @@ CLKernel& CLKernel::operator=(CLKernel&& kernel) { CLKernel::~CLKernel() { Release(); } -absl::Status CLKernel::ReInit() const { +Status CLKernel::ReInit() const { clReleaseKernel(kernel_); cl_kernel* kern_ptr = const_cast(&kernel_); int error_code; *kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code); if (!kernel_ || error_code != CL_SUCCESS) { *kern_ptr = nullptr; - return absl::UnknownError(absl::StrCat("Failed to create ", function_name_, - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to create ", function_name_, + CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } void CLKernel::Release() { @@ -103,16 +103,16 @@ void CLKernel::Release() { } } -absl::Status CLKernel::CreateFromProgram(const CLProgram& program, - const std::string& function_name) { +Status CLKernel::CreateFromProgram(const CLProgram& program, + const std::string& function_name) { int error_code; function_name_ = function_name; kernel_ = clCreateKernel(program.program(), function_name.c_str(), &error_code); if (!kernel_ || error_code != CL_SUCCESS) { kernel_ = nullptr; - return absl::UnknownError(absl::StrCat("Failed to create ", function_name, - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to create ", function_name, + CLErrorCodeToString(error_code))); } program_ = program.program(); @@ -122,64 +122,64 @@ absl::Status CLKernel::CreateFromProgram(const CLProgram& program, &private_memory_size_)); RETURN_IF_ERROR(GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(), &max_work_group_size_)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CLKernel::SetMemory(int index, cl_mem memory) { +Status CLKernel::SetMemory(int index, cl_mem memory) { return SetBytes(index, &memory, sizeof(cl_mem)); } -absl::Status CLKernel::SetMemoryAuto(cl_mem memory) { +Status CLKernel::SetMemoryAuto(cl_mem memory) { return SetBytesAuto(&memory, sizeof(cl_mem)); } -absl::Status CLKernel::SetBytes(int index, const void* ptr, int length) const { +Status CLKernel::SetBytes(int index, const void* ptr, int length) const { const int error_code = clSetKernelArg(kernel_, index, length, ptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to set kernel arguments - ", + CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CLKernel::SetBytesAuto(const void* ptr, int length) { +Status CLKernel::SetBytesAuto(const void* ptr, int length) { const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError(absl::StrCat( - "Failed to set kernel arguments - ", CLErrorCodeToString(error_code), - "(at index - ", binding_counter_, ")")); + return UnknownError(absl::StrCat("Failed to set kernel arguments - ", + CLErrorCodeToString(error_code), + "(at index - ", binding_counter_, ")")); } binding_counter_++; - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status CLKernel::SetBytes(int index, const FLT& value) const { +Status CLKernel::SetBytes(int index, const FLT& value) const { return SetBytes(index, value.GetData(), value.GetSize()); } template <> -absl::Status CLKernel::SetBytes(int index, const FLT2& value) const { +Status CLKernel::SetBytes(int index, const FLT2& value) const { return SetBytes(index, value.GetData(), value.GetSize()); } template <> -absl::Status CLKernel::SetBytes(int index, const FLT4& value) const { +Status CLKernel::SetBytes(int index, const FLT4& value) const { return SetBytes(index, value.GetData(), value.GetSize()); } template <> -absl::Status CLKernel::SetBytesAuto(const FLT& value) { +Status CLKernel::SetBytesAuto(const FLT& value) { return SetBytesAuto(value.GetData(), value.GetSize()); } template <> -absl::Status CLKernel::SetBytesAuto(const FLT2& value) { +Status CLKernel::SetBytesAuto(const FLT2& value) { return SetBytesAuto(value.GetData(), value.GetSize()); } template <> -absl::Status CLKernel::SetBytesAuto(const FLT4& value) { +Status CLKernel::SetBytesAuto(const FLT4& value) { return SetBytesAuto(value.GetData(), value.GetSize()); } diff --git a/tensorflow/lite/delegates/gpu/cl/cl_kernel.h b/tensorflow/lite/delegates/gpu/cl/cl_kernel.h index b575684d2b4..3b63e43c967 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_kernel.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_kernel.h @@ -48,17 +48,17 @@ class CLKernel { cl_kernel kernel() const { return kernel_; } - absl::Status CreateFromProgram(const CLProgram& program, - const std::string& function_name); + Status CreateFromProgram(const CLProgram& program, + const std::string& function_name); - absl::Status SetMemory(int index, cl_mem memory); - absl::Status SetMemoryAuto(cl_mem memory); + Status SetMemory(int index, cl_mem memory); + Status SetMemoryAuto(cl_mem memory); template - absl::Status SetBytes(int index, const T& value) const { + Status SetBytes(int index, const T& value) const { return SetBytes(index, static_cast(&value), sizeof(T)); } template - absl::Status SetBytesAuto(const T& value) { + Status SetBytesAuto(const T& value) { return SetBytesAuto(static_cast(&value), sizeof(T)); } @@ -69,12 +69,12 @@ class CLKernel { // Do not use this function // workaround for Mali memory leak - absl::Status ReInit() const; + Status ReInit() const; private: void Release(); - absl::Status SetBytes(int index, const void* ptr, int length) const; - absl::Status SetBytesAuto(const void* ptr, int length); + Status SetBytes(int index, const void* ptr, int length) const; + Status SetBytesAuto(const void* ptr, int length); int private_memory_size_; int max_work_group_size_; @@ -87,22 +87,22 @@ class CLKernel { }; template <> -absl::Status CLKernel::SetBytes(int index, const FLT& value) const; +Status CLKernel::SetBytes(int index, const FLT& value) const; template <> -absl::Status CLKernel::SetBytes(int index, const FLT2& value) const; +Status CLKernel::SetBytes(int index, const FLT2& value) const; template <> -absl::Status CLKernel::SetBytes(int index, const FLT4& value) const; +Status CLKernel::SetBytes(int index, const FLT4& value) const; template <> -absl::Status CLKernel::SetBytesAuto(const FLT& value); +Status CLKernel::SetBytesAuto(const FLT& value); template <> -absl::Status CLKernel::SetBytesAuto(const FLT2& value); +Status CLKernel::SetBytesAuto(const FLT2& value); template <> -absl::Status CLKernel::SetBytesAuto(const FLT4& value); +Status CLKernel::SetBytesAuto(const FLT4& value); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/cl_program.cc b/tensorflow/lite/delegates/gpu/cl/cl_program.cc index 690bc598777..3592ad895ea 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_program.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_program.cc @@ -49,29 +49,28 @@ std::string GetProgramBuildInfo(cl_program program, cl_device_id id, return result; } -absl::Status GetBinarySize(cl_program program, size_t* binary_size) { +Status GetBinarySize(cl_program program, size_t* binary_size) { cl_int error_code = clGetProgramInfo(program, CL_PROGRAM_BINARY_SIZES, sizeof(size_t), binary_size, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to get program binary size - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to get program binary size - ", + CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status BuildProgram(cl_program program, const CLDevice& device, - const std::string& compiler_options) { +Status BuildProgram(cl_program program, const CLDevice& device, + const std::string& compiler_options) { const int error_code = clBuildProgram( program, 0, nullptr, compiler_options.c_str(), nullptr, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError(absl::StrCat( + return UnknownError(absl::StrCat( "Failed to build program executable - ", CLErrorCodeToString(error_code), GetProgramBuildInfo(program, device.id(), CL_PROGRAM_BUILD_LOG))); } - return absl::OkStatus(); + return OkStatus(); } std::string CompilerOptionToString(const CLDevice& device, @@ -134,7 +133,7 @@ void CLProgram::Release() { } } -absl::Status CLProgram::GetBinary(std::vector* result) const { +Status CLProgram::GetBinary(std::vector* result) const { size_t binary_size; RETURN_IF_ERROR(GetBinarySize(program_, &binary_size)); result->resize(result->size() + binary_size); @@ -142,36 +141,35 @@ absl::Status CLProgram::GetBinary(std::vector* result) const { cl_int error_code = clGetProgramInfo(program_, CL_PROGRAM_BINARIES, binary_size, &binary_ptr, nullptr); if (error_code != CL_SUCCESS) { - return absl::UnknownError(absl::StrCat("Failed to get program binary - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to get program binary - ", + CLErrorCodeToString(error_code))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateCLProgram(const std::string& code, - const std::string& compiler_options, - const CLContext& context, const CLDevice& device, - CLProgram* result) { +Status CreateCLProgram(const std::string& code, + const std::string& compiler_options, + const CLContext& context, const CLDevice& device, + CLProgram* result) { int error_code; const char* source = code.c_str(); cl_program program = clCreateProgramWithSource(context.context(), 1, &source, nullptr, &error_code); if (!program || error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to create compute program - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to create compute program - ", + CLErrorCodeToString(error_code))); } *result = CLProgram(program, device.id()); RETURN_IF_ERROR(BuildProgram(program, device, compiler_options)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateCLProgramFromBinary(const CLContext& context, - const CLDevice& device, - absl::Span binary, - CLProgram* result) { +Status CreateCLProgramFromBinary(const CLContext& context, + const CLDevice& device, + absl::Span binary, + CLProgram* result) { cl_int binary_status; cl_int error_code; cl_device_id devices_list[] = {device.id()}; @@ -181,13 +179,13 @@ absl::Status CreateCLProgramFromBinary(const CLContext& context, context.context(), 1, devices_list, &binary_size, &binary_pointer, &binary_status, &error_code); if (binary_status != CL_SUCCESS) { - return absl::UnknownError(absl::StrCat( + return UnknownError(absl::StrCat( "Something wrong with binary after clCreateProgramWithBinary - ", binary_status)); } if (error_code != CL_SUCCESS) { - return absl::UnknownError(absl::StrCat("Failed to create program - ", - CLErrorCodeToString(error_code))); + return UnknownError(absl::StrCat("Failed to create program - ", + CLErrorCodeToString(error_code))); } *result = CLProgram(program, device.id()); return BuildProgram(program, device, ""); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_program.h b/tensorflow/lite/delegates/gpu/cl/cl_program.h index fb2a7edb9c1..b6deb3beb95 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_program.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_program.h @@ -68,7 +68,7 @@ class CLProgram { // was created using clCreateProgramWithBinary. cl_device_id GetDeviceId() const { return device_id_; } - absl::Status GetBinary(std::vector* result) const; + Status GetBinary(std::vector* result) const; private: void Release(); @@ -79,15 +79,15 @@ class CLProgram { cl_device_id device_id_ = nullptr; }; -absl::Status CreateCLProgram(const std::string& code, - const std::string& compiler_options, - const CLContext& context, const CLDevice& device, - CLProgram* result); +Status CreateCLProgram(const std::string& code, + const std::string& compiler_options, + const CLContext& context, const CLDevice& device, + CLProgram* result); -absl::Status CreateCLProgramFromBinary(const CLContext& context, - const CLDevice& device, - absl::Span binary, - CLProgram* result); +Status CreateCLProgramFromBinary(const CLContext& context, + const CLDevice& device, + absl::Span binary, + CLProgram* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/egl_sync.cc b/tensorflow/lite/delegates/gpu/cl/egl_sync.cc index ddc373bce31..8493fbb049f 100644 --- a/tensorflow/lite/delegates/gpu/cl/egl_sync.cc +++ b/tensorflow/lite/delegates/gpu/cl/egl_sync.cc @@ -21,15 +21,15 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status EglSync::NewFence(EGLDisplay display, EglSync* sync) { +Status EglSync::NewFence(EGLDisplay display, EglSync* sync) { EGLSyncKHR egl_sync; RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglCreateSyncKHR, &egl_sync, display, EGL_SYNC_FENCE_KHR, nullptr)); if (egl_sync == EGL_NO_SYNC_KHR) { - return absl::InternalError("Returned empty KHR EGL sync"); + return InternalError("Returned empty KHR EGL sync"); } *sync = EglSync(display, egl_sync); - return absl::OkStatus(); + return OkStatus(); } EglSync& EglSync::operator=(EglSync&& sync) { @@ -48,23 +48,22 @@ void EglSync::Invalidate() { } } -absl::Status EglSync::ServerWait() { +Status EglSync::ServerWait() { EGLint result; RETURN_IF_ERROR( TFLITE_GPU_CALL_EGL(eglWaitSyncKHR, &result, display_, sync_, 0)); - return result == EGL_TRUE ? absl::OkStatus() - : absl::InternalError("eglWaitSync failed"); + return result == EGL_TRUE ? OkStatus() : InternalError("eglWaitSync failed"); } -absl::Status EglSync::ClientWait() { +Status EglSync::ClientWait() { EGLint result; // TODO(akulik): make it active wait for better performance RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglClientWaitSyncKHR, &result, display_, sync_, EGL_SYNC_FLUSH_COMMANDS_BIT_KHR, EGL_FOREVER_KHR)); return result == EGL_CONDITION_SATISFIED_KHR - ? absl::OkStatus() - : absl::InternalError("eglClientWaitSync failed"); + ? OkStatus() + : InternalError("eglClientWaitSync failed"); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/egl_sync.h b/tensorflow/lite/delegates/gpu/cl/egl_sync.h index d0943a797ee..27a551c5d59 100644 --- a/tensorflow/lite/delegates/gpu/cl/egl_sync.h +++ b/tensorflow/lite/delegates/gpu/cl/egl_sync.h @@ -32,7 +32,7 @@ class EglSync { // flushed. // // Depends on EGL_KHR_fence_sync extension. - static absl::Status NewFence(EGLDisplay display, EglSync* sync); + static Status NewFence(EGLDisplay display, EglSync* sync); // Creates invalid object. EglSync() : EglSync(EGL_NO_DISPLAY, EGL_NO_SYNC_KHR) {} @@ -50,10 +50,10 @@ class EglSync { // Causes GPU to block and wait until this sync has been signaled. // This call does not block and returns immediately. - absl::Status ServerWait(); + Status ServerWait(); // Causes CPU to block and wait until this sync has been signaled. - absl::Status ClientWait(); + Status ClientWait(); // Returns the EGLDisplay on which this instance was created. EGLDisplay display() const { return display_; } diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc index 01d034fb1f7..ca13e19f73f 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.cc +++ b/tensorflow/lite/delegates/gpu/cl/environment.cc @@ -26,7 +26,6 @@ namespace tflite { namespace gpu { namespace cl { namespace { - std::string GetKernelOneLayerTextureArray() { return R"( @@ -44,12 +43,12 @@ __kernel void main_function(__write_only image2d_array_t dst) { // texture, we will get zeroes instead of actual values. // The same kernel will work, if we use texture array with more than one layer. // With help of this code we can detect this bug. -absl::Status CheckKernelSupportOfOneLayerTextureArray(Environment* env, - bool* result) { +Status CheckKernelSupportOfOneLayerTextureArray(Environment* env, + bool* result) { // No bug on Adreno 6xx if (env->device().GetInfo().adreno_info.gpu_version >= 600) { *result = true; - return absl::OkStatus(); + return OkStatus(); } CLKernel kernel; RETURN_IF_ERROR(env->program_cache()->GetOrCreateCLKernel( @@ -76,12 +75,12 @@ absl::Status CheckKernelSupportOfOneLayerTextureArray(Environment* env, break; } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateEnvironment(Environment* result, bool shared, - cl_context_properties egl_context, - cl_context_properties egl_display) { +Status CreateEnvironment(Environment* result, bool shared, + cl_context_properties egl_context, + cl_context_properties egl_display) { CLDevice gpu; RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu)); @@ -108,9 +107,8 @@ absl::Status CreateEnvironment(Environment* result, bool shared, } } - return absl::OkStatus(); + return OkStatus(); } - } // namespace Environment::Environment(CLDevice&& device, CLContext&& context, @@ -139,7 +137,7 @@ Environment& Environment::operator=(Environment&& environment) { return *this; } -absl::Status Environment::Init() { +Status Environment::Init() { if (device().IsAdreno() && device().SupportsTextureArray()) { bool supports_one_layer; RETURN_IF_ERROR( @@ -148,7 +146,7 @@ absl::Status Environment::Init() { GetDevicePtr()->DisableOneLayerTextureArray(); } } - return absl::OkStatus(); + return OkStatus(); } void Environment::SetHighPerformance() const { @@ -268,7 +266,7 @@ TensorStorageType GetStorageTypeWithMinimalMemoryConsumption( return TensorStorageType::BUFFER; } -absl::Status CreateEnvironment(Environment* result) { +Status CreateEnvironment(Environment* result) { CLDevice gpu; RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu)); diff --git a/tensorflow/lite/delegates/gpu/cl/environment.h b/tensorflow/lite/delegates/gpu/cl/environment.h index b40d22d3dd6..496d6957623 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.h +++ b/tensorflow/lite/delegates/gpu/cl/environment.h @@ -57,7 +57,7 @@ class Environment { std::vector GetSupportedStorages() const; bool IsSupported(TensorStorageType storage_type) const; - absl::Status Init(); + Status Init(); void SetHighPerformance() const; void SetDefaultPerformance() const; @@ -75,7 +75,7 @@ TensorStorageType GetFastestStorageType(const CLDevice& gpu); TensorStorageType GetStorageTypeWithMinimalMemoryConsumption( const CLDevice& gpu); -absl::Status CreateEnvironment(Environment* result); +Status CreateEnvironment(Environment* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc index 648b772d827..f4db12bf133 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc @@ -41,11 +41,10 @@ PFNEGLCREATESYNCPROC g_eglCreateSync = nullptr; } // namespace -absl::Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, - EglSync* sync) { +Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, + EglSync* sync) { if (!IsEglSyncFromClEventSupported()) { - return absl::UnimplementedError( - "CreateEglSyncFromClEvent is not supported"); + return UnimplementedError("CreateEglSyncFromClEvent is not supported"); } EGLSync egl_sync; const EGLAttrib attributes[] = {EGL_CL_EVENT_HANDLE, @@ -53,10 +52,10 @@ absl::Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(g_eglCreateSync, &egl_sync, display, EGL_SYNC_CL_EVENT, attributes)); if (egl_sync == EGL_NO_SYNC) { - return absl::InternalError("Returned empty EGL sync"); + return InternalError("Returned empty EGL sync"); } *sync = EglSync(display, egl_sync); - return absl::OkStatus(); + return OkStatus(); } bool IsEglSyncFromClEventSupported() { @@ -74,54 +73,52 @@ bool IsEglSyncFromClEventSupported() { return supported; } -absl::Status CreateClEventFromEglSync(cl_context context, - const EglSync& egl_sync, CLEvent* event) { +Status CreateClEventFromEglSync(cl_context context, const EglSync& egl_sync, + CLEvent* event) { cl_int error_code; cl_event new_event = clCreateEventFromEGLSyncKHR( context, egl_sync.sync(), egl_sync.display(), &error_code); if (error_code != CL_SUCCESS) { - return absl::InternalError( + return InternalError( absl::StrCat("Unable to create CL sync from EGL sync. ", CLErrorCodeToString(error_code))); } *event = CLEvent(new_event); - return absl::OkStatus(); + return OkStatus(); } bool IsClEventFromEglSyncSupported(const CLDevice& device) { return device.SupportsExtension("cl_khr_egl_event"); } -absl::Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, - AccessType access_type, - CLContext* context, CLMemory* memory) { +Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, AccessType access_type, + CLContext* context, CLMemory* memory) { cl_int error_code; auto mem = clCreateFromGLBuffer(context->context(), ToClMemFlags(access_type), gl_ssbo_id, &error_code); if (error_code != CL_SUCCESS) { - return absl::InternalError( + return InternalError( absl::StrCat("Unable to acquire CL buffer from GL buffer. ", CLErrorCodeToString(error_code))); } *memory = CLMemory(mem, true); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateClMemoryFromGlTexture(GLenum texture_target, - GLuint texture_id, - AccessType access_type, - CLContext* context, CLMemory* memory) { +Status CreateClMemoryFromGlTexture(GLenum texture_target, GLuint texture_id, + AccessType access_type, CLContext* context, + CLMemory* memory) { cl_int error_code; auto mem = clCreateFromGLTexture(context->context(), ToClMemFlags(access_type), texture_target, 0, texture_id, &error_code); if (error_code != CL_SUCCESS) { - return absl::InternalError( + return InternalError( absl::StrCat("Unable to create CL buffer from GL texture. ", CLErrorCodeToString(error_code))); } *memory = CLMemory(mem, true); - return absl::OkStatus(); + return OkStatus(); } bool IsGlSharingSupported(const CLDevice& device) { @@ -131,18 +128,19 @@ bool IsGlSharingSupported(const CLDevice& device) { AcquiredGlObjects::~AcquiredGlObjects() { Release({}, nullptr).IgnoreError(); } -absl::Status AcquiredGlObjects::Acquire( - const std::vector& memory, cl_command_queue queue, - const std::vector& wait_events, CLEvent* acquire_event, - AcquiredGlObjects* objects) { +Status AcquiredGlObjects::Acquire(const std::vector& memory, + cl_command_queue queue, + const std::vector& wait_events, + CLEvent* acquire_event, + AcquiredGlObjects* objects) { if (!memory.empty()) { cl_event new_event; cl_int error_code = clEnqueueAcquireGLObjects( queue, memory.size(), memory.data(), wait_events.size(), wait_events.data(), acquire_event ? &new_event : nullptr); if (error_code != CL_SUCCESS) { - return absl::InternalError(absl::StrCat("Unable to acquire GL object. ", - CLErrorCodeToString(error_code))); + return InternalError(absl::StrCat("Unable to acquire GL object. ", + CLErrorCodeToString(error_code))); } if (acquire_event) { *acquire_event = CLEvent(new_event); @@ -150,19 +148,19 @@ absl::Status AcquiredGlObjects::Acquire( clFlush(queue); } *objects = AcquiredGlObjects(memory, queue); - return absl::OkStatus(); + return OkStatus(); } -absl::Status AcquiredGlObjects::Release( - const std::vector& wait_events, CLEvent* release_event) { +Status AcquiredGlObjects::Release(const std::vector& wait_events, + CLEvent* release_event) { if (queue_ && !memory_.empty()) { cl_event new_event; cl_int error_code = clEnqueueReleaseGLObjects( queue_, memory_.size(), memory_.data(), wait_events.size(), wait_events.data(), release_event ? &new_event : nullptr); if (error_code != CL_SUCCESS) { - return absl::InternalError(absl::StrCat("Unable to release GL object. ", - CLErrorCodeToString(error_code))); + return InternalError(absl::StrCat("Unable to release GL object. ", + CLErrorCodeToString(error_code))); } if (release_event) { *release_event = CLEvent(new_event); @@ -170,7 +168,7 @@ absl::Status AcquiredGlObjects::Release( clFlush(queue_); queue_ = nullptr; } - return absl::OkStatus(); + return OkStatus(); } GlInteropFabric::GlInteropFabric(EGLDisplay egl_display, @@ -194,9 +192,9 @@ void GlInteropFabric::UnregisterMemory(cl_mem memory) { } } -absl::Status GlInteropFabric::Start() { +Status GlInteropFabric::Start() { if (!is_enabled()) { - return absl::OkStatus(); + return OkStatus(); } // In GL-CL interoperability, we need to make sure GL finished processing of @@ -237,9 +235,9 @@ absl::Status GlInteropFabric::Start() { nullptr, &gl_objects_); } -absl::Status GlInteropFabric::Finish() { +Status GlInteropFabric::Finish() { if (!is_enabled()) { - return absl::OkStatus(); + return OkStatus(); } RETURN_IF_ERROR(gl_objects_.Release({}, &outbound_event_)); @@ -260,7 +258,7 @@ absl::Status GlInteropFabric::Finish() { // This slow sync is the only working solution right now. We have to debug why // above version is not working fast and reliable. outbound_event_.Wait(); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.h b/tensorflow/lite/delegates/gpu/cl/gl_interop.h index 7ebc3e4bf4f..597bee857c6 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.h +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.h @@ -39,8 +39,8 @@ namespace cl { // returned sync and could be safely destroyed. // // Depends on EGL 1.5. -absl::Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, - EglSync* sync); +Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display, + EglSync* sync); // Returns true if 'CreateEglSyncFromClEvent' is supported. bool IsEglSyncFromClEventSupported(); @@ -48,22 +48,20 @@ bool IsEglSyncFromClEventSupported(); // Creates CL event from EGL sync. // Created event could only be consumed by AcquiredGlObject::Acquire call as // a 'wait_event'. -absl::Status CreateClEventFromEglSync(cl_context context, - const EglSync& egl_sync, CLEvent* event); +Status CreateClEventFromEglSync(cl_context context, const EglSync& egl_sync, + CLEvent* event); // Returns true if 'CreateClEventFromEglSync' is supported. bool IsClEventFromEglSyncSupported(const CLDevice& device); // Creates new CL memory object from OpenGL buffer. -absl::Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, - AccessType access_type, - CLContext* context, CLMemory* memory); +Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, AccessType access_type, + CLContext* context, CLMemory* memory); // Creates new CL memory object from OpenGL texture. -absl::Status CreateClMemoryFromGlTexture(GLenum texture_target, - GLuint texture_id, - AccessType access_type, - CLContext* context, CLMemory* memory); +Status CreateClMemoryFromGlTexture(GLenum texture_target, GLuint texture_id, + AccessType access_type, CLContext* context, + CLMemory* memory); // Returns true if GL objects could be shared with OpenCL context. bool IsGlSharingSupported(const CLDevice& device); @@ -83,16 +81,16 @@ class AcquiredGlObjects { // CreateClMemoryFromGlBuffer or CreateClMemoryFromGlTexture calls. // If 'acquire_event' is not nullptr, it will be signared once acquisition is // complete. - static absl::Status Acquire(const std::vector& memory, - cl_command_queue queue, - const std::vector& wait_events, - CLEvent* acquire_event /* optional */, - AcquiredGlObjects* objects); + static Status Acquire(const std::vector& memory, + cl_command_queue queue, + const std::vector& wait_events, + CLEvent* acquire_event /* optional */, + AcquiredGlObjects* objects); // Releases OpenCL memory back to OpenGL context. If 'release_event' is not // nullptr, it will be signalled once release is complete. - absl::Status Release(const std::vector& wait_events, - CLEvent* release_event /* optional */); + Status Release(const std::vector& wait_events, + CLEvent* release_event /* optional */); private: AcquiredGlObjects(const std::vector& memory, cl_command_queue queue) @@ -110,10 +108,10 @@ class GlInteropFabric { // Ensures proper GL->CL synchronization is in place before // GL objects that are mapped to CL objects are used. - absl::Status Start(); + Status Start(); // Puts appropriate CL->GL synchronization after all work is complete. - absl::Status Finish(); + Status Finish(); // Registers memory to be used from GL context. Such CL memory object must // be created with CreateClMemoryFromGlBuffer or CreateClMemoryFromGlTexture diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc index 0e2d046eba2..8e2c3308a47 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc +++ b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc @@ -87,8 +87,8 @@ class Delegate { } } - absl::Status Prepare(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params) { + Status Prepare(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { // Extract TFLite delegate execution plan from the context and convert it // into FlowGraph32. GraphFloat32 graph; @@ -98,7 +98,7 @@ class Delegate { NullTransformationReporter reporter; ModelTransformer transformer(&graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + return InternalError("Graph general transformations failed"); } InferenceEnvironmentOptions env_options; @@ -108,7 +108,7 @@ class Delegate { options_.serialized_binary_cache_data, options_.serialized_binary_cache_size}; InferenceEnvironmentProperties properties; - absl::Status status = + Status status = NewInferenceEnvironment(env_options, &environment_, &properties); if (!properties.is_opencl_available) { context->ReportError(context, @@ -200,7 +200,7 @@ class Delegate { return builder->Build(&runner_); } - absl::Status SetInputsAndOutputs(TfLiteContext* context) { + Status SetInputsAndOutputs(TfLiteContext* context) { int i = 0; for (auto index : input_indices_) { RETURN_IF_ERROR( @@ -211,10 +211,10 @@ class Delegate { RETURN_IF_ERROR( runner_->SetOutputObject(i++, GetTensorObject(index, context))); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status Invoke(TfLiteContext* context) { + Status Invoke(TfLiteContext* context) { RETURN_IF_ERROR(SetInputsAndOutputs(context)); return runner_->Run(); } @@ -310,7 +310,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = gpu_delegate->Prepare(context, params); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Init: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return nullptr; } return gpu_delegate; @@ -335,7 +335,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = GetDelegate(node)->Invoke(context); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Invoke: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index 2ec911813e6..47998bf8c99 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -169,9 +169,9 @@ CLNode& CLNode::operator=(CLNode&& node) { return *this; } -absl::Status InferenceContext::InitFromGraph( - const CreateInferenceInfo& create_info, const GraphFloat32& graph, - Environment* env) { +Status InferenceContext::InitFromGraph(const CreateInferenceInfo& create_info, + const GraphFloat32& graph, + Environment* env) { CreationContext creation_context; creation_context.device = env->GetDevicePtr(); creation_context.context = &env->context(); @@ -206,15 +206,15 @@ absl::Status InferenceContext::InitFromGraph( tuning_parameters.tuning_type = TuningType::FAST; } RETURN_IF_ERROR(Tune(tuning_parameters)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status InferenceContext::InitFromGraphWithTransforms( +Status InferenceContext::InitFromGraphWithTransforms( const CreateInferenceInfo& create_info, GraphFloat32* graph, Environment* env) { RETURN_IF_ERROR(RunGraphTransforms(graph)); RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env)); - return absl::OkStatus(); + return OkStatus(); } void InferenceContext::CopyInAndOutIds(const GraphFloat32& graph) { @@ -258,7 +258,7 @@ void InferenceContext::ReserveGraphTensors( tensor_reserver_.SetNext(max_id + 1); } -absl::Status InferenceContext::ConvertOperations( +Status InferenceContext::ConvertOperations( const CreationContext& creation_context, const GraphFloat32& graph, ModelHints hints) { std::vector graph_nodes = graph.nodes(); @@ -343,7 +343,7 @@ absl::Status InferenceContext::ConvertOperations( } } - return absl::OkStatus(); + return OkStatus(); } void InferenceContext::Merge() { @@ -424,15 +424,15 @@ void InferenceContext::GetUsages( } } -absl::Status InferenceContext::AllocateMemory(const CLDevice& device, - CLContext* context) { +Status InferenceContext::AllocateMemory(const CLDevice& device, + CLContext* context) { RETURN_IF_ERROR(AllocateMemoryForBuffers(device, context)); RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device, context)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device, - CLContext* context) { +Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device, + CLContext* context) { std::map buffer_usages; GetUsages( [](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); }, @@ -480,11 +480,11 @@ absl::Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device, created_tensors[tensor_index] = true; } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status InferenceContext::AllocateMemoryForStrongShapes( - const CLDevice& device, CLContext* context) { +Status InferenceContext::AllocateMemoryForStrongShapes(const CLDevice& device, + CLContext* context) { std::map usages; GetUsages( [](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); }, @@ -517,7 +517,7 @@ absl::Status InferenceContext::AllocateMemoryForStrongShapes( } } } - return absl::OkStatus(); + return OkStatus(); } void InferenceContext::BindMemoryToOperations() { @@ -539,22 +539,21 @@ void InferenceContext::BindMemoryToOperations() { } } -absl::Status InferenceContext::Compile( - const CreationContext& creation_context) { +Status InferenceContext::Compile(const CreationContext& creation_context) { for (auto& node : nodes_) { RETURN_IF_ERROR(node.operations[0]->Compile(creation_context)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status InferenceContext::Tune(const TuningParameters& tuning_parameters) { +Status InferenceContext::Tune(const TuningParameters& tuning_parameters) { for (auto& node : nodes_) { RETURN_IF_ERROR(node.operations[0]->Tune(tuning_parameters)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status InferenceContext::AddToQueue(CLCommandQueue* queue) { +Status InferenceContext::AddToQueue(CLCommandQueue* queue) { if (need_manual_release_) { if (prev_enqueue_start_point_.is_valid()) { prev_enqueue_start_point_.Wait(); @@ -572,11 +571,11 @@ absl::Status InferenceContext::AddToQueue(CLCommandQueue* queue) { if (need_flush_) { clFlush(queue->queue()); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status InferenceContext::Profile(ProfilingCommandQueue* queue, - ProfilingInfo* result) { +Status InferenceContext::Profile(ProfilingCommandQueue* queue, + ProfilingInfo* result) { queue->ResetMeasurements(); for (auto& node : nodes_) { queue->SetEventsLabel(node.name); @@ -584,7 +583,7 @@ absl::Status InferenceContext::Profile(ProfilingCommandQueue* queue, } RETURN_IF_ERROR(queue->WaitForCompletion()); *result = queue->GetProfilingInfo(); - return absl::OkStatus(); + return OkStatus(); } uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors() @@ -609,15 +608,13 @@ Tensor* InferenceContext::GetTensor(ValueId id) { } } -absl::Status InferenceContext::SetInputTensor(ValueId id, - const TensorFloat32& tensor, - CLCommandQueue* queue) { +Status InferenceContext::SetInputTensor(ValueId id, const TensorFloat32& tensor, + CLCommandQueue* queue) { return GetTensor(id)->WriteData(queue, tensor); } -absl::Status InferenceContext::GetOutputTensor(ValueId id, - CLCommandQueue* queue, - TensorFloat32* result) { +Status InferenceContext::GetOutputTensor(ValueId id, CLCommandQueue* queue, + TensorFloat32* result) { const auto& gpu_tensor = *GetTensor(id); const auto dst_shape = BHWC(gpu_tensor.Batch(), gpu_tensor.Height(), gpu_tensor.Width(), gpu_tensor.Channels()); @@ -627,17 +624,17 @@ absl::Status InferenceContext::GetOutputTensor(ValueId id, return gpu_tensor.ReadData(queue, result); } -absl::Status RunGraphTransforms(GraphFloat32* graph) { +Status RunGraphTransforms(GraphFloat32* graph) { auto merge_padding_transform = NewMergePaddingWithAdd(); auto add_bias_transform = NewAddBias(); ModelTransformer transformer(graph, /*reporter=*/nullptr); if (!transformer.Apply("add_bias", add_bias_transform.get())) { - return absl::InternalError("Invalid add_bias transform"); + return InternalError("Invalid add_bias transform"); } if (!transformer.Apply("merge_padding", merge_padding_transform.get())) { - return absl::InternalError("Invalid merge_padding transform"); + return InternalError("Invalid merge_padding transform"); } - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.h b/tensorflow/lite/delegates/gpu/cl/inference_context.h index 75365258e41..40b20e8806a 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.h +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.h @@ -65,55 +65,53 @@ class InferenceContext { TensorStorageType storage_type; ModelHints hints; }; - absl::Status InitFromGraph(const CreateInferenceInfo& create_info, - const GraphFloat32& graph, Environment* env); + Status InitFromGraph(const CreateInferenceInfo& create_info, + const GraphFloat32& graph, Environment* env); // Applies OpenCL-specific transformations to the graph before the // initialization. These transformations are either impossible or useless in // other backends. - absl::Status InitFromGraphWithTransforms( - const CreateInferenceInfo& create_info, GraphFloat32* graph, - Environment* env); + Status InitFromGraphWithTransforms(const CreateInferenceInfo& create_info, + GraphFloat32* graph, Environment* env); - absl::Status AddToQueue(CLCommandQueue* queue); - absl::Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result); + Status AddToQueue(CLCommandQueue* queue); + Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result); // for profiling and memory statistics uint64_t GetSizeOfMemoryAllocatedForIntermediateTensors() const; - absl::Status SetInputTensor(ValueId id, const TensorFloat32& tensor, - CLCommandQueue* queue); + Status SetInputTensor(ValueId id, const TensorFloat32& tensor, + CLCommandQueue* queue); // It will work only with input/output tensor ids. For all other ids we don't // have any guarantees. Tensor* GetTensor(ValueId id); - absl::Status GetOutputTensor(ValueId id, CLCommandQueue* queue, - TensorFloat32* result); + Status GetOutputTensor(ValueId id, CLCommandQueue* queue, + TensorFloat32* result); private: void CopyInAndOutIds(const GraphFloat32& graph); - absl::Status ConvertOperations(const CreationContext& creation_context, - const GraphFloat32& graph, ModelHints hints); + Status ConvertOperations(const CreationContext& creation_context, + const GraphFloat32& graph, ModelHints hints); void CreateLinks(); void ReserveGraphTensors(const CreateInferenceInfo& create_info, const CreationContext& creation_context, const GraphFloat32& graph); void Merge(); - absl::Status AllocateMemory(const CLDevice& device, CLContext* context); + Status AllocateMemory(const CLDevice& device, CLContext* context); - absl::Status AllocateMemoryForBuffers(const CLDevice& device, - CLContext* context); + Status AllocateMemoryForBuffers(const CLDevice& device, CLContext* context); - absl::Status AllocateMemoryForStrongShapes(const CLDevice& device, - CLContext* context); + Status AllocateMemoryForStrongShapes(const CLDevice& device, + CLContext* context); // utility function void GetUsages(const std::function& functor, std::map* usages); void BindMemoryToOperations(); - absl::Status Compile(const CreationContext& creation_context); - absl::Status Tune(const TuningParameters& tuning_parameters); + Status Compile(const CreationContext& creation_context); + Status Tune(const TuningParameters& tuning_parameters); // performance hacks bool need_flush_ = false; @@ -177,7 +175,7 @@ class InferenceContext { }; // Runs OpenCL specific transforms for the graph. -absl::Status RunGraphTransforms(GraphFloat32* graph); +Status RunGraphTransforms(GraphFloat32* graph); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc index 0c96f4316ec..b5c37c5987f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc @@ -143,17 +143,17 @@ std::string Add::GetArgsDeclaration() const { return args; } -absl::Status Add::BindArguments(CLKernel* kernel) { +Status Add::BindArguments(CLKernel* kernel) { for (int i = 1; i < src_depthes_.size(); ++i) { RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[i]->GetMemoryPtr())); } for (int i = 1; i < src_depthes_.size(); ++i) { RETURN_IF_ERROR(kernel->SetBytesAuto(src_[i]->GetWBatchedHSB())); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Add::Compile(const CreationContext& creation_context) { +Status Add::Compile(const CreationContext& creation_context) { const auto code = GetElementWiseCode(definition_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.h b/tensorflow/lite/delegates/gpu/cl/kernels/add.h index d47954748c7..ac6243cc5e4 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/add.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.h @@ -36,7 +36,7 @@ class Add : public ElementwiseOperation { Add(const OperationDef& definition, const std::vector& channels, int dst_channels); - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Add(Add&& operation); @@ -47,7 +47,7 @@ class Add : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - absl::Status BindArguments(CLKernel* kernel) override; + Status BindArguments(CLKernel* kernel) override; private: std::string GetElementWiseCode( diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc index deb0ebf67c4..ad4b54853e1 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc @@ -21,17 +21,17 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status ExecuteGPUOperation(const std::vector& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, - const std::vector& dst_sizes, - const std::vector& dst_cpu) { +Status ExecuteGPUOperation(const std::vector& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, + const std::vector& dst_sizes, + const std::vector& dst_cpu) { const OperationDef& op_def = operation->GetDefinition(); std::vector src(src_cpu.size()); for (int i = 0; i < src_cpu.size(); ++i) { auto src_shape = src_cpu[i].shape; if (src_shape.b != 1 && !op_def.IsBatchSupported()) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Layout doesn't have Batch dimension, but shape.b != 1"); } RETURN_IF_ERROR(CreateTensor(*creation_context.context, @@ -45,7 +45,7 @@ absl::Status ExecuteGPUOperation(const std::vector& src_cpu, for (int i = 0; i < dst_cpu.size(); ++i) { auto dst_shape = dst_sizes[i]; if (dst_shape.b != 1 && !op_def.IsBatchSupported()) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Layout doesn't have Batch dimension, but shape.b != 1"); } RETURN_IF_ERROR(CreateTensor(*creation_context.context, @@ -64,22 +64,22 @@ absl::Status ExecuteGPUOperation(const std::vector& src_cpu, dst_cpu[i]->data = std::vector(dst_sizes[i].DimensionsProduct(), 0); RETURN_IF_ERROR(dst[i].ReadData(creation_context.queue, dst_cpu[i])); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status ExecuteGPUOperation(const std::vector& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, const BHWC& dst_size, - TensorFloat32* result) { +Status ExecuteGPUOperation(const std::vector& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, const BHWC& dst_size, + TensorFloat32* result) { return ExecuteGPUOperation( std::vector{src_cpu}, creation_context, operation, std::vector{dst_size}, std::vector{result}); } -absl::Status ExecuteGPUOperation(const TensorFloat32& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, const BHWC& dst_size, - TensorFloat32* result) { +Status ExecuteGPUOperation(const TensorFloat32& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, const BHWC& dst_size, + TensorFloat32* result) { return ExecuteGPUOperation(std::vector{src_cpu}, creation_context, operation, dst_size, result); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h index 4d3636d0384..c127d1bacd3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h @@ -51,21 +51,21 @@ class OpenCLOperationTest : public ::testing::Test { CreationContext creation_context_; }; -absl::Status ExecuteGPUOperation(const TensorFloat32& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, const BHWC& dst_size, - TensorFloat32* result); +Status ExecuteGPUOperation(const TensorFloat32& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, const BHWC& dst_size, + TensorFloat32* result); -absl::Status ExecuteGPUOperation(const std::vector& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, const BHWC& dst_size, - TensorFloat32* result); +Status ExecuteGPUOperation(const std::vector& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, const BHWC& dst_size, + TensorFloat32* result); -absl::Status ExecuteGPUOperation(const std::vector& src_cpu, - const CreationContext& creation_context, - GPUOperation* operation, - const std::vector& dst_sizes, - const std::vector& dst_cpu); +Status ExecuteGPUOperation(const std::vector& src_cpu, + const CreationContext& creation_context, + GPUOperation* operation, + const std::vector& dst_sizes, + const std::vector& dst_cpu); } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc index ef7915afba5..141a19de6e1 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc @@ -96,7 +96,7 @@ ConcatXY& ConcatXY::operator=(ConcatXY&& operation) { return *this; } -absl::Status ConcatXY::Compile(const CreationContext& creation_context) { +Status ConcatXY::Compile(const CreationContext& creation_context) { const auto code = GetConcatKernelCode(definition_, tensors_count_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -104,7 +104,7 @@ absl::Status ConcatXY::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status ConcatXY::BindArguments() { +Status ConcatXY::BindArguments() { kernel_.ResetBindingCounter(); for (int i = 0; i < tensors_count_; ++i) { RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr())); @@ -122,7 +122,7 @@ absl::Status ConcatXY::BindArguments() { y_offset += attr_.axis == Axis::HEIGHT ? height : 0; } RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 ConcatXY::GetGridSize() const { @@ -140,12 +140,12 @@ int3 ConcatXY::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConcatXY::Tune(const TuningParameters& params) { +Status ConcatXY::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status ConcatXY::AddToQueue(CLCommandQueue* queue) { +Status ConcatXY::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h index a170b593cf0..6bc0c87a51f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h @@ -31,10 +31,10 @@ class ConcatXY : public GPUOperation { ConcatXY(const OperationDef& definition, const ConcatAttributes& attr, int tensors_count) : GPUOperation(definition), attr_(attr), tensors_count_(tensors_count) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConcatXY(ConcatXY&& operation); @@ -43,7 +43,7 @@ class ConcatXY : public GPUOperation { ConcatXY& operator=(const ConcatXY&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; ConcatAttributes attr_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc index 3a7ec1c0cb7..039fac0d0e3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc @@ -25,8 +25,8 @@ limitations under the License. namespace tflite { namespace gpu { namespace cl { -namespace { +namespace { bool IsAllChannelsX4(const std::vector& channels) { for (int channel : channels) { if (channel % 4 != 0) { @@ -146,7 +146,6 @@ std::string GetConcatKernelCode( c += "}\n"; return c; } - } // namespace ConcatZ::ConcatZ(ConcatZ&& kernel) @@ -165,7 +164,7 @@ ConcatZ& ConcatZ::operator=(ConcatZ&& kernel) { return *this; } -absl::Status ConcatZ::Compile(const CreationContext& creation_context) { +Status ConcatZ::Compile(const CreationContext& creation_context) { const auto code = GetConcatKernelCode(definition_, channels_, linked_operations_); std::vector options; @@ -187,7 +186,7 @@ absl::Status ConcatZ::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status ConcatZ::BindArguments() { +Status ConcatZ::BindArguments() { kernel_.ResetBindingCounter(); for (int i = 0; i < channels_.size(); ++i) { RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr())); @@ -198,7 +197,7 @@ absl::Status ConcatZ::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[i]->Slices())); } RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 ConcatZ::GetGridSize() const { @@ -208,12 +207,12 @@ int3 ConcatZ::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConcatZ::Tune(const TuningParameters& params) { +Status ConcatZ::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status ConcatZ::AddToQueue(CLCommandQueue* queue) { +Status ConcatZ::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h index ec25f6e4ed9..9fc0fcc1fdb 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h @@ -32,10 +32,10 @@ class ConcatZ : public GPUOperation { public: ConcatZ(const OperationDef& definition, const std::vector& channels) : GPUOperation(definition), channels_(channels) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConcatZ(ConcatZ&& kernel); @@ -44,7 +44,7 @@ class ConcatZ : public GPUOperation { ConcatZ& operator=(const ConcatZ&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; std::vector channels_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc index b79599d8e95..e6015357bfc 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc @@ -76,7 +76,7 @@ Conv3D& Conv3D::operator=(Conv3D&& operation) { return *this; } -absl::Status Conv3D::Compile(const CreationContext& creation_context) { +Status Conv3D::Compile(const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; const std::string code = @@ -92,7 +92,7 @@ absl::Status Conv3D::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Conv3D::BindArguments() { +Status Conv3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (conv_params_.AreWeightsBuffer()) { @@ -131,7 +131,7 @@ absl::Status Conv3D::BindArguments() { IntegralDivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w))); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS())); - return absl::OkStatus(); + return OkStatus(); } int3 Conv3D::GetGridSize() const { @@ -154,12 +154,12 @@ int3 Conv3D::GetGridSize() const { conv_params_.work_group_size.z); } -absl::Status Conv3D::Tune(const TuningParameters& params) { +Status Conv3D::Tune(const TuningParameters& params) { if (conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP || conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_BY_THREADS) { - return absl::OkStatus(); + return OkStatus(); } if (conv_params_.work_group_launch_order[0] == 0 && conv_params_.work_group_launch_order[1] == 1 && @@ -168,10 +168,10 @@ absl::Status Conv3D::Tune(const TuningParameters& params) { return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &conv_params_.work_group_size); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Conv3D::AddToQueue(CLCommandQueue* queue) { +Status Conv3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), conv_params_.work_group_size); @@ -903,9 +903,9 @@ Conv3D::ConvParams Conv3D::GuessBestParams( x_kernel_is_1, y_kernel_is_1, z_kernel_is_1); } -absl::Status CreateConv3D(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution3DAttributes& attr, Conv3D* result) { +Status CreateConv3D(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution3DAttributes& attr, Conv3D* result) { *result = Conv3D(definition, attr, *creation_context.device); return result->UploadData(attr.weights, attr.bias, creation_context.context); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h index 00b1e868e5d..8fc48c4114a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h @@ -39,9 +39,9 @@ namespace cl { class Conv3D : public GPUOperation { public: Conv3D() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; + Status Compile(const CreationContext& creation_context) override; // Move only Conv3D(Conv3D&& operation); @@ -75,21 +75,21 @@ class Conv3D : public GPUOperation { const CLDevice& device); template - absl::Status UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + Status UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - friend absl::Status CreateConv3D(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution3DAttributes& attr, - Conv3D* result); + friend Status CreateConv3D(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution3DAttributes& attr, + Conv3D* result); friend std::string GenerateConv3D( const OperationDef& op_def, const LinearStorage& biases, @@ -105,7 +105,7 @@ class Conv3D : public GPUOperation { int dst_slices, bool x_kernel_is_1, bool y_kernel_is_1, bool z_kernel_is_1) const; - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Texture2D weights_0_; @@ -125,9 +125,9 @@ class Conv3D : public GPUOperation { }; template -absl::Status Conv3D::UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context) { +Status Conv3D::UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context) { RETURN_IF_ERROR(UploadWeights(weights, context)); LinearStorageCreateInfo create_info; create_info.storage_type = conv_params_.AreWeightsBuffer() @@ -139,12 +139,12 @@ absl::Status Conv3D::UploadData(const ::tflite::gpu::Tensor& weights, create_info.name = "biases"; create_info.aligned_size = weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_)); - return absl::OkStatus(); + return OkStatus(); } template -absl::Status Conv3D::UploadWeights( - const ::tflite::gpu::Tensor& weights, CLContext* context) { +Status Conv3D::UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context) { const int block_size = conv_params_.block_size.w; const int dst_slices = AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size); @@ -211,7 +211,7 @@ absl::Status Conv3D::UploadWeights( } } - return absl::OkStatus(); + return OkStatus(); } template @@ -271,9 +271,9 @@ void Conv3D::RearrangeWeightsData( } } -absl::Status CreateConv3D(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution3DAttributes& attr, Conv3D* result); +Status CreateConv3D(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution3DAttributes& attr, Conv3D* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc index 70bd1b5249f..3a8c726021c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc @@ -291,16 +291,16 @@ ConvBuffer1x1& ConvBuffer1x1::operator=(ConvBuffer1x1&& operation) { return *this; } -absl::Status ConvBuffer1x1::Compile(const CreationContext& creation_context) { +Status ConvBuffer1x1::Compile(const CreationContext& creation_context) { std::string code = GenerateConvBuffer1x1(definition_, conv_params_, linked_operations_); RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status ConvBuffer1x1::BindArguments() { +Status ConvBuffer1x1::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -313,7 +313,7 @@ absl::Status ConvBuffer1x1::BindArguments() { src_width_elements * src_[0]->Height()); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_size)); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 ConvBuffer1x1::GetGridSize() const { @@ -328,13 +328,13 @@ int3 ConvBuffer1x1::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConvBuffer1x1::Tune(const TuningParameters& params) { +Status ConvBuffer1x1::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &conv_params_.work_group_size); } -absl::Status ConvBuffer1x1::AddToQueue(CLCommandQueue* queue) { +Status ConvBuffer1x1::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), conv_params_.work_group_size); @@ -351,12 +351,12 @@ bool IsConvBuffer1x1Supported(const OperationDef& definition, attr.padding.appended.w == 0 && attr.padding.appended.h == 0; } -absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvBuffer1x1* result, const BHWC* shape) { +Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvBuffer1x1* result, const BHWC* shape) { if (!IsConvBuffer1x1Supported(definition, attr)) { - return absl::InvalidArgumentError("ConvBuffer1x1 doesn't supported"); + return InvalidArgumentError("ConvBuffer1x1 doesn't supported"); } const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); @@ -372,10 +372,10 @@ absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, return result->UploadData(attr.weights, attr.bias, creation_context.context); } -absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvBuffer1x1* result, const BHWC* shape) { +Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvBuffer1x1* result, const BHWC* shape) { const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); ConvBuffer1x1::ConvParams conv_params; @@ -392,10 +392,11 @@ absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, return result->UploadData(attr.weights, attr.bias, creation_context.context); } -absl::Status CreateConvBuffer1x1Wino4x4To6x6( - const CreationContext& creation_context, const OperationDef& definition, - const Convolution2DAttributes& attr, ConvBuffer1x1* result, - const BHWC* shape) { +Status CreateConvBuffer1x1Wino4x4To6x6(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvBuffer1x1* result, + const BHWC* shape) { const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); ConvBuffer1x1::ConvParams conv_params; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h index 07da846107e..54e99d29ec7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h @@ -45,10 +45,10 @@ class ConvBuffer1x1 : public GPUOperation { ConvBuffer1x1(const ConvBuffer1x1&) = delete; ConvBuffer1x1& operator=(const ConvBuffer1x1&) = delete; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; struct ConvParams { int3 block_size = int3(1, 1, 1); @@ -64,33 +64,33 @@ class ConvBuffer1x1 : public GPUOperation { private: ConvBuffer1x1(const OperationDef& definition, const ConvParams& conv_params); - friend absl::Status CreateConvBuffer1x1( - const CreationContext& creation_context, const OperationDef& definition, - const Convolution2DAttributes& attr, ConvBuffer1x1* result, - const BHWC* shape); - friend absl::Status CreateConvBuffer1x1( - const CreationContext& creation_context, const OperationDef& definition, - const FullyConnectedAttributes& attr, ConvBuffer1x1* result, - const BHWC* shape); - friend absl::Status CreateConvBuffer1x1Wino4x4To6x6( + friend Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvBuffer1x1* result, const BHWC* shape); + friend Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvBuffer1x1* result, const BHWC* shape); + friend Status CreateConvBuffer1x1Wino4x4To6x6( const CreationContext& creation_context, const OperationDef& definition, const Convolution2DAttributes& attr, ConvBuffer1x1* result, const BHWC* shape); template - absl::Status UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + Status UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); template - absl::Status UploadDataForWinograd4x4To6x6( + Status UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -101,20 +101,20 @@ class ConvBuffer1x1 : public GPUOperation { }; template -absl::Status ConvBuffer1x1::UploadData( - const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, CLContext* context) { +Status ConvBuffer1x1::UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context) { RETURN_IF_ERROR(UploadWeights(weights, context)); LinearStorageCreateInfo create_info; create_info.storage_type = LinearStorageType::BUFFER; create_info.data_type = definition_.GetDataType(); create_info.aligned_size = weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_)); - return absl::OkStatus(); + return OkStatus(); } template -absl::Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6( +Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context) { ::tflite::gpu::Tensor wino_weights; @@ -132,7 +132,7 @@ absl::Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6( } template -absl::Status ConvBuffer1x1::UploadWeights( +Status ConvBuffer1x1::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -162,22 +162,21 @@ absl::Status ConvBuffer1x1::UploadWeights( bool IsConvBuffer1x1Supported(const OperationDef& definition, const Convolution2DAttributes& attr); -absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvBuffer1x1* result, - const BHWC* shape = nullptr); +Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvBuffer1x1* result, const BHWC* shape = nullptr); -absl::Status CreateConvBuffer1x1(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvBuffer1x1* result, - const BHWC* shape = nullptr); +Status CreateConvBuffer1x1(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvBuffer1x1* result, const BHWC* shape = nullptr); -absl::Status CreateConvBuffer1x1Wino4x4To6x6( - const CreationContext& creation_context, const OperationDef& definition, - const Convolution2DAttributes& attr, ConvBuffer1x1* result, - const BHWC* shape = nullptr); +Status CreateConvBuffer1x1Wino4x4To6x6(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvBuffer1x1* result, + const BHWC* shape = nullptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc index 07d2da9d641..ceb3b8985e8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc @@ -219,7 +219,7 @@ ConvConstants& ConvConstants::operator=(ConvConstants&& kernel) { return *this; } -absl::Status ConvConstants::Compile(const CreationContext& creation_context) { +Status ConvConstants::Compile(const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; const auto code = GenerateConvolutionConstantCode( @@ -240,7 +240,7 @@ absl::Status ConvConstants::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status ConvConstants::BindArguments() { +Status ConvConstants::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -254,7 +254,7 @@ absl::Status ConvConstants::BindArguments() { kernel_.SetBytesAuto(int2(dilation_.x * src_[0]->Batch(), dilation_.y))); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 ConvConstants::GetGridSize() const { @@ -263,12 +263,12 @@ int3 ConvConstants::GetGridSize() const { return int3(grid_x, grid_y, 1); } -absl::Status ConvConstants::Tune(const TuningParameters& params) { +Status ConvConstants::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status ConvConstants::AddToQueue(CLCommandQueue* queue) { +Status ConvConstants::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -294,12 +294,12 @@ bool IsConvConstantsSupported(const CLDevice& device, return filters_buffer_size <= kConstantMaxSize && flt4_registers <= 8; } -absl::Status CreateConvConstants(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvConstants* result) { +Status CreateConvConstants(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvConstants* result) { if (!IsConvConstantsSupported(*creation_context.device, definition, attr)) { - return absl::InvalidArgumentError("ConvConstants doesn't supported"); + return InvalidArgumentError("ConvConstants doesn't supported"); } *result = ConvConstants(definition, attr); RETURN_IF_ERROR( @@ -310,7 +310,8 @@ absl::Status CreateConvConstants(const CreationContext& creation_context, create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h index fc0e66b5e86..b4830d20fd1 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h @@ -35,10 +35,10 @@ namespace cl { class ConvConstants : public GPUOperation { public: ConvConstants() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvConstants(ConvConstants&& kernel); @@ -47,9 +47,10 @@ class ConvConstants : public GPUOperation { ConvConstants& operator=(const ConvConstants&) = delete; private: - friend absl::Status CreateConvConstants( - const CreationContext& creation_context, const OperationDef& definition, - const Convolution2DAttributes& attr, ConvConstants* result); + friend Status CreateConvConstants(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvConstants* result); explicit ConvConstants(const OperationDef& definition, const Convolution2DAttributes& attr) : GPUOperation(definition), @@ -61,14 +62,14 @@ class ConvConstants : public GPUOperation { dst_channels_(attr.weights.shape.o) {} template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -86,7 +87,7 @@ class ConvConstants : public GPUOperation { }; template -absl::Status ConvConstants::UploadWeights( +Status ConvConstants::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); const int kernel_x = weights.shape.w; @@ -156,10 +157,10 @@ bool IsConvConstantsSupported(const CLDevice& device, const OperationDef& definition, const Convolution2DAttributes& attr); -absl::Status CreateConvConstants(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvConstants* result); +Status CreateConvConstants(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvConstants* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc index bd4f53395f3..c1860d6452f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc @@ -173,7 +173,7 @@ ConvPowerVR& ConvPowerVR::operator=(ConvPowerVR&& operation) { return *this; } -absl::Status ConvPowerVR::Compile(const CreationContext& creation_context) { +Status ConvPowerVR::Compile(const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_padding_.x != 1; const std::string code = @@ -189,7 +189,7 @@ absl::Status ConvPowerVR::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status ConvPowerVR::BindArguments() { +Status ConvPowerVR::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -211,7 +211,7 @@ absl::Status ConvPowerVR::BindArguments() { } RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 ConvPowerVR::GetGridSize() const { @@ -245,13 +245,13 @@ int3 ConvPowerVR::GetGridSize() const { } } -absl::Status ConvPowerVR::Tune(const TuningParameters& params) { +Status ConvPowerVR::Tune(const TuningParameters& params) { if (conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP || conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_BY_THREADS || conv_params_.fixed_work_group_size) { - return absl::OkStatus(); + return OkStatus(); } if (conv_params_.work_group_launch_order[0] == 0 && conv_params_.work_group_launch_order[1] == 1 && @@ -260,10 +260,10 @@ absl::Status ConvPowerVR::Tune(const TuningParameters& params) { return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &conv_params_.work_group_size); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status ConvPowerVR::AddToQueue(CLCommandQueue* queue) { +Status ConvPowerVR::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), conv_params_.work_group_size); @@ -848,26 +848,27 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd( return params; } -absl::Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvPowerVR* result, const BHWC* dst_shape) { +Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvPowerVR* result, const BHWC* dst_shape) { *result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape); return result->UploadData(attr.weights, attr.bias, creation_context.context); } -absl::Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvPowerVR* result, const BHWC* dst_shape) { +Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvPowerVR* result, const BHWC* dst_shape) { *result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape); return result->UploadData(attr.weights, attr.bias, creation_context.context); } -absl::Status CreateConvPowerVRWino4x4To6x6( - const CreationContext& creation_context, const OperationDef& definition, - const Convolution2DAttributes& attr, ConvPowerVR* result, - const BHWC* dst_shape) { +Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvPowerVR* result, + const BHWC* dst_shape) { *result = ConvPowerVR(definition); result->conv_params_ = result->GuessBestParamsWinograd( *creation_context.device, definition, attr, dst_shape); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h index 954205f1ca3..44145c585da 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h @@ -39,9 +39,9 @@ namespace cl { class ConvPowerVR : public GPUOperation { public: ConvPowerVR() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvPowerVR(ConvPowerVR&& operation); @@ -87,31 +87,29 @@ class ConvPowerVR : public GPUOperation { explicit ConvPowerVR(const OperationDef& definition); template - absl::Status UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + Status UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); template - absl::Status UploadDataForWinograd4x4To6x6( + Status UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); - friend absl::Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvPowerVR* result, - const BHWC* dst_shape); + friend Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvPowerVR* result, const BHWC* dst_shape); - friend absl::Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvPowerVR* result, - const BHWC* dst_shape); + friend Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvPowerVR* result, const BHWC* dst_shape); - friend absl::Status CreateConvPowerVRWino4x4To6x6( + friend Status CreateConvPowerVRWino4x4To6x6( const CreationContext& creation_context, const OperationDef& definition, const Convolution2DAttributes& attr, ConvPowerVR* result, const BHWC* dst_shape); @@ -140,7 +138,7 @@ class ConvPowerVR : public GPUOperation { bool different_weights_for_height, const BHWC* dst_shape = nullptr) const; - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -154,20 +152,20 @@ class ConvPowerVR : public GPUOperation { }; template -absl::Status ConvPowerVR::UploadData( - const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, CLContext* context) { +Status ConvPowerVR::UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context) { RETURN_IF_ERROR(UploadWeights(weights, context)); LinearStorageCreateInfo create_info; create_info.storage_type = LinearStorageType::BUFFER; create_info.data_type = conv_params_.weights_data_type; create_info.aligned_size = weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_)); - return absl::OkStatus(); + return OkStatus(); } template -absl::Status ConvPowerVR::UploadDataForWinograd4x4To6x6( +Status ConvPowerVR::UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context) { ::tflite::gpu::Tensor wino_weights; @@ -181,12 +179,12 @@ absl::Status ConvPowerVR::UploadDataForWinograd4x4To6x6( bias.shape = Linear(weights.shape.o); bias.data.resize(weights.shape.o, 0.0f); RETURN_IF_ERROR(CreateLinearStorage(create_info, bias, context, &biases_)); - return absl::OkStatus(); + return OkStatus(); } template -absl::Status ConvPowerVR::UploadWeights( - const ::tflite::gpu::Tensor& weights, CLContext* context) { +Status ConvPowerVR::UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context) { const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -212,22 +210,21 @@ absl::Status ConvPowerVR::UploadWeights( } } -absl::Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvPowerVR* result, - const BHWC* dst_shape = nullptr); +Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvPowerVR* result, const BHWC* dst_shape = nullptr); -absl::Status CreateConvPowerVR(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvPowerVR* result, - const BHWC* dst_shape = nullptr); +Status CreateConvPowerVR(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvPowerVR* result, const BHWC* dst_shape = nullptr); -absl::Status CreateConvPowerVRWino4x4To6x6( - const CreationContext& creation_context, const OperationDef& definition, - const Convolution2DAttributes& attr, ConvPowerVR* result, - const BHWC* dst_shape = nullptr); +Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvPowerVR* result, + const BHWC* dst_shape = nullptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc index 953f564c40a..780d6646ea8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc @@ -30,7 +30,6 @@ namespace tflite { namespace gpu { namespace cl { namespace { - std::string GenerateConvCode( const OperationDef& op_def, const int3& block_size, bool is1x1, bool adreno4xx_optimization, bool stride_correction, @@ -385,7 +384,7 @@ ConvTexture& ConvTexture::operator=(ConvTexture&& operation) { return *this; } -absl::Status ConvTexture::Compile(const CreationContext& creation_context) { +Status ConvTexture::Compile(const CreationContext& creation_context) { auto storage_type = definition_.GetPrimaryStorageType(); bool is1x1 = kernel_size_.x == 1 && kernel_size_.y == 1; bool adreno4xx_optimization = @@ -408,7 +407,7 @@ absl::Status ConvTexture::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status ConvTexture::BindArguments() { +Status ConvTexture::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_0_.GetMemoryPtr())); @@ -428,7 +427,7 @@ absl::Status ConvTexture::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_)); RETURN_IF_ERROR( kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y))); - return absl::OkStatus(); + return OkStatus(); } int3 ConvTexture::GetGridSize() const { @@ -439,36 +438,37 @@ int3 ConvTexture::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConvTexture::Tune(const TuningParameters& params) { +Status ConvTexture::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status ConvTexture::AddToQueue(CLCommandQueue* queue) { +Status ConvTexture::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvTexture* result) { +Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvTexture* result) { *result = ConvTexture(definition, attr); return result->UploadData(attr.weights, attr.bias, creation_context.context); } -absl::Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvTexture* result) { +Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvTexture* result) { *result = ConvTexture(definition); return result->UploadData(attr.weights, attr.bias, creation_context.context); } -absl::Status CreateConvTextureWino4x4To6x6( - const CreationContext& creation_context, const OperationDef& definition, - const Convolution2DAttributes& attr, ConvTexture* result) { +Status CreateConvTextureWino4x4To6x6(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvTexture* result) { *result = ConvTexture(definition); result->different_weights_for_height_ = true; result->block_size_ = {4, 1, 2}; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h index b7fbac91cf2..fb25f655057 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h @@ -41,10 +41,10 @@ namespace cl { class ConvTexture : public GPUOperation { public: ConvTexture() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvTexture(ConvTexture&& operation); @@ -53,16 +53,16 @@ class ConvTexture : public GPUOperation { ConvTexture& operator=(const ConvTexture&) = delete; private: - friend absl::Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvTexture* result); - friend absl::Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvTexture* result); + friend Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvTexture* result); + friend Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvTexture* result); - friend absl::Status CreateConvTextureWino4x4To6x6( + friend Status CreateConvTextureWino4x4To6x6( const CreationContext& creation_context, const OperationDef& definition, const Convolution2DAttributes& attr, ConvTexture* result); @@ -70,25 +70,25 @@ class ConvTexture : public GPUOperation { const Convolution2DAttributes& attr); explicit ConvTexture(const OperationDef& definition); template - absl::Status UploadData(const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, - CLContext* context); + Status UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); template - absl::Status UploadDataForWinograd4x4To6x6( + Status UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst_0, absl::Span dst_1, absl::Span dst_2, absl::Span dst_3); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Texture2D weights_0_; @@ -114,20 +114,20 @@ class ConvTexture : public GPUOperation { }; template -absl::Status ConvTexture::UploadData( - const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, CLContext* context) { +Status ConvTexture::UploadData(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context) { RETURN_IF_ERROR(UploadWeights(weights, context)); LinearStorageCreateInfo create_info; create_info.storage_type = LinearStorageType::TEXTURE_2D; create_info.data_type = definition_.GetDataType(); create_info.aligned_size = weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_)); - return absl::OkStatus(); + return OkStatus(); } template -absl::Status ConvTexture::UploadDataForWinograd4x4To6x6( +Status ConvTexture::UploadDataForWinograd4x4To6x6( const ::tflite::gpu::Tensor& weights, const CLDevice& device, CLContext* context) { ::tflite::gpu::Tensor wino_weights; @@ -145,8 +145,8 @@ absl::Status ConvTexture::UploadDataForWinograd4x4To6x6( } template -absl::Status ConvTexture::UploadWeights( - const ::tflite::gpu::Tensor& weights, CLContext* context) { +Status ConvTexture::UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context) { int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); dst_depth = AlignByN(dst_depth, block_size_.z); const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -246,19 +246,20 @@ void ConvTexture::RearrangeWeightsData( } } -absl::Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const Convolution2DAttributes& attr, - ConvTexture* result); +Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvTexture* result); -absl::Status CreateConvTexture(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - ConvTexture* result); +Status CreateConvTexture(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + ConvTexture* result); -absl::Status CreateConvTextureWino4x4To6x6( - const CreationContext& creation_context, const OperationDef& definition, - const Convolution2DAttributes& attr, ConvTexture* result); +Status CreateConvTextureWino4x4To6x6(const CreationContext& creation_context, + const OperationDef& definition, + const Convolution2DAttributes& attr, + ConvTexture* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc index e3170f068e9..947c39cd299 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc @@ -35,12 +35,12 @@ namespace { class OpenClConverterImpl : public TensorObjectConverter { public: - virtual absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) = 0; + virtual Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) = 0; protected: - absl::Status DispatchKernel(cl_mem input, cl_mem output) { + Status DispatchKernel(cl_mem input, cl_mem output) { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(input)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(output)); @@ -119,9 +119,9 @@ class FromTensorConverter : public OpenClConverterImpl { })"); } - absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { + Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { auto params_kernel = output_def.object_def.data_layout == DataLayout::BHWC ? GetToBhwcKernel(input_def, output_def) : GetToDhwc4Kernel(input_def, output_def); @@ -157,12 +157,11 @@ __kernel void from_tensor()" + environment->device(), &kernel_); } - absl::Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto output = absl::get_if(&output_obj); if (!output || !output->memobj) { - return absl::InvalidArgumentError( - "Missing output in from_tensor converter"); + return InvalidArgumentError("Missing output in from_tensor converter"); } auto input_texture = absl::get_if(&input_obj); if (input_texture && input_texture->memobj) { @@ -172,7 +171,7 @@ __kernel void from_tensor()" + if (input_buffer && input_buffer->memobj) { return DispatchKernel(input_buffer->memobj, output->memobj); } - return absl::InvalidArgumentError("Missing input in from_tensor converter"); + return InvalidArgumentError("Missing input in from_tensor converter"); } }; @@ -226,9 +225,9 @@ class ToTensorConverter : public OpenClConverterImpl { )"); } - absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { + Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { auto params_kernel = input_def.object_def.data_layout == DataLayout::BHWC ? GetFromBhwcKernel(input_def, output_def) : GetFromDhwc4Kernel(input_def, output_def); @@ -262,11 +261,11 @@ __kernel void to_tensor()" + &kernel_); } - absl::Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto input = absl::get_if(&input_obj); if (!input || !input->memobj) { - return absl::InvalidArgumentError("Missing input in to_tensor converter"); + return InvalidArgumentError("Missing input in to_tensor converter"); } auto output_texture = absl::get_if(&output_obj); if (output_texture && output_texture->memobj) { @@ -276,7 +275,7 @@ __kernel void to_tensor()" + if (output_buffer && output_buffer->memobj) { return DispatchKernel(input->memobj, output_buffer->memobj); } - return absl::InvalidArgumentError("Missing input in to_tensor converter"); + return InvalidArgumentError("Missing input in to_tensor converter"); } }; @@ -319,18 +318,18 @@ class TrivialCopier : public OpenClConverterImpl { input.data_layout == output.data_layout; } - absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { + Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { dims_ = input_def.dimensions; data_type_ = input_def.object_def.data_type; queue_ = environment->queue(); region_ = CalculateTextureRegion(output_def); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto texture_input = absl::get_if(&input_obj); auto texture_output = absl::get_if(&output_obj); if (texture_input && texture_output) { @@ -341,12 +340,12 @@ class TrivialCopier : public OpenClConverterImpl { if (buffer_input && buffer_output) { return Copy(*buffer_input, *buffer_output); } - return absl::InternalError("Unexpected object"); + return InternalError("Unexpected object"); } - absl::Status Copy(const OpenClBuffer& input, const OpenClBuffer& output) { + Status Copy(const OpenClBuffer& input, const OpenClBuffer& output) { if (input.memobj == output.memobj) { - return absl::OkStatus(); + return OkStatus(); } return GetOpenCLError(clEnqueueCopyBuffer( queue_->queue(), input.memobj, output.memobj, 0, 0, @@ -354,9 +353,9 @@ class TrivialCopier : public OpenClConverterImpl { nullptr)); } - absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) { + Status Copy(const OpenClTexture& input, const OpenClTexture& output) { if (input.memobj == output.memobj) { - return absl::OkStatus(); + return OkStatus(); } size_t origin[3] = {0, 0, 0}; return GetOpenCLError( @@ -381,18 +380,18 @@ class CpuCopier : public OpenClConverterImpl { IsOpenClTextureOrBuffer(input.object_type))); } - absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { + Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { region_ = CalculateTextureRegion( input_def.object_def.object_type == ObjectType::CPU_MEMORY ? output_def : input_def); queue_ = environment->queue(); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto cpu_input = absl::get_if(&input_obj); auto cpu_output = absl::get_if(&output_obj); if (cpu_input) { @@ -420,7 +419,7 @@ class CpuCopier : public OpenClConverterImpl { buffer_input->memobj, cpu_output->size_bytes, cpu_output->data); } } - return absl::InternalError("Unexpected object"); + return InternalError("Unexpected object"); } private: @@ -443,7 +442,7 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder { ToTensorConverter::IsSupported(input_def, output_def)); } - absl::Status MakeConverter( + Status MakeConverter( const TensorObjectDef& input, const TensorObjectDef& output, std::unique_ptr* converter) final { std::unique_ptr impl; @@ -458,11 +457,11 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder { } else if (ToTensorConverter::IsSupported(input_def, output_def)) { impl = absl::make_unique(); } else { - return absl::UnimplementedError("Unsupported conversion"); + return UnimplementedError("Unsupported conversion"); } RETURN_IF_ERROR(impl->Init(input, output, environment_)); *converter = std::move(impl); - return absl::OkStatus(); + return OkStatus(); } Environment* environment_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc index 417fb63e820..921a257aa7e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc @@ -368,8 +368,7 @@ ConvolutionTransposed& ConvolutionTransposed::operator=( return *this; } -absl::Status ConvolutionTransposed::Compile( - const CreationContext& creation_context) { +Status ConvolutionTransposed::Compile(const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, biases_, *creation_context.device, weights_are_buffer_, block_size_, linked_operations_); @@ -381,7 +380,7 @@ absl::Status ConvolutionTransposed::Compile( *creation_context.device, &kernel_); } -absl::Status ConvolutionTransposed::BindArguments() { +Status ConvolutionTransposed::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (weights_are_buffer_) { @@ -400,7 +399,7 @@ absl::Status ConvolutionTransposed::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 ConvolutionTransposed::GetGridSize() const { @@ -413,21 +412,21 @@ int3 ConvolutionTransposed::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConvolutionTransposed::Tune(const TuningParameters& params) { +Status ConvolutionTransposed::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status ConvolutionTransposed::AddToQueue(CLCommandQueue* queue) { +Status ConvolutionTransposed::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateConvolutionTransposed( - const CreationContext& creation_context, const OperationDef& definition, - const ConvolutionTransposedAttributes& attr, - ConvolutionTransposed* result) { +Status CreateConvolutionTransposed(const CreationContext& creation_context, + const OperationDef& definition, + const ConvolutionTransposedAttributes& attr, + ConvolutionTransposed* result) { *result = ConvolutionTransposed(definition, attr, *creation_context.device); RETURN_IF_ERROR( result->UploadWeights(attr.weights, creation_context.context)); @@ -439,7 +438,8 @@ absl::Status CreateConvolutionTransposed( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h index 7545b9091e2..73fce020f5a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h @@ -38,10 +38,10 @@ namespace cl { class ConvolutionTransposed : public GPUOperation { public: ConvolutionTransposed() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed(ConvolutionTransposed&& operation); @@ -50,7 +50,7 @@ class ConvolutionTransposed : public GPUOperation { ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete; private: - friend absl::Status CreateConvolutionTransposed( + friend Status CreateConvolutionTransposed( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed* result); @@ -58,14 +58,14 @@ class ConvolutionTransposed : public GPUOperation { const ConvolutionTransposedAttributes& attr, const CLDevice& device); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; LinearStorage biases_; @@ -88,7 +88,7 @@ class ConvolutionTransposed : public GPUOperation { }; template -absl::Status ConvolutionTransposed::UploadWeights( +Status ConvolutionTransposed::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z); @@ -153,7 +153,7 @@ absl::Status ConvolutionTransposed::UploadWeights( } } - return absl::OkStatus(); + return OkStatus(); } template @@ -208,9 +208,10 @@ void ConvolutionTransposed::RearrangeWeightsData( } } -absl::Status CreateConvolutionTransposed( - const CreationContext& creation_context, const OperationDef& definition, - const ConvolutionTransposedAttributes& attr, ConvolutionTransposed* result); +Status CreateConvolutionTransposed(const CreationContext& creation_context, + const OperationDef& definition, + const ConvolutionTransposedAttributes& attr, + ConvolutionTransposed* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc index 9d3f0b2639c..147674b7eff 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc @@ -396,7 +396,7 @@ ConvolutionTransposed3D& ConvolutionTransposed3D::operator=( return *this; } -absl::Status ConvolutionTransposed3D::Compile( +Status ConvolutionTransposed3D::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposed3DCode( definition_, biases_, *creation_context.device, weights_are_buffer_, @@ -417,7 +417,7 @@ absl::Status ConvolutionTransposed3D::Compile( *creation_context.device, &kernel_); } -absl::Status ConvolutionTransposed3D::BindArguments() { +Status ConvolutionTransposed3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (weights_are_buffer_) { @@ -444,7 +444,7 @@ absl::Status ConvolutionTransposed3D::BindArguments() { IntegralDivideRoundUp(dst_[0]->Slices(), block_size_.w))); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHDS())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHDS())); - return absl::OkStatus(); + return OkStatus(); } int3 ConvolutionTransposed3D::GetGridSize() const { @@ -459,18 +459,18 @@ int3 ConvolutionTransposed3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConvolutionTransposed3D::Tune(const TuningParameters& params) { +Status ConvolutionTransposed3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroupConv(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status ConvolutionTransposed3D::AddToQueue(CLCommandQueue* queue) { +Status ConvolutionTransposed3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateConvolutionTransposed3D( +Status CreateConvolutionTransposed3D( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposed3DAttributes& attr, ConvolutionTransposed3D* result) { @@ -485,7 +485,8 @@ absl::Status CreateConvolutionTransposed3D( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h index 763494efce6..c3fbd87a240 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h @@ -38,10 +38,10 @@ namespace cl { class ConvolutionTransposed3D : public GPUOperation { public: ConvolutionTransposed3D() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed3D(ConvolutionTransposed3D&& operation); @@ -50,7 +50,7 @@ class ConvolutionTransposed3D : public GPUOperation { ConvolutionTransposed3D& operator=(const ConvolutionTransposed3D&) = delete; private: - friend absl::Status CreateConvolutionTransposed3D( + friend Status CreateConvolutionTransposed3D( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposed3DAttributes& attr, ConvolutionTransposed3D* result); @@ -58,14 +58,14 @@ class ConvolutionTransposed3D : public GPUOperation { const ConvolutionTransposed3DAttributes& attr, const CLDevice& device); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; LinearStorage biases_; @@ -88,7 +88,7 @@ class ConvolutionTransposed3D : public GPUOperation { }; template -absl::Status ConvolutionTransposed3D::UploadWeights( +Status ConvolutionTransposed3D::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_depth = AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z); @@ -155,7 +155,7 @@ absl::Status ConvolutionTransposed3D::UploadWeights( } } - return absl::OkStatus(); + return OkStatus(); } template @@ -214,7 +214,7 @@ void ConvolutionTransposed3D::RearrangeWeightsData( } } -absl::Status CreateConvolutionTransposed3D( +Status CreateConvolutionTransposed3D( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposed3DAttributes& attr, ConvolutionTransposed3D* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc index 4be593be57b..7b19ac0ba38 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc @@ -304,11 +304,12 @@ ConvolutionTransposed3x3& ConvolutionTransposed3x3::operator=( return *this; } -absl::Status ConvolutionTransposed3x3::Compile( +Status ConvolutionTransposed3x3::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, biases_, linked_operations_, weights_upload_type_, padding_, work_group_launch_order_); + std::vector options; if (definition_.precision == CalculationsPrecision::F16 && creation_context.device->IsPowerVR()) { @@ -317,10 +318,11 @@ absl::Status ConvolutionTransposed3x3::Compile( RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); - return absl::OkStatus(); + + return OkStatus(); } -absl::Status ConvolutionTransposed3x3::BindArguments() { +Status ConvolutionTransposed3x3::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -335,7 +337,10 @@ absl::Status ConvolutionTransposed3x3::BindArguments() { padding_.x >= 1 ? (padding_.x - 1) / 2 : (padding_.x - 2) / 2; const int padding_y = padding_.y >= 1 ? (padding_.y - 1) / 2 : (padding_.y - 2) / 2; - return kernel_.SetBytesAuto(int2(padding_x * src_[0]->Batch(), padding_y)); + RETURN_IF_ERROR( + kernel_.SetBytesAuto(int2(padding_x * src_[0]->Batch(), padding_y))); + + return OkStatus(); } int3 ConvolutionTransposed3x3::GetGridSize() const { @@ -353,7 +358,7 @@ int3 ConvolutionTransposed3x3::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConvolutionTransposed3x3::AddToQueue(CLCommandQueue* queue) { +Status ConvolutionTransposed3x3::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -365,13 +370,13 @@ bool IsConvolutionTransposed3x3Supported( attr.stride.w == 2 && attr.stride.h == 2; } -absl::Status CreateConvolutionTransposed3x3( +Status CreateConvolutionTransposed3x3( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3* result) { if (!IsConvolutionTransposed3x3Supported(*creation_context.device, definition, attr)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "ConvolutionTransposed3x3 doesn't support this attributes"); } const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h); @@ -386,7 +391,7 @@ absl::Status CreateConvolutionTransposed3x3( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h index 5da112e19c0..9e12d884719 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h @@ -37,8 +37,8 @@ namespace cl { class ConvolutionTransposed3x3 : public GPUOperation { public: ConvolutionTransposed3x3() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed3x3(ConvolutionTransposed3x3&& operation); @@ -56,19 +56,19 @@ class ConvolutionTransposed3x3 : public GPUOperation { private: ConvolutionTransposed3x3(const OperationDef& definition, const CLDevice& device, int2 padding); - friend absl::Status CreateConvolutionTransposed3x3( + friend Status CreateConvolutionTransposed3x3( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3* result); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; int2 padding_; @@ -82,7 +82,7 @@ class ConvolutionTransposed3x3 : public GPUOperation { }; template -absl::Status ConvolutionTransposed3x3::UploadWeights( +Status ConvolutionTransposed3x3::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); @@ -165,7 +165,7 @@ bool IsConvolutionTransposed3x3Supported( const CLDevice& device, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); -absl::Status CreateConvolutionTransposed3x3( +Status CreateConvolutionTransposed3x3( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc index b8e4b25443e..40838d28eed 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc @@ -221,18 +221,19 @@ ConvolutionTransposed3x3Thin& ConvolutionTransposed3x3Thin::operator=( return *this; } -absl::Status ConvolutionTransposed3x3Thin::Compile( +Status ConvolutionTransposed3x3Thin::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, biases_, IntegralDivideRoundUp(src_channels_, 4), IntegralDivideRoundUp(dst_channels_, 4), *creation_context.device, linked_operations_); + return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -absl::Status ConvolutionTransposed3x3Thin::BindArguments() { +Status ConvolutionTransposed3x3Thin::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -241,7 +242,7 @@ absl::Status ConvolutionTransposed3x3Thin::BindArguments() { RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 ConvolutionTransposed3x3Thin::GetGridSize() const { @@ -251,13 +252,12 @@ int3 ConvolutionTransposed3x3Thin::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConvolutionTransposed3x3Thin::Tune( - const TuningParameters& params) { +Status ConvolutionTransposed3x3Thin::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status ConvolutionTransposed3x3Thin::AddToQueue(CLCommandQueue* queue) { +Status ConvolutionTransposed3x3Thin::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -271,13 +271,13 @@ bool IsConvolutionTransposed3x3ThinSupported( attr.padding.appended.h == 1; } -absl::Status CreateConvolutionTransposed3x3Thin( +Status CreateConvolutionTransposed3x3Thin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3Thin* result) { if (!IsConvolutionTransposed3x3ThinSupported(*creation_context.device, attr)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "ConvolutionTransposed3x3Thin doesn't support this attributes"); } *result = ConvolutionTransposed3x3Thin(definition, attr); @@ -291,7 +291,8 @@ absl::Status CreateConvolutionTransposed3x3Thin( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h index f2a0d586bd1..f8d10d6c6b8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h @@ -37,10 +37,10 @@ namespace cl { class ConvolutionTransposed3x3Thin : public GPUOperation { public: ConvolutionTransposed3x3Thin() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed3x3Thin(ConvolutionTransposed3x3Thin&& operation); @@ -51,7 +51,7 @@ class ConvolutionTransposed3x3Thin : public GPUOperation { delete; private: - friend absl::Status CreateConvolutionTransposed3x3Thin( + friend Status CreateConvolutionTransposed3x3Thin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3Thin* result); @@ -59,14 +59,14 @@ class ConvolutionTransposed3x3Thin : public GPUOperation { const OperationDef& definition, const ConvolutionTransposedAttributes& attr); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -80,7 +80,7 @@ class ConvolutionTransposed3x3Thin : public GPUOperation { }; template -absl::Status ConvolutionTransposed3x3Thin::UploadWeights( +Status ConvolutionTransposed3x3Thin::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(src_channels_, 4); const int dst_depth = IntegralDivideRoundUp(dst_channels_, 4); @@ -150,7 +150,7 @@ void ConvolutionTransposed3x3Thin::RearrangeWeightsData( bool IsConvolutionTransposed3x3ThinSupported( const CLDevice& device, const ConvolutionTransposedAttributes& attr); -absl::Status CreateConvolutionTransposed3x3Thin( +Status CreateConvolutionTransposed3x3Thin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed3x3Thin* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc index a558fe6cb3c..1e36be17778 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc @@ -301,7 +301,7 @@ ConvolutionTransposed4x4& ConvolutionTransposed4x4::operator=( return *this; } -absl::Status ConvolutionTransposed4x4::Compile( +Status ConvolutionTransposed4x4::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, biases_, linked_operations_, weights_upload_type_); @@ -314,10 +314,11 @@ absl::Status ConvolutionTransposed4x4::Compile( RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); - return absl::OkStatus(); + + return OkStatus(); } -absl::Status ConvolutionTransposed4x4::BindArguments() { +Status ConvolutionTransposed4x4::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -328,7 +329,8 @@ absl::Status ConvolutionTransposed4x4::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); const int32_t filters_offset = 4 * 16 * src_[0]->Slices(); RETURN_IF_ERROR(kernel_.SetBytesAuto(filters_offset)); - return absl::OkStatus(); + + return OkStatus(); } int3 ConvolutionTransposed4x4::GetGridSize() const { @@ -339,7 +341,7 @@ int3 ConvolutionTransposed4x4::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConvolutionTransposed4x4::AddToQueue(CLCommandQueue* queue) { +Status ConvolutionTransposed4x4::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -352,13 +354,13 @@ bool IsConvolutionTransposed4x4Supported( attr.padding.prepended.w == 1 && attr.padding.prepended.h == 1; } -absl::Status CreateConvolutionTransposed4x4( +Status CreateConvolutionTransposed4x4( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed4x4* result) { if (!IsConvolutionTransposed4x4Supported(*creation_context.device, definition, attr)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "ConvolutionTransposed4x4 doesn't support this attributes"); } *result = ConvolutionTransposed4x4(definition, *creation_context.device); @@ -371,7 +373,7 @@ absl::Status CreateConvolutionTransposed4x4( create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h index 7bf37c56119..8d92542c908 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h @@ -37,8 +37,8 @@ namespace cl { class ConvolutionTransposed4x4 : public GPUOperation { public: ConvolutionTransposed4x4() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposed4x4(ConvolutionTransposed4x4&& operation); @@ -56,19 +56,19 @@ class ConvolutionTransposed4x4 : public GPUOperation { private: ConvolutionTransposed4x4(const OperationDef& definition, const CLDevice& device); - friend absl::Status CreateConvolutionTransposed4x4( + friend Status CreateConvolutionTransposed4x4( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed4x4* result); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Buffer weights_; @@ -80,7 +80,7 @@ class ConvolutionTransposed4x4 : public GPUOperation { }; template -absl::Status ConvolutionTransposed4x4::UploadWeights( +Status ConvolutionTransposed4x4::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); @@ -150,7 +150,7 @@ bool IsConvolutionTransposed4x4Supported( const CLDevice& device, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); -absl::Status CreateConvolutionTransposed4x4( +Status CreateConvolutionTransposed4x4( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposed4x4* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc index 8ea40bedd7d..03b9ab0eb6c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc @@ -184,7 +184,7 @@ ConvolutionTransposedThin& ConvolutionTransposedThin::operator=( return *this; } -absl::Status ConvolutionTransposedThin::Compile( +Status ConvolutionTransposedThin::Compile( const CreationContext& creation_context) { const auto code = GenerateConvolutionTransposedCode( definition_, IntegralDivideRoundUp(src_channels_, 4), dst_channels_, @@ -201,7 +201,7 @@ absl::Status ConvolutionTransposedThin::Compile( *creation_context.device, &kernel_); } -absl::Status ConvolutionTransposedThin::BindArguments() { +Status ConvolutionTransposedThin::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr())); @@ -210,7 +210,7 @@ absl::Status ConvolutionTransposedThin::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(bias_value_)); - return absl::OkStatus(); + return OkStatus(); } int3 ConvolutionTransposedThin::GetGridSize() const { @@ -220,12 +220,12 @@ int3 ConvolutionTransposedThin::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ConvolutionTransposedThin::Tune(const TuningParameters& params) { +Status ConvolutionTransposedThin::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status ConvolutionTransposedThin::AddToQueue(CLCommandQueue* queue) { +Status ConvolutionTransposedThin::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -238,18 +238,18 @@ bool IsConvolutionTransposedThinSupported( attr.padding.appended.w == 0 && attr.padding.appended.h == 0; } -absl::Status CreateConvolutionTransposedThin( +Status CreateConvolutionTransposedThin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposedThin* result) { if (!IsConvolutionTransposedThinSupported(*creation_context.device, attr)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "ConvolutionTransposedThin doesn't support this attributes"); } *result = ConvolutionTransposedThin(definition, attr); RETURN_IF_ERROR( result->UploadWeights(attr.weights, creation_context.context)); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h index 573772965ae..0642a7c928b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h @@ -38,10 +38,10 @@ namespace cl { class ConvolutionTransposedThin : public GPUOperation { public: ConvolutionTransposedThin() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ConvolutionTransposedThin(ConvolutionTransposedThin&& operation); @@ -51,21 +51,21 @@ class ConvolutionTransposedThin : public GPUOperation { delete; private: - friend absl::Status CreateConvolutionTransposedThin( + friend Status CreateConvolutionTransposedThin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposedThin* result); ConvolutionTransposedThin(const OperationDef& definition, const ConvolutionTransposedAttributes& attr); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Buffer weights_buf_; @@ -80,7 +80,7 @@ class ConvolutionTransposedThin : public GPUOperation { }; template -absl::Status ConvolutionTransposedThin::UploadWeights( +Status ConvolutionTransposedThin::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(src_channels_, 4); const int elements_count = @@ -136,7 +136,7 @@ void ConvolutionTransposedThin::RearrangeWeightsData( bool IsConvolutionTransposedThinSupported( const CLDevice& device, const ConvolutionTransposedAttributes& attr); -absl::Status CreateConvolutionTransposedThin( +Status CreateConvolutionTransposedThin( const CreationContext& creation_context, const OperationDef& definition, const ConvolutionTransposedAttributes& attr, ConvolutionTransposedThin* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc index 99bec18c7f8..e7bf31b0d37 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc @@ -226,8 +226,7 @@ DepthWiseConvolution& DepthWiseConvolution::operator=( return *this; } -absl::Status DepthWiseConvolution::Compile( - const CreationContext& creation_context) { +Status DepthWiseConvolution::Compile(const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; const auto code = GenerateDepthWiseConvolutionCode( @@ -238,7 +237,7 @@ absl::Status DepthWiseConvolution::Compile( *creation_context.device, &kernel_); } -absl::Status DepthWiseConvolution::BindArguments() { +Status DepthWiseConvolution::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_)); @@ -256,7 +255,7 @@ absl::Status DepthWiseConvolution::BindArguments() { } RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 DepthWiseConvolution::GetGridSize() const { @@ -266,20 +265,20 @@ int3 DepthWiseConvolution::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status DepthWiseConvolution::Tune(const TuningParameters& params) { +Status DepthWiseConvolution::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status DepthWiseConvolution::AddToQueue(CLCommandQueue* queue) { +Status DepthWiseConvolution::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateDepthWiseConvolution( - const CreationContext& creation_context, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, - DepthWiseConvolution* result) { +Status CreateDepthWiseConvolution(const CreationContext& creation_context, + const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr, + DepthWiseConvolution* result) { bool weights_are_buffer = creation_context.device->IsMali(); *result = DepthWiseConvolution(definition, attr, weights_are_buffer); RETURN_IF_ERROR( @@ -292,7 +291,7 @@ absl::Status CreateDepthWiseConvolution( create_info.aligned_size = attr.weights.shape.o * attr.weights.shape.i; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h index 8f3320ae57b..5915ed94502 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h @@ -38,10 +38,10 @@ namespace cl { class DepthWiseConvolution : public GPUOperation { public: DepthWiseConvolution() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only DepthWiseConvolution(DepthWiseConvolution&& operation); @@ -50,7 +50,7 @@ class DepthWiseConvolution : public GPUOperation { DepthWiseConvolution& operator=(const DepthWiseConvolution&) = delete; private: - friend absl::Status CreateDepthWiseConvolution( + friend Status CreateDepthWiseConvolution( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr, DepthWiseConvolution* result); @@ -58,14 +58,14 @@ class DepthWiseConvolution : public GPUOperation { const DepthwiseConvolution2DAttributes& attr, bool weights_are_buffer); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; bool weights_are_buffer_; @@ -86,7 +86,7 @@ class DepthWiseConvolution : public GPUOperation { }; template -absl::Status DepthWiseConvolution::UploadWeights( +Status DepthWiseConvolution::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); @@ -130,7 +130,7 @@ absl::Status DepthWiseConvolution::UploadWeights( weights_ = weights_tex2d_.GetMemoryPtr(); } - return absl::OkStatus(); + return OkStatus(); } template @@ -162,9 +162,10 @@ void DepthWiseConvolution::RearrangeWeightsData( } } -absl::Status CreateDepthWiseConvolution( - const CreationContext& creation_context, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, DepthWiseConvolution* result); +Status CreateDepthWiseConvolution(const CreationContext& creation_context, + const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr, + DepthWiseConvolution* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc index 57d30dd2734..e3297cb6814 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc @@ -256,7 +256,7 @@ DepthWiseConvolution3D& DepthWiseConvolution3D::operator=( return *this; } -absl::Status DepthWiseConvolution3D::Compile( +Status DepthWiseConvolution3D::Compile( const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; @@ -268,7 +268,7 @@ absl::Status DepthWiseConvolution3D::Compile( *creation_context.device, &kernel_); } -absl::Status DepthWiseConvolution3D::BindArguments() { +Status DepthWiseConvolution3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (weights_are_buffer_) { @@ -295,7 +295,7 @@ absl::Status DepthWiseConvolution3D::BindArguments() { } RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS())); - return absl::OkStatus(); + return OkStatus(); } int3 DepthWiseConvolution3D::GetGridSize() const { @@ -305,17 +305,17 @@ int3 DepthWiseConvolution3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status DepthWiseConvolution3D::Tune(const TuningParameters& params) { +Status DepthWiseConvolution3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status DepthWiseConvolution3D::AddToQueue(CLCommandQueue* queue) { +Status DepthWiseConvolution3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateDepthWiseConvolution3D( +Status CreateDepthWiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, DepthWiseConvolution3D* result) { @@ -330,7 +330,7 @@ absl::Status CreateDepthWiseConvolution3D( create_info.aligned_size = attr.weights.shape.o * attr.weights.shape.i; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h index 78ca6862416..e3c565422af 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h @@ -38,10 +38,10 @@ namespace cl { class DepthWiseConvolution3D : public GPUOperation { public: DepthWiseConvolution3D() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only DepthWiseConvolution3D(DepthWiseConvolution3D&& operation); @@ -50,7 +50,7 @@ class DepthWiseConvolution3D : public GPUOperation { DepthWiseConvolution3D& operator=(const DepthWiseConvolution3D&) = delete; private: - friend absl::Status CreateDepthWiseConvolution3D( + friend Status CreateDepthWiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, DepthWiseConvolution3D* result); @@ -58,14 +58,14 @@ class DepthWiseConvolution3D : public GPUOperation { const DepthwiseConvolution3DAttributes& attr, const CLDevice& device); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeightsData(const ::tflite::gpu::Tensor& weights, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Texture2D weights_tex2d_; @@ -85,7 +85,7 @@ class DepthWiseConvolution3D : public GPUOperation { }; template -absl::Status DepthWiseConvolution3D::UploadWeights( +Status DepthWiseConvolution3D::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_slices = IntegralDivideRoundUp(dst_channels, 4); @@ -123,7 +123,7 @@ absl::Status DepthWiseConvolution3D::UploadWeights( gpu_data.data(), context, &weights_tex2d_)); } } - return absl::OkStatus(); + return OkStatus(); } template @@ -158,7 +158,7 @@ void DepthWiseConvolution3D::RearrangeWeightsData( } } -absl::Status CreateDepthWiseConvolution3D( +Status CreateDepthWiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, DepthWiseConvolution3D* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc index 3324adada3b..704df26f2ba 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc @@ -297,8 +297,7 @@ DepthWiseConv3x3& DepthWiseConv3x3::operator=(DepthWiseConv3x3&& operation) { return *this; } -absl::Status DepthWiseConv3x3::Compile( - const CreationContext& creation_context) { +Status DepthWiseConv3x3::Compile(const CreationContext& creation_context) { std::string code = GenerateDepthWiseConvCode( definition_, linked_operations_, *creation_context.device, weights_are_buffer_, local_mem_uploads_); @@ -312,14 +311,15 @@ absl::Status DepthWiseConv3x3::Compile( *creation_context.device, &kernel_); } -absl::Status DepthWiseConv3x3::BindArguments() { +Status DepthWiseConv3x3::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_)); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - return absl::OkStatus(); + + return OkStatus(); } int3 DepthWiseConv3x3::GetGridSize() const { @@ -329,15 +329,15 @@ int3 DepthWiseConv3x3::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status DepthWiseConv3x3::Tune(const TuningParameters& params) { +Status DepthWiseConv3x3::Tune(const TuningParameters& params) { if (local_mem_uploads_) { - return absl::OkStatus(); + return OkStatus(); } RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status DepthWiseConv3x3::AddToQueue(CLCommandQueue* queue) { +Status DepthWiseConv3x3::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -351,11 +351,12 @@ bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr) { attr.padding.appended.h == 1; } -absl::Status CreateDepthWiseConv3x3( - const CreationContext& creation_context, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result) { +Status CreateDepthWiseConv3x3(const CreationContext& creation_context, + const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr, + DepthWiseConv3x3* result) { if (!IsDepthWiseConv3x3Supported(attr)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "DepthWiseConv3x3 doesn't support this attributes"); } bool weights_are_buffer = @@ -363,8 +364,9 @@ absl::Status CreateDepthWiseConv3x3( bool local_mem_uploads = weights_are_buffer && creation_context.device->IsPowerVR(); *result = DepthWiseConv3x3(definition, weights_are_buffer, local_mem_uploads); - return result->UploadWeightsAndBiases(attr.weights, attr.bias, - creation_context.context); + RETURN_IF_ERROR(result->UploadWeightsAndBiases(attr.weights, attr.bias, + creation_context.context)); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h index 936ab773229..1630557afc9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h @@ -38,10 +38,10 @@ namespace cl { class DepthWiseConv3x3 : public GPUOperation { public: DepthWiseConv3x3() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only DepthWiseConv3x3(DepthWiseConv3x3&& operation); @@ -53,11 +53,11 @@ class DepthWiseConv3x3 : public GPUOperation { explicit DepthWiseConv3x3(const OperationDef& definition, bool weights_are_buffer, bool local_mem_uploads); template - absl::Status UploadWeightsAndBiases( - const ::tflite::gpu::Tensor& weights, - const ::tflite::gpu::Tensor& biases, CLContext* context); + Status UploadWeightsAndBiases(const ::tflite::gpu::Tensor& weights, + const ::tflite::gpu::Tensor& biases, + CLContext* context); - friend absl::Status CreateDepthWiseConv3x3( + friend Status CreateDepthWiseConv3x3( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result); @@ -66,7 +66,7 @@ class DepthWiseConv3x3 : public GPUOperation { const ::tflite::gpu::Tensor& weights, const ::tflite::gpu::Tensor& biases, absl::Span dst); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; bool weights_are_buffer_; @@ -80,7 +80,7 @@ class DepthWiseConv3x3 : public GPUOperation { }; template -absl::Status DepthWiseConv3x3::UploadWeightsAndBiases( +Status DepthWiseConv3x3::UploadWeightsAndBiases( const ::tflite::gpu::Tensor& weights, const ::tflite::gpu::Tensor& biases, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -122,7 +122,7 @@ absl::Status DepthWiseConv3x3::UploadWeightsAndBiases( weights_ = weights_tex2d_.GetMemoryPtr(); } - return absl::OkStatus(); + return OkStatus(); } template @@ -160,9 +160,10 @@ void DepthWiseConv3x3::RearrangeWeightsAndBiasesData( bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr); -absl::Status CreateDepthWiseConv3x3( - const CreationContext& creation_context, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result); +Status CreateDepthWiseConv3x3(const CreationContext& creation_context, + const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr, + DepthWiseConv3x3* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc index e435bccef03..7c394a45669 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc @@ -203,14 +203,14 @@ std::string ElementwiseTwoInput::GetArgsDeclaration() const { return args; } -absl::Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) { +Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) { if (use_scalar_para_) { RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_para_)); } else { RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr())); RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB())); } - return absl::OkStatus(); + return OkStatus(); } ElementwiseTwoInput CreateElementwiseTwoInput( diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h index 4c85fee6071..8bf33b0c128 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h @@ -75,7 +75,7 @@ class ElementwiseTwoInput : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - absl::Status BindArguments(CLKernel* kernel) override; + Status BindArguments(CLKernel* kernel) override; inline void SetScalarPara(FLT scalar) { scalar_para_ = scalar; use_scalar_para_ = true; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc index f93648f82fc..44a3e97554c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc @@ -113,7 +113,7 @@ FullyConnected& FullyConnected::operator=(FullyConnected&& kernel) { return *this; } -absl::Status FullyConnected::Compile(const CreationContext& creation_context) { +Status FullyConnected::Compile(const CreationContext& creation_context) { int wg_width = 32; int wg_height = 4; int work_items; @@ -134,10 +134,10 @@ absl::Status FullyConnected::Compile(const CreationContext& creation_context) { } work_items = work_group_size_.x * work_group_size_.y * work_group_size_.z; } while (work_items > kernel_.GetMaxWorkGroupSize()); - return absl::OkStatus(); + return OkStatus(); } -absl::Status FullyConnected::AddToQueue(CLCommandQueue* queue) { +Status FullyConnected::AddToQueue(CLCommandQueue* queue) { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr())); @@ -146,14 +146,15 @@ absl::Status FullyConnected::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR( kernel_.SetBytesAuto(int2(src_[0]->Slices(), dst_[0]->Slices()))); + return queue->DispatchImplicit(kernel_, {dst_[0]->Slices(), 1, 1}, work_group_size_); } -absl::Status CreateFullyConnected(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - FullyConnected* result) { +Status CreateFullyConnected(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + FullyConnected* result) { *result = FullyConnected(definition); RETURN_IF_ERROR( result->UploadWeights(attr.weights, creation_context.context)); @@ -164,7 +165,7 @@ absl::Status CreateFullyConnected(const CreationContext& creation_context, create_info.aligned_size = attr.weights.shape.o; RETURN_IF_ERROR(CreateLinearStorage( create_info, attr.bias, creation_context.context, &result->biases_)); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h index bc7cbd32fb0..83ac279a71b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h @@ -37,9 +37,9 @@ namespace cl { class FullyConnected : public GPUOperation { public: FullyConnected() = default; - absl::Status AddToQueue(CLCommandQueue* queue) override; + Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only FullyConnected(FullyConnected&& kernel); @@ -49,13 +49,14 @@ class FullyConnected : public GPUOperation { private: explicit FullyConnected(const OperationDef& definition); - friend absl::Status CreateFullyConnected( - const CreationContext& creation_context, const OperationDef& definition, - const FullyConnectedAttributes& attr, FullyConnected* result); + friend Status CreateFullyConnected(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + FullyConnected* result); template - absl::Status UploadWeights(const ::tflite::gpu::Tensor& weights, - CLContext* context); + Status UploadWeights(const ::tflite::gpu::Tensor& weights, + CLContext* context); template void RearrangeWeights(const ::tflite::gpu::Tensor& weights, @@ -68,7 +69,7 @@ class FullyConnected : public GPUOperation { }; template -absl::Status FullyConnected::UploadWeights( +Status FullyConnected::UploadWeights( const ::tflite::gpu::Tensor& weights, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); @@ -122,10 +123,10 @@ void FullyConnected::RearrangeWeights( } } -absl::Status CreateFullyConnected(const CreationContext& creation_context, - const OperationDef& definition, - const FullyConnectedAttributes& attr, - FullyConnected* result); +Status CreateFullyConnected(const CreationContext& creation_context, + const OperationDef& definition, + const FullyConnectedAttributes& attr, + FullyConnected* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc index 9f4c9871123..4972bb9f737 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc @@ -154,7 +154,7 @@ ElementwiseOperation& ElementwiseOperation::operator=( return *this; } -absl::Status ElementwiseOperation::BindArguments() { +Status ElementwiseOperation::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArguments(&kernel_)); @@ -162,7 +162,7 @@ absl::Status ElementwiseOperation::BindArguments() { RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 ElementwiseOperation::GetGridSize() const { @@ -172,20 +172,19 @@ int3 ElementwiseOperation::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status ElementwiseOperation::Compile( - const CreationContext& creation_context) { +Status ElementwiseOperation::Compile(const CreationContext& creation_context) { const auto code = GetElementWiseCode(definition_, *this, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -absl::Status ElementwiseOperation::AddToQueue(CLCommandQueue* queue) { +Status ElementwiseOperation::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status ElementwiseOperation::Tune(const TuningParameters& params) { +Status ElementwiseOperation::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } @@ -210,12 +209,12 @@ std::string PostProcess(const std::vector& linked_ops, return code; } -absl::Status BindArgs(CLKernel* kernel, - const std::vector& linked_ops) { +Status BindArgs(CLKernel* kernel, + const std::vector& linked_ops) { for (auto linked_op : linked_ops) { RETURN_IF_ERROR(linked_op->BindArguments(kernel)); } - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h index 17817682bce..4507f0eb81d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h @@ -96,15 +96,11 @@ class GPUOperation { void SetSrc(Tensor* ptr, int index = 0); void SetDst(Tensor* ptr, int index = 0); - virtual absl::Status AddToQueue(CLCommandQueue* queue) { - return absl::OkStatus(); - } - virtual absl::Status Tune(const TuningParameters& params) { - return absl::OkStatus(); - } + virtual Status AddToQueue(CLCommandQueue* queue) { return OkStatus(); } + virtual Status Tune(const TuningParameters& params) { return OkStatus(); } - virtual absl::Status Compile(const CreationContext& creation_context) { - return absl::OkStatus(); + virtual Status Compile(const CreationContext& creation_context) { + return OkStatus(); } const OperationDef& GetDefinition() const { return definition_; } @@ -131,10 +127,10 @@ class ElementwiseOperation : public GPUOperation { : GPUOperation(definition) {} virtual ~ElementwiseOperation() {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only ElementwiseOperation(ElementwiseOperation&& operation); @@ -154,12 +150,10 @@ class ElementwiseOperation : public GPUOperation { virtual std::string GetCoreCode(const LinkingContext& context) const = 0; virtual std::string GetArgsDeclaration() const { return ""; } - virtual absl::Status BindArguments(CLKernel* kernel) { - return absl::OkStatus(); - } + virtual Status BindArguments(CLKernel* kernel) { return OkStatus(); } protected: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; int3 work_group_size_ = int3(8, 4, 1); @@ -177,8 +171,8 @@ std::string PostProcess(const std::vector& linked_ops, // Binds arguments to given kernel for elementwise operations in // linked_ops. // Every ElementwiseOperation can bind her arguments. -absl::Status BindArgs(CLKernel* kernel, - const std::vector& linked_ops); +Status BindArgs(CLKernel* kernel, + const std::vector& linked_ops); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc index 77eea07f278..f2e53a06908 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc @@ -121,14 +121,14 @@ LSTM& LSTM::operator=(LSTM&& kernel) { return *this; } -absl::Status LSTM::Compile(const CreationContext& creation_context) { +Status LSTM::Compile(const CreationContext& creation_context) { const auto code = GetLSTMCode(definition_, *creation_context.device); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -absl::Status LSTM::BindArguments() { +Status LSTM::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr())); @@ -137,7 +137,8 @@ absl::Status LSTM::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch())); - return absl::OkStatus(); + + return OkStatus(); } int3 LSTM::GetGridSize() const { @@ -147,12 +148,12 @@ int3 LSTM::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status LSTM::Tune(const TuningParameters& params) { +Status LSTM::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status LSTM::AddToQueue(CLCommandQueue* queue) { +Status LSTM::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h index 27b072ed001..3e84887cdc2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h @@ -28,9 +28,9 @@ namespace cl { class LSTM : public GPUOperation { public: explicit LSTM(const OperationDef& definition); - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; + Status Compile(const CreationContext& creation_context) override; // Move only LSTM(LSTM&& kernel); @@ -39,7 +39,7 @@ class LSTM : public GPUOperation { LSTM& operator=(const LSTM&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc index 56109fc713b..194daee5f1e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc @@ -218,7 +218,7 @@ MaxUnpooling& MaxUnpooling::operator=(MaxUnpooling&& kernel) { return *this; } -absl::Status MaxUnpooling::Compile(const CreationContext& creation_context) { +Status MaxUnpooling::Compile(const CreationContext& creation_context) { const auto code = GetMaxUnpoolingKernelCode( definition_, *creation_context.device, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -226,7 +226,7 @@ absl::Status MaxUnpooling::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status MaxUnpooling::BindArguments() { +Status MaxUnpooling::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr())); @@ -237,7 +237,8 @@ absl::Status MaxUnpooling::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_)); - return absl::OkStatus(); + + return OkStatus(); } int3 MaxUnpooling::GetGridSize() const { @@ -247,12 +248,12 @@ int3 MaxUnpooling::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status MaxUnpooling::Tune(const TuningParameters& params) { +Status MaxUnpooling::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status MaxUnpooling::AddToQueue(CLCommandQueue* queue) { +Status MaxUnpooling::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -290,7 +291,7 @@ MaxUnpooling3D& MaxUnpooling3D::operator=(MaxUnpooling3D&& kernel) { return *this; } -absl::Status MaxUnpooling3D::Compile(const CreationContext& creation_context) { +Status MaxUnpooling3D::Compile(const CreationContext& creation_context) { const auto code = GetMaxUnpooling3DKernelCode( definition_, *creation_context.device, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -298,7 +299,7 @@ absl::Status MaxUnpooling3D::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status MaxUnpooling3D::BindArguments() { +Status MaxUnpooling3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr())); @@ -315,7 +316,8 @@ absl::Status MaxUnpooling3D::BindArguments() { kernel_.SetBytesAuto(int4(padding_.x, padding_.y, padding_.z, 1))); RETURN_IF_ERROR( kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1))); - return absl::OkStatus(); + + return OkStatus(); } int3 MaxUnpooling3D::GetGridSize() const { @@ -325,12 +327,12 @@ int3 MaxUnpooling3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status MaxUnpooling3D::Tune(const TuningParameters& params) { +Status MaxUnpooling3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status MaxUnpooling3D::AddToQueue(CLCommandQueue* queue) { +Status MaxUnpooling3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h index 19184ee1e89..c7479acb728 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h @@ -29,10 +29,10 @@ class MaxUnpooling : public GPUOperation { public: MaxUnpooling(const OperationDef& definition, const MaxUnpooling2DAttributes& attr); - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only MaxUnpooling(MaxUnpooling&& kernel); @@ -41,7 +41,7 @@ class MaxUnpooling : public GPUOperation { MaxUnpooling& operator=(const MaxUnpooling&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; int2 stride_; @@ -59,10 +59,10 @@ class MaxUnpooling3D : public GPUOperation { public: MaxUnpooling3D(const OperationDef& definition, const MaxUnpooling3DAttributes& attr); - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only MaxUnpooling3D(MaxUnpooling3D&& kernel); @@ -71,7 +71,7 @@ class MaxUnpooling3D : public GPUOperation { MaxUnpooling3D& operator=(const MaxUnpooling3D&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; int3 stride_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc index f79a30e33dd..9dd0546c059 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc @@ -103,7 +103,7 @@ Mean& Mean::operator=(Mean&& operation) { return *this; } -absl::Status Mean::Compile(const CreationContext& creation_context) { +Status Mean::Compile(const CreationContext& creation_context) { if (creation_context.device->IsAdreno3xx()) { work_group_size_ = int3(16, 8, 1); } @@ -114,7 +114,7 @@ absl::Status Mean::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Mean::BindArguments() { +Status Mean::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -124,7 +124,7 @@ absl::Status Mean::BindArguments() { const double size_0 = work_group_size_.x * work_group_size_.y; const double size_1 = total_size / size_0; RETURN_IF_ERROR(kernel_.SetBytesAuto(float2(1.0 / size_1, 1.0 / size_0))); - return absl::OkStatus(); + return OkStatus(); } int3 Mean::GetGridSize() const { @@ -134,7 +134,7 @@ int3 Mean::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Mean::AddToQueue(CLCommandQueue* queue) { +Status Mean::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean.h b/tensorflow/lite/delegates/gpu/cl/kernels/mean.h index 4525551b5f2..0c0d3fff81c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean.h @@ -30,9 +30,9 @@ class Mean : public GPUOperation { public: Mean() = default; explicit Mean(const OperationDef& definition) : GPUOperation(definition) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; + Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Mean(Mean&& operation); @@ -41,7 +41,7 @@ class Mean : public GPUOperation { Mean& operator=(const Mean&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc index fde0712a412..45f48246078 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc @@ -89,7 +89,7 @@ std::string MultiplyAdd::GetArgsDeclaration() const { return args; } -absl::Status MultiplyAdd::BindArguments(CLKernel* kernel) { +Status MultiplyAdd::BindArguments(CLKernel* kernel) { if (use_mul_vec_) { RETURN_IF_ERROR(kernel->SetMemoryAuto(mul_vec_.GetMemoryPtr())); } @@ -102,12 +102,12 @@ absl::Status MultiplyAdd::BindArguments(CLKernel* kernel) { if (scalar_add_.Active()) { RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_add_)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr, - CalculationsPrecision scalar_precision, - CLContext* context) { +Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr, + CalculationsPrecision scalar_precision, + CLContext* context) { auto mul = absl::get_if<::tflite::gpu::Tensor>( &attr.param); auto mul_scalar = absl::get_if(&attr.param); @@ -116,12 +116,12 @@ absl::Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr, } else { scalar_mul_ = FLT(scalar_precision, *mul_scalar); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status MultiplyAdd::UploadAdd(const AddAttributes& attr, - CalculationsPrecision scalar_precision, - CLContext* context) { +Status MultiplyAdd::UploadAdd(const AddAttributes& attr, + CalculationsPrecision scalar_precision, + CLContext* context) { auto add = absl::get_if<::tflite::gpu::Tensor>( &attr.param); auto add_scalar = absl::get_if(&attr.param); @@ -130,13 +130,12 @@ absl::Status MultiplyAdd::UploadAdd(const AddAttributes& attr, } else { scalar_add_ = FLT(scalar_precision, *add_scalar); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& attr, - MultiplyAdd* result) { +Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& attr, MultiplyAdd* result) { const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 : definition.precision; @@ -144,12 +143,12 @@ absl::Status CreateMultiplyAdd(const CreationContext& creation_context, RETURN_IF_ERROR( result->UploadMul(attr, scalar_precision, creation_context.context)); result->SetLinkIndex(0); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const AddAttributes& attr, MultiplyAdd* result) { +Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const AddAttributes& attr, MultiplyAdd* result) { const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 : definition.precision; @@ -157,14 +156,13 @@ absl::Status CreateMultiplyAdd(const CreationContext& creation_context, RETURN_IF_ERROR( result->UploadAdd(attr, scalar_precision, creation_context.context)); result->SetLinkIndex(0); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& mul_attr, - const AddAttributes& add_attr, - MultiplyAdd* result) { +Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& mul_attr, + const AddAttributes& add_attr, MultiplyAdd* result) { const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 : definition.precision; @@ -174,7 +172,7 @@ absl::Status CreateMultiplyAdd(const CreationContext& creation_context, RETURN_IF_ERROR( result->UploadAdd(add_attr, scalar_precision, creation_context.context)); result->SetLinkIndex(0); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h index 4047a7e5c1b..83bb6e11216 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h @@ -40,42 +40,40 @@ class MultiplyAdd : public ElementwiseOperation { MultiplyAdd(const MultiplyAdd&) = delete; MultiplyAdd& operator=(const MultiplyAdd&) = delete; - absl::Status UploadMul(const MultiplyAttributes& attr, - CalculationsPrecision scalar_precision, - CLContext* context); - absl::Status UploadAdd(const AddAttributes& attr, - CalculationsPrecision scalar_precision, - CLContext* context); + Status UploadMul(const MultiplyAttributes& attr, + CalculationsPrecision scalar_precision, CLContext* context); + Status UploadAdd(const AddAttributes& attr, + CalculationsPrecision scalar_precision, CLContext* context); template - absl::Status UploadMul(const ::tflite::gpu::Tensor& mul, - CLContext* context); + Status UploadMul(const ::tflite::gpu::Tensor& mul, + CLContext* context); template - absl::Status UploadAdd(const ::tflite::gpu::Tensor& add, - CLContext* context); + Status UploadAdd(const ::tflite::gpu::Tensor& add, + CLContext* context); void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - absl::Status BindArguments(CLKernel* kernel) override; + Status BindArguments(CLKernel* kernel) override; - friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& attr, - MultiplyAdd* result); + friend Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& attr, + MultiplyAdd* result); - friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const AddAttributes& attr, - MultiplyAdd* result); + friend Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const AddAttributes& attr, + MultiplyAdd* result); - friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& mul_attr, - const AddAttributes& add_attr, - MultiplyAdd* result); + friend Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& mul_attr, + const AddAttributes& add_attr, + MultiplyAdd* result); private: explicit MultiplyAdd(const OperationDef& definition) @@ -91,43 +89,41 @@ class MultiplyAdd : public ElementwiseOperation { FLT scalar_add_; }; -absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& attr, - MultiplyAdd* result); +Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& attr, MultiplyAdd* result); -absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const AddAttributes& attr, MultiplyAdd* result); +Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const AddAttributes& attr, MultiplyAdd* result); -absl::Status CreateMultiplyAdd(const CreationContext& creation_context, - const OperationDef& definition, - const MultiplyAttributes& mul_attr, - const AddAttributes& add_attr, - MultiplyAdd* result); +Status CreateMultiplyAdd(const CreationContext& creation_context, + const OperationDef& definition, + const MultiplyAttributes& mul_attr, + const AddAttributes& add_attr, MultiplyAdd* result); template -absl::Status MultiplyAdd::UploadMul(const ::tflite::gpu::Tensor& mul, - CLContext* context) { +Status MultiplyAdd::UploadMul(const ::tflite::gpu::Tensor& mul, + CLContext* context) { LinearStorageCreateInfo create_info; create_info.storage_type = DeduceLinearStorageType(definition_.GetPrimaryStorageType()); create_info.data_type = definition_.GetDataType(); RETURN_IF_ERROR(CreateLinearStorage(create_info, mul, context, &mul_vec_)); use_mul_vec_ = true; - return absl::OkStatus(); + return OkStatus(); } template -absl::Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor& add, - CLContext* context) { +Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor& add, + CLContext* context) { LinearStorageCreateInfo create_info; create_info.storage_type = DeduceLinearStorageType(definition_.GetPrimaryStorageType()); create_info.data_type = definition_.GetDataType(); RETURN_IF_ERROR(CreateLinearStorage(create_info, add, context, &add_vec_)); use_add_vec_ = true; - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc index 48edcb448a1..1443f5958db 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc @@ -169,7 +169,7 @@ Padding& Padding::operator=(Padding&& kernel) { return *this; } -absl::Status Padding::Compile(const CreationContext& creation_context) { +Status Padding::Compile(const CreationContext& creation_context) { const auto code = GetPaddingCode(definition_, linked_operations_, attributes_); return creation_context.cache->GetOrCreateCLKernel( @@ -177,7 +177,7 @@ absl::Status Padding::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Padding::BindArguments() { +Status Padding::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -187,7 +187,7 @@ absl::Status Padding::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); const auto& prep = attributes_.prepended; RETURN_IF_ERROR(kernel_.SetBytesAuto(int4(prep.w, prep.h, prep.c, prep.b))); - return absl::OkStatus(); + return OkStatus(); } int3 Padding::GetGridSize() const { @@ -197,12 +197,12 @@ int3 Padding::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Padding::Tune(const TuningParameters& params) { +Status Padding::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status Padding::AddToQueue(CLCommandQueue* queue) { +Status Padding::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.h b/tensorflow/lite/delegates/gpu/cl/kernels/padding.h index ddf9f9583be..38e78d4a461 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.h @@ -28,10 +28,10 @@ namespace cl { class Padding : public GPUOperation { public: Padding(const OperationDef& definition, const PadAttributes& attr); - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Padding(Padding&& kernel); @@ -40,7 +40,7 @@ class Padding : public GPUOperation { Padding& operator=(const Padding&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; PadAttributes attributes_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc index fb985461c02..17705782f93 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc @@ -408,7 +408,7 @@ Pooling& Pooling::operator=(Pooling&& kernel) { return *this; } -absl::Status Pooling::Compile(const CreationContext& creation_context) { +Status Pooling::Compile(const CreationContext& creation_context) { std::string code; const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; @@ -423,7 +423,7 @@ absl::Status Pooling::Compile(const CreationContext& creation_context) { linked_operations_, output_indices_); break; default: - return absl::InvalidArgumentError( + return InvalidArgumentError( "You should create another kernel with this params"); break; } @@ -432,7 +432,7 @@ absl::Status Pooling::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Pooling::BindArguments() { +Status Pooling::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -447,7 +447,7 @@ absl::Status Pooling::BindArguments() { kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y))); RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_)); - return absl::OkStatus(); + return OkStatus(); } int3 Pooling::GetGridSize() const { @@ -457,12 +457,12 @@ int3 Pooling::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Pooling::Tune(const TuningParameters& params) { +Status Pooling::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status Pooling::AddToQueue(CLCommandQueue* queue) { +Status Pooling::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } @@ -506,7 +506,7 @@ Pooling3D& Pooling3D::operator=(Pooling3D&& kernel) { return *this; } -absl::Status Pooling3D::Compile(const CreationContext& creation_context) { +Status Pooling3D::Compile(const CreationContext& creation_context) { std::string code; const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; @@ -521,7 +521,7 @@ absl::Status Pooling3D::Compile(const CreationContext& creation_context) { linked_operations_, output_indices_); break; default: - return absl::InvalidArgumentError( + return InvalidArgumentError( "You should create another kernel with this params"); break; } @@ -530,7 +530,7 @@ absl::Status Pooling3D::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Pooling3D::BindArguments() { +Status Pooling3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -550,7 +550,7 @@ absl::Status Pooling3D::BindArguments() { RETURN_IF_ERROR( kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1))); - return absl::OkStatus(); + return OkStatus(); } int3 Pooling3D::GetGridSize() const { @@ -560,12 +560,12 @@ int3 Pooling3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Pooling3D::Tune(const TuningParameters& params) { +Status Pooling3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status Pooling3D::AddToQueue(CLCommandQueue* queue) { +Status Pooling3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h index 09d2d5260f7..eaeb188f19e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h @@ -30,10 +30,10 @@ namespace cl { class Pooling : public GPUOperation { public: Pooling(const OperationDef& definition, const Pooling2DAttributes& attr); - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Pooling(Pooling&& kernel); @@ -42,7 +42,7 @@ class Pooling : public GPUOperation { Pooling& operator=(const Pooling&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; int2 stride_; @@ -62,10 +62,10 @@ Pooling CreatePooling(const OperationDef& definition, class Pooling3D : public GPUOperation { public: Pooling3D(const OperationDef& definition, const Pooling3DAttributes& attr); - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Pooling3D(Pooling3D&& kernel); @@ -74,7 +74,7 @@ class Pooling3D : public GPUOperation { Pooling3D& operator=(const Pooling3D&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; int3 stride_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc index 1879d390ad6..8aa357b91b4 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc @@ -73,21 +73,21 @@ std::string PReLU::GetArgsDeclaration() const { return args; } -absl::Status PReLU::BindArguments(CLKernel* kernel) { +Status PReLU::BindArguments(CLKernel* kernel) { RETURN_IF_ERROR(kernel->SetMemoryAuto(alpha_.GetMemoryPtr())); if (clip_.Active()) { RETURN_IF_ERROR(kernel->SetBytesAuto(clip_)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreatePReLU(const CreationContext& creation_context, - const OperationDef& definition, - const PReLUAttributes& attr, PReLU* result) { +Status CreatePReLU(const CreationContext& creation_context, + const OperationDef& definition, const PReLUAttributes& attr, + PReLU* result) { auto alpha = absl::get_if<::tflite::gpu::Tensor>( &attr.alpha); if (!alpha) { - return absl::InvalidArgumentError("Alpha is missing"); + return InvalidArgumentError("Alpha is missing"); } const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 @@ -95,7 +95,7 @@ absl::Status CreatePReLU(const CreationContext& creation_context, *result = PReLU(definition, attr, scalar_precision); RETURN_IF_ERROR(result->UploadParameters(*alpha, creation_context.context)); result->SetLinkIndex(0); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h index 4ba0a92158f..0feb387e644 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h @@ -44,30 +44,30 @@ class PReLU : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - absl::Status BindArguments(CLKernel* kernel) override; + Status BindArguments(CLKernel* kernel) override; - friend absl::Status CreatePReLU(const CreationContext& creation_context, - const OperationDef& definition, - const PReLUAttributes& attr, PReLU* result); + friend Status CreatePReLU(const CreationContext& creation_context, + const OperationDef& definition, + const PReLUAttributes& attr, PReLU* result); private: PReLU(const OperationDef& definition, const PReLUAttributes& attr, CalculationsPrecision scalar_precision); template - absl::Status UploadParameters( - const ::tflite::gpu::Tensor& parameters, CLContext* context); + Status UploadParameters(const ::tflite::gpu::Tensor& parameters, + CLContext* context); FLT clip_; LinearStorage alpha_; }; -absl::Status CreatePReLU(const CreationContext& creation_context, - const OperationDef& definition, - const PReLUAttributes& attr, PReLU* result); +Status CreatePReLU(const CreationContext& creation_context, + const OperationDef& definition, const PReLUAttributes& attr, + PReLU* result); template -absl::Status PReLU::UploadParameters( +Status PReLU::UploadParameters( const ::tflite::gpu::Tensor& parameters, CLContext* context) { LinearStorageCreateInfo create_info; create_info.storage_type = @@ -75,7 +75,7 @@ absl::Status PReLU::UploadParameters( create_info.data_type = definition_.GetPrimaryDataType(); RETURN_IF_ERROR( CreateLinearStorage(create_info, parameters, context, &alpha_)); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc index e0346a66ff9..f7751fac6ff 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc @@ -92,17 +92,17 @@ std::string QuantizeAndDequantize::GetArgsDeclaration() const { scale_.GetDeclaration()); } -absl::Status QuantizeAndDequantize::BindArguments(CLKernel* kernel) { +Status QuantizeAndDequantize::BindArguments(CLKernel* kernel) { RETURN_IF_ERROR(kernel->SetBytesAuto(min_)); RETURN_IF_ERROR(kernel->SetBytesAuto(max_)); RETURN_IF_ERROR(kernel->SetBytesAuto(scale_)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateQuantizeAndDequantize( - const CreationContext& creation_context, const OperationDef& definition, - const QuantizeAndDequantizeAttributes& attr, - QuantizeAndDequantize* result) { +Status CreateQuantizeAndDequantize(const CreationContext& creation_context, + const OperationDef& definition, + const QuantizeAndDequantizeAttributes& attr, + QuantizeAndDequantize* result) { const auto scalar_precision = creation_context.device->IsPowerVR() ? CalculationsPrecision::F32 : definition.precision; @@ -120,7 +120,7 @@ absl::Status CreateQuantizeAndDequantize( *result = QuantizeAndDequantize(definition, attr, scalar_precision); } result->SetLinkIndex(0); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h index 41c295e881d..07fa8f21773 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h @@ -57,9 +57,9 @@ class QuantizeAndDequantize : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - absl::Status BindArguments(CLKernel* kernel) override; + Status BindArguments(CLKernel* kernel) override; - friend absl::Status CreateQuantizeAndDequantize( + friend Status CreateQuantizeAndDequantize( const CreationContext& creation_context, const OperationDef& definition, const QuantizeAndDequantizeAttributes& attr, QuantizeAndDequantize* result); @@ -70,26 +70,27 @@ class QuantizeAndDequantize : public ElementwiseOperation { CalculationsPrecision scalar_precision); template - absl::Status UploadParameters( - const ::tflite::gpu::Tensor& parameters, CLContext* context); + Status UploadParameters(const ::tflite::gpu::Tensor& parameters, + CLContext* context); FLT min_; FLT max_; FLT scale_; }; -absl::Status CreateQuantizeAndDequantize( - const CreationContext& creation_context, const OperationDef& definition, - const QuantizeAndDequantizeAttributes& attr, QuantizeAndDequantize* result); +Status CreateQuantizeAndDequantize(const CreationContext& creation_context, + const OperationDef& definition, + const QuantizeAndDequantizeAttributes& attr, + QuantizeAndDequantize* result); template -absl::Status QuantizeAndDequantize::UploadParameters( +Status QuantizeAndDequantize::UploadParameters( const ::tflite::gpu::Tensor& parameters, CLContext* context) { LinearStorageCreateInfo create_info; create_info.storage_type = DeduceLinearStorageType(definition_.GetPrimaryStorageType()); create_info.data_type = definition_.GetPrimaryDataType(); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/relu.cc b/tensorflow/lite/delegates/gpu/cl/kernels/relu.cc index a96db2aa45e..ce903972c35 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/relu.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/relu.cc @@ -80,14 +80,14 @@ std::string ReLU::GetArgsDeclaration() const { return args; } -absl::Status ReLU::BindArguments(CLKernel* kernel) { +Status ReLU::BindArguments(CLKernel* kernel) { if (alpha_.Active()) { RETURN_IF_ERROR(kernel->SetBytesAuto(alpha_)); } if (clip_.Active()) { RETURN_IF_ERROR(kernel->SetBytesAuto(clip_)); } - return absl::OkStatus(); + return OkStatus(); } ReLU CreateReLU(const CreationContext& creation_context, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/relu.h b/tensorflow/lite/delegates/gpu/cl/kernels/relu.h index c8260a33faf..c4fb68588d3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/relu.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/relu.h @@ -37,7 +37,7 @@ class ReLU : public ElementwiseOperation { void SetLinkIndex(int index) override; std::string GetCoreCode(const LinkingContext& context) const override; std::string GetArgsDeclaration() const override; - absl::Status BindArguments(CLKernel* kernel) override; + Status BindArguments(CLKernel* kernel) override; friend ReLU CreateReLU(const CreationContext& creation_context, const OperationDef& definition, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc index e1589e9d682..3bb3cdd5d22 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc @@ -156,7 +156,7 @@ Reshape& Reshape::operator=(Reshape&& operation) { return *this; } -absl::Status Reshape::Compile(const CreationContext& creation_context) { +Status Reshape::Compile(const CreationContext& creation_context) { const auto code = definition_.IsBatchSupported() ? GetReshapeBatchedCode(definition_, linked_operations_) : GetReshapeCode(definition_, linked_operations_); @@ -165,7 +165,7 @@ absl::Status Reshape::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Reshape::BindArguments() { +Status Reshape::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -174,7 +174,8 @@ absl::Status Reshape::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels())); - return absl::OkStatus(); + + return OkStatus(); } int3 Reshape::GetGridSize() const { @@ -184,12 +185,12 @@ int3 Reshape::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Reshape::Tune(const TuningParameters& params) { +Status Reshape::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status Reshape::AddToQueue(CLCommandQueue* queue) { +Status Reshape::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h index e11c066ebd3..2117ef05907 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h @@ -29,10 +29,10 @@ class Reshape : public GPUOperation { public: explicit Reshape(const OperationDef& definition) : GPUOperation(definition), work_group_size_(8, 4, 1) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Reshape(Reshape&& operation); @@ -41,7 +41,7 @@ class Reshape : public GPUOperation { Reshape& operator=(const Reshape&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc index de6813e741f..3741a02aa5b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc @@ -120,7 +120,7 @@ Reshapex4& Reshapex4::operator=(Reshapex4&& operation) { return *this; } -absl::Status Reshapex4::Compile(const CreationContext& creation_context) { +Status Reshapex4::Compile(const CreationContext& creation_context) { const auto code = definition_.IsBatchSupported() ? GetReshapeBatchedCode(definition_, linked_operations_) : GetReshapeCode(definition_, linked_operations_); @@ -129,14 +129,15 @@ absl::Status Reshapex4::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Reshapex4::BindArguments() { +Status Reshapex4::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - return absl::OkStatus(); + + return OkStatus(); } int3 Reshapex4::GetGridSize() const { @@ -146,12 +147,12 @@ int3 Reshapex4::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Reshapex4::Tune(const TuningParameters& params) { +Status Reshapex4::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status Reshapex4::AddToQueue(CLCommandQueue* queue) { +Status Reshapex4::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h index d61224a7367..656e299b547 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h @@ -30,10 +30,10 @@ class Reshapex4 : public GPUOperation { public: explicit Reshapex4(const OperationDef& definition) : GPUOperation(definition), work_group_size_(8, 4, 1) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Reshapex4(Reshapex4&& operation); @@ -42,7 +42,7 @@ class Reshapex4 : public GPUOperation { Reshapex4& operator=(const Reshapex4&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc index 5d578fe6e09..bd109020004 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc @@ -209,7 +209,7 @@ Resize& Resize::operator=(Resize&& operation) { return *this; } -absl::Status Resize::Compile(const CreationContext& creation_context) { +Status Resize::Compile(const CreationContext& creation_context) { const auto code = GetResizeCode(definition_, attr_.type, attr_.half_pixel_centers, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -217,7 +217,7 @@ absl::Status Resize::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Resize::BindArguments() { +Status Resize::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -230,7 +230,7 @@ absl::Status Resize::BindArguments() { float2(CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_), CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor)); - return absl::OkStatus(); + return OkStatus(); } int3 Resize::GetGridSize() const { @@ -240,12 +240,12 @@ int3 Resize::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Resize::AddToQueue(CLCommandQueue* queue) { +Status Resize::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status Resize::Tune(const TuningParameters& params) { +Status Resize::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } @@ -271,7 +271,7 @@ Resize3D& Resize3D::operator=(Resize3D&& operation) { return *this; } -absl::Status Resize3D::Compile(const CreationContext& creation_context) { +Status Resize3D::Compile(const CreationContext& creation_context) { const auto code = GetResize3DCode(definition_, attr_.type, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -279,7 +279,7 @@ absl::Status Resize3D::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status Resize3D::BindArguments() { +Status Resize3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -296,7 +296,7 @@ absl::Status Resize3D::BindArguments() { CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_), CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_), 1.0f); RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor)); - return absl::OkStatus(); + return OkStatus(); } int3 Resize3D::GetGridSize() const { @@ -306,12 +306,12 @@ int3 Resize3D::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Resize3D::AddToQueue(CLCommandQueue* queue) { +Status Resize3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status Resize3D::Tune(const TuningParameters& params) { +Status Resize3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h index 04459e12ff9..a80f9a98382 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h @@ -27,10 +27,10 @@ namespace cl { class Resize : public GPUOperation { public: - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Resize(Resize&& operation); @@ -45,7 +45,7 @@ class Resize : public GPUOperation { Resize(const OperationDef& definition, const Resize2DAttributes& attr) : GPUOperation(definition), attr_(attr) {} - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Resize2DAttributes attr_; @@ -58,10 +58,10 @@ Resize CreateResize(const OperationDef& definition, class Resize3D : public GPUOperation { public: - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Resize3D(Resize3D&& operation); @@ -76,7 +76,7 @@ class Resize3D : public GPUOperation { Resize3D(const OperationDef& definition, const Resize3DAttributes& attr) : GPUOperation(definition), attr_(attr) {} - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; Resize3DAttributes attr_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc index 0f9fcb03097..350abf7f64e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc @@ -79,14 +79,14 @@ Softmax& Softmax::operator=(Softmax&& kernel) { return *this; } -absl::Status Softmax::Compile(const CreationContext& creation_context) { +Status Softmax::Compile(const CreationContext& creation_context) { const auto code = GetSoftmaxKernelCode(definition_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -absl::Status Softmax::BindArguments() { +Status Softmax::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -94,7 +94,7 @@ absl::Status Softmax::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR( kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels()))); - return absl::OkStatus(); + return OkStatus(); } int3 Softmax::GetGridSize() const { @@ -104,12 +104,12 @@ int3 Softmax::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Softmax::Tune(const TuningParameters& params) { +Status Softmax::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status Softmax::AddToQueue(CLCommandQueue* queue) { +Status Softmax::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h index 703a40a4e89..b8b7846e8de 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h @@ -30,10 +30,10 @@ class Softmax : public GPUOperation { public: Softmax() = default; explicit Softmax(const OperationDef& definition) : GPUOperation(definition) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Softmax(Softmax&& kernel); @@ -44,7 +44,7 @@ class Softmax : public GPUOperation { friend Softmax CreateSoftmax(); private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; CLKernel kernel_; int3 work_group_size_ = int3(8, 4, 1); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc index 09e6c978026..168dc6ce4a9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc @@ -115,14 +115,14 @@ Softmax1x1& Softmax1x1::operator=(Softmax1x1&& kernel) { return *this; } -absl::Status Softmax1x1::Compile(const CreationContext& creation_context) { +Status Softmax1x1::Compile(const CreationContext& creation_context) { const auto code = GetSoftmaxKernelCode(definition_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -absl::Status Softmax1x1::AddToQueue(CLCommandQueue* queue) { +Status Softmax1x1::AddToQueue(CLCommandQueue* queue) { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h index 0d28145ca03..0fd5325a863 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h @@ -30,9 +30,9 @@ class Softmax1x1 : public GPUOperation { Softmax1x1() = default; explicit Softmax1x1(const OperationDef& definition) : GPUOperation(definition) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; + Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only Softmax1x1(Softmax1x1&& kernel); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc index b763684516a..db6882ce4f4 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc @@ -96,14 +96,14 @@ SpaceToDepth& SpaceToDepth::operator=(SpaceToDepth&& operation) { return *this; } -absl::Status SpaceToDepth::Compile(const CreationContext& creation_context) { +Status SpaceToDepth::Compile(const CreationContext& creation_context) { const auto code = GetSpaceToDepthCode(definition_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -absl::Status SpaceToDepth::BindArguments() { +Status SpaceToDepth::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -121,12 +121,12 @@ int3 SpaceToDepth::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status SpaceToDepth::Tune(const TuningParameters& params) { +Status SpaceToDepth::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status SpaceToDepth::AddToQueue(CLCommandQueue* queue) { +Status SpaceToDepth::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h index 9dd257a4c4d..3d316569fcb 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h @@ -30,9 +30,9 @@ class SpaceToDepth : public GPUOperation { public: SpaceToDepth(const OperationDef& op_def, const SpaceToDepthAttributes& attr) : GPUOperation(op_def), attr_(attr), work_group_size_(8, 4, 1) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; + Status Compile(const CreationContext& creation_context) override; SpaceToDepth(SpaceToDepth&& operation); SpaceToDepth& operator=(SpaceToDepth&& operation); @@ -40,7 +40,7 @@ class SpaceToDepth : public GPUOperation { SpaceToDepth& operator=(const SpaceToDepth&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; SpaceToDepthAttributes attr_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc index 19f1b185d3c..4f5cf9b26c7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc @@ -166,7 +166,7 @@ StridedSlice& StridedSlice::operator=(StridedSlice&& operation) { return *this; } -absl::Status StridedSlice::Compile(const CreationContext& creation_context) { +Status StridedSlice::Compile(const CreationContext& creation_context) { const auto code = GetStridedSliceCode(definition_, Is4Aligned(attributes_), linked_operations_); return creation_context.cache->GetOrCreateCLKernel( @@ -174,7 +174,7 @@ absl::Status StridedSlice::Compile(const CreationContext& creation_context) { *creation_context.device, &kernel_); } -absl::Status StridedSlice::BindArguments() { +Status StridedSlice::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -187,7 +187,7 @@ absl::Status StridedSlice::BindArguments() { attributes_.strides.c, attributes_.strides.b))); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - return absl::OkStatus(); + return OkStatus(); } int3 StridedSlice::GetGridSize() const { @@ -197,12 +197,12 @@ int3 StridedSlice::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status StridedSlice::Tune(const TuningParameters& params) { +Status StridedSlice::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status StridedSlice::AddToQueue(CLCommandQueue* queue) { +Status StridedSlice::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h index ee6f18fdacb..f30f6777134 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h @@ -27,10 +27,10 @@ namespace cl { class StridedSlice : public GPUOperation { public: StridedSlice(const OperationDef& definition, const SliceAttributes& attr); - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status Compile(const CreationContext& creation_context) override; // Move only StridedSlice(StridedSlice&& operation); @@ -39,7 +39,7 @@ class StridedSlice : public GPUOperation { StridedSlice& operator=(const StridedSlice&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; SliceAttributes attributes_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc index 66a272fa2da..cab9b728866 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc @@ -125,14 +125,14 @@ Transpose& Transpose::operator=(Transpose&& operation) { return *this; } -absl::Status Transpose::Compile(const CreationContext& creation_context) { +Status Transpose::Compile(const CreationContext& creation_context) { const auto code = GetTransposeCode(definition_, attr_, linked_operations_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } -absl::Status Transpose::BindArguments() { +Status Transpose::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); @@ -141,7 +141,8 @@ absl::Status Transpose::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels())); - return absl::OkStatus(); + + return OkStatus(); } int3 Transpose::GetGridSize() const { @@ -151,12 +152,12 @@ int3 Transpose::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Transpose::Tune(const TuningParameters& params) { +Status Transpose::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status Transpose::AddToQueue(CLCommandQueue* queue) { +Status Transpose::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h index 61038b1e0ca..22c155a79ba 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h @@ -28,9 +28,9 @@ class Transpose : public GPUOperation { public: Transpose(const OperationDef& definition, const TransposeAttributes& attr) : GPUOperation(definition), attr_(attr), work_group_size_(8, 4, 1) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; + Status Compile(const CreationContext& creation_context) override; // Move only Transpose(Transpose&& operation); @@ -39,7 +39,7 @@ class Transpose : public GPUOperation { Transpose& operator=(const Transpose&) = delete; private: - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; TransposeAttributes attr_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc index 81a8fc690c4..9bb89874c3d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc @@ -381,7 +381,7 @@ Winograd4x4To36& Winograd4x4To36::operator=(Winograd4x4To36&& operation) { return *this; } -absl::Status Winograd4x4To36::Compile(const CreationContext& creation_context) { +Status Winograd4x4To36::Compile(const CreationContext& creation_context) { std::vector options; if (creation_context.device->IsAdreno()) { options.push_back(CompilerOptions::ADRENO_MORE_WAVES); @@ -397,10 +397,10 @@ absl::Status Winograd4x4To36::Compile(const CreationContext& creation_context) { code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); work_group_size_ = SelectBestWorkGroup(); - return absl::OkStatus(); + return OkStatus(); } -absl::Status Winograd4x4To36::UploadBt(CLContext* context) { +Status Winograd4x4To36::UploadBt(CLContext* context) { ::tflite::gpu::Tensor bt_aligned; bt_aligned.shape = Linear(6 * 8); bt_aligned.data.resize(6 * 8); @@ -427,7 +427,7 @@ int3 Winograd4x4To36::SelectBestWorkGroup() { return GetFirstSuitableWorkGroup(wgs, kernel_.GetMaxWorkGroupSize()); } -absl::Status Winograd4x4To36::BindArguments() { +Status Winograd4x4To36::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(bt_.GetMemoryPtr())); @@ -444,7 +444,8 @@ absl::Status Winograd4x4To36::BindArguments() { kernel_.SetBytesAuto(int2(-padding_.prepended.w, -padding_.prepended.h))); RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_total)); RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_x)); - return absl::OkStatus(); + + return OkStatus(); } int3 Winograd4x4To36::GetGridSize() const { @@ -454,7 +455,7 @@ int3 Winograd4x4To36::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Winograd4x4To36::Tune(const TuningParameters& params) { +Status Winograd4x4To36::Tune(const TuningParameters& params) { switch (params.tuning_type) { case TuningType::EXHAUSTIVE: RETURN_IF_ERROR(BindArguments()); @@ -463,19 +464,19 @@ absl::Status Winograd4x4To36::Tune(const TuningParameters& params) { case TuningType::FAST: default: work_group_size_ = SelectBestWorkGroup(); - return absl::OkStatus(); + return OkStatus(); } } -absl::Status Winograd4x4To36::AddToQueue(CLCommandQueue* queue) { +Status Winograd4x4To36::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateWinograd4x4To36(const CreationContext& creation_context, - const OperationDef& definition, - const Padding2D& padding, - Winograd4x4To36* result) { +Status CreateWinograd4x4To36(const CreationContext& creation_context, + const OperationDef& definition, + const Padding2D& padding, + Winograd4x4To36* result) { *result = Winograd4x4To36(definition, padding); return result->UploadBt(creation_context.context); } @@ -498,7 +499,7 @@ Winograd36To4x4& Winograd36To4x4::operator=(Winograd36To4x4&& operation) { return *this; } -absl::Status Winograd36To4x4::Compile(const CreationContext& creation_context) { +Status Winograd36To4x4::Compile(const CreationContext& creation_context) { std::vector options; if (definition_.precision == CalculationsPrecision::F16 && creation_context.device->IsPowerVR()) { @@ -510,10 +511,10 @@ absl::Status Winograd36To4x4::Compile(const CreationContext& creation_context) { code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); work_group_size_ = SelectBestWorkGroup(); - return absl::OkStatus(); + return OkStatus(); } -absl::Status Winograd36To4x4::UploadAt(CLContext* context) { +Status Winograd36To4x4::UploadAt(CLContext* context) { ::tflite::gpu::Tensor at_aligned; at_aligned.shape = Linear(4 * 8); at_aligned.data.resize(4 * 8); @@ -540,7 +541,7 @@ int3 Winograd36To4x4::SelectBestWorkGroup() { return GetFirstSuitableWorkGroup(wgs, kernel_.GetMaxWorkGroupSize()); } -absl::Status Winograd36To4x4::BindArguments() { +Status Winograd36To4x4::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(at_.GetMemoryPtr())); @@ -551,7 +552,8 @@ absl::Status Winograd36To4x4::BindArguments() { RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); const int tiles_x = IntegralDivideRoundUp(dst_[0]->Width(), 4); RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_x)); - return absl::OkStatus(); + + return OkStatus(); } int3 Winograd36To4x4::GetGridSize() const { @@ -563,7 +565,7 @@ int3 Winograd36To4x4::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } -absl::Status Winograd36To4x4::Tune(const TuningParameters& params) { +Status Winograd36To4x4::Tune(const TuningParameters& params) { switch (params.tuning_type) { case TuningType::EXHAUSTIVE: RETURN_IF_ERROR(BindArguments()); @@ -572,16 +574,16 @@ absl::Status Winograd36To4x4::Tune(const TuningParameters& params) { case TuningType::FAST: default: work_group_size_ = SelectBestWorkGroup(); - return absl::OkStatus(); + return OkStatus(); } } -absl::Status Winograd36To4x4::AddToQueue(CLCommandQueue* queue) { +Status Winograd36To4x4::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateWinograd36To4x4( +Status CreateWinograd36To4x4( const CreationContext& creation_context, const OperationDef& definition, const ::tflite::gpu::Tensor& biases, Winograd36To4x4* result) { @@ -592,6 +594,7 @@ absl::Status CreateWinograd36To4x4( create_info.name = "biases"; RETURN_IF_ERROR(CreateLinearStorage( create_info, biases, creation_context.context, &result->biases_)); + return result->UploadAt(creation_context.context); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h index 5a0444c4be5..f6b80b67f32 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h @@ -36,9 +36,9 @@ class Winograd4x4To36 : public GPUOperation { Winograd4x4To36() = default; Winograd4x4To36(const OperationDef& definition, const Padding2D& padding) : GPUOperation(definition), padding_(padding) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; + Status Compile(const CreationContext& creation_context) override; // Move only Winograd4x4To36(Winograd4x4To36&& operation); @@ -47,16 +47,17 @@ class Winograd4x4To36 : public GPUOperation { Winograd4x4To36& operator=(const Winograd4x4To36&) = delete; private: - friend absl::Status CreateWinograd4x4To36( - const CreationContext& creation_context, const OperationDef& definition, - const Padding2D& padding, Winograd4x4To36* result); + friend Status CreateWinograd4x4To36(const CreationContext& creation_context, + const OperationDef& definition, + const Padding2D& padding, + Winograd4x4To36* result); - absl::Status UploadBt(CLContext* context); + Status UploadBt(CLContext* context); // Must be called after kernel compilation int3 SelectBestWorkGroup(); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; LinearStorage bt_; @@ -66,19 +67,18 @@ class Winograd4x4To36 : public GPUOperation { int3 work_group_size_ = int3(128, 1, 1); }; -absl::Status CreateWinograd4x4To36(const CreationContext& creation_context, - const OperationDef& definition, - const Padding2D& padding, - Winograd4x4To36* result); +Status CreateWinograd4x4To36(const CreationContext& creation_context, + const OperationDef& definition, + const Padding2D& padding, Winograd4x4To36* result); class Winograd36To4x4 : public GPUOperation { public: Winograd36To4x4() = default; explicit Winograd36To4x4(const OperationDef& definition) : GPUOperation(definition) {} - absl::Status AddToQueue(CLCommandQueue* queue) override; - absl::Status Tune(const TuningParameters& params) override; - absl::Status Compile(const CreationContext& creation_context) override; + Status AddToQueue(CLCommandQueue* queue) override; + Status Tune(const TuningParameters& params) override; + Status Compile(const CreationContext& creation_context) override; // Move only Winograd36To4x4(Winograd36To4x4&& operation); @@ -87,17 +87,17 @@ class Winograd36To4x4 : public GPUOperation { Winograd36To4x4& operator=(const Winograd36To4x4&) = delete; private: - friend absl::Status CreateWinograd36To4x4( + friend Status CreateWinograd36To4x4( const CreationContext& creation_context, const OperationDef& definition, const ::tflite::gpu::Tensor& biases, Winograd36To4x4* result); - absl::Status UploadAt(CLContext* context); + Status UploadAt(CLContext* context); // Must be called after kernel compilation int3 SelectBestWorkGroup(); - absl::Status BindArguments(); + Status BindArguments(); int3 GetGridSize() const; LinearStorage at_; @@ -107,7 +107,7 @@ class Winograd36To4x4 : public GPUOperation { int3 work_group_size_ = int3(128, 1, 1); }; -absl::Status CreateWinograd36To4x4( +Status CreateWinograd36To4x4( const CreationContext& creation_context, const OperationDef& definition, const ::tflite::gpu::Tensor& biases, Winograd36To4x4* result); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc index 683116091b8..7a2e54840b9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc @@ -75,10 +75,9 @@ std::vector GenerateWorkGroupSizesXY128Linear( return work_groups; } -absl::Status GetBestWorkGroupAlignedToGrid(const TuningParameters& params, - const CLKernel& kernel, - const int3& grid, - int3* best_work_group) { +Status GetBestWorkGroupAlignedToGrid(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + int3* best_work_group) { std::vector work_groups; RETURN_IF_ERROR(GenerateWorkGroupSizesAlignedToGrid( grid, params.info->max_work_group_sizes, kernel.GetMaxWorkGroupSize(), @@ -87,7 +86,7 @@ absl::Status GetBestWorkGroupAlignedToGrid(const TuningParameters& params, RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex( kernel, *params.info, grid, work_groups, &best_work_group_index)); *best_work_group = work_groups[best_work_group_index]; - return absl::OkStatus(); + return OkStatus(); } int GetPenalty(int grid_size, int group_size) { @@ -203,31 +202,30 @@ int3 GetWorkGroupConv(const int3& grid, int max_size, int max_z_size) { return int3(wg_x, wg_y, wg_z); } -absl::Status GetBestWorkGroupXY128(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - WorkGroupSizeAlignment z_alignment, - int3* best_work_group) { +Status GetBestWorkGroupXY128(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + WorkGroupSizeAlignment z_alignment, + int3* best_work_group) { std::vector work_groups = GenerateWorkGroupSizesXY128( grid, kernel.GetMaxWorkGroupSize(), z_alignment); int best_work_group_index; RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex( kernel, *params.info, grid, work_groups, &best_work_group_index)); *best_work_group = work_groups[best_work_group_index]; - return absl::OkStatus(); + return OkStatus(); } -absl::Status GetBestWorkGroupXY128Linear(const TuningParameters& params, - const CLKernel& kernel, - const int3& grid, - WorkGroupSizeAlignment z_alignment, - int3* best_work_group) { +Status GetBestWorkGroupXY128Linear(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + WorkGroupSizeAlignment z_alignment, + int3* best_work_group) { std::vector work_groups = GenerateWorkGroupSizesXY128Linear( grid, kernel.GetMaxWorkGroupSize(), z_alignment); int best_work_group_index; RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex( kernel, *params.info, grid, work_groups, &best_work_group_index)); *best_work_group = work_groups[best_work_group_index]; - return absl::OkStatus(); + return OkStatus(); } bool XY128RequiresMoreWorkGroupsThenXY128Linear(int width, int height) { @@ -246,25 +244,24 @@ bool XY128RequiresMoreWorkGroupsThenXY128Linear(int width, int height) { return !have_equal_work_groups; } -absl::Status GetBestWorkGroup(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - int3* best_work_group) { +Status GetBestWorkGroup(const TuningParameters& params, const CLKernel& kernel, + const int3& grid, int3* best_work_group) { switch (params.tuning_type) { case TuningType::FAST: *best_work_group = GetWorkGroup(grid, kernel.GetMaxWorkGroupSize()); - return absl::OkStatus(); + return OkStatus(); case TuningType::EXHAUSTIVE: return GetBestWorkGroupAlignedToGrid(params, kernel, grid, best_work_group); default: *best_work_group = {8, 4, 1}; - return absl::OkStatus(); + return OkStatus(); } } -absl::Status GetBestWorkGroupConv(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - int3* best_work_group) { +Status GetBestWorkGroupConv(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + int3* best_work_group) { switch (params.tuning_type) { case TuningType::FAST: { int max_z_size = 16; @@ -274,14 +271,14 @@ absl::Status GetBestWorkGroupConv(const TuningParameters& params, max_z_size = std::min(max_z_size, params.info->max_work_group_sizes.z); *best_work_group = GetWorkGroupConv(grid, kernel.GetMaxWorkGroupSize(), max_z_size); - return absl::OkStatus(); + return OkStatus(); } case TuningType::EXHAUSTIVE: return GetBestWorkGroupAlignedToGrid(params, kernel, grid, best_work_group); default: *best_work_group = {8, 4, 1}; - return absl::OkStatus(); + return OkStatus(); } } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h index 7cc60f4723f..4b9801e6009 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h @@ -31,17 +31,16 @@ namespace cl { // Here and later you can find XY128, this is because 128 is SIMD width of A6xx // And XY128 means that work_group_size.x * work_group_size.y % 128 = 0 // We need it to correctly work with constants uploading on A6xx -absl::Status GetBestWorkGroupXY128(const TuningParameters& params, +Status GetBestWorkGroupXY128(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + WorkGroupSizeAlignment z_alignment, + int3* best_work_group); + +Status GetBestWorkGroupXY128Linear(const TuningParameters& params, const CLKernel& kernel, const int3& grid, WorkGroupSizeAlignment z_alignment, int3* best_work_group); -absl::Status GetBestWorkGroupXY128Linear(const TuningParameters& params, - const CLKernel& kernel, - const int3& grid, - WorkGroupSizeAlignment z_alignment, - int3* best_work_group); - int3 GetWorkGroupXY128ConvLinear(const int3& grid); int3 GetWorkGroupXY128Simple(const int3& grid); @@ -49,13 +48,12 @@ int3 GetWorkGroupXY128Conv(const int3& grid); bool XY128RequiresMoreWorkGroupsThenXY128Linear(int width, int height); -absl::Status GetBestWorkGroup(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - int3* best_work_group); +Status GetBestWorkGroup(const TuningParameters& params, const CLKernel& kernel, + const int3& grid, int3* best_work_group); -absl::Status GetBestWorkGroupConv(const TuningParameters& params, - const CLKernel& kernel, const int3& grid, - int3* best_work_group); +Status GetBestWorkGroupConv(const TuningParameters& params, + const CLKernel& kernel, const int3& grid, + int3* best_work_group); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc index 4fb21d0ec6a..cd7fe729c7d 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" - namespace tflite { namespace gpu { namespace cl { @@ -75,31 +73,29 @@ LinearStorageType DeduceLinearStorageType( } } -absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data, - CLContext* context, - LinearStorage* result) { +Status CreateBufferLinearStorage(int size, DataType data_type, void* data, + CLContext* context, LinearStorage* result) { const int float4_size = data_type == DataType::FLOAT32 ? sizeof(float4) : sizeof(half4); *result = LinearStorage(size, LinearStorageType::BUFFER, data_type); RETURN_IF_ERROR(CreateReadOnlyBuffer(float4_size * size, data, context, &result->buffer_storage_)); result->memory_ = result->buffer_storage_.GetMemoryPtr(); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateTextureLinearStorage(int size, DataType data_type, - void* data, CLContext* context, - LinearStorage* result) { +Status CreateTextureLinearStorage(int size, DataType data_type, void* data, + CLContext* context, LinearStorage* result) { *result = LinearStorage(size, LinearStorageType::TEXTURE_2D, data_type); RETURN_IF_ERROR(CreateTexture2DRGBA(data_type, size, 1, data, context, &result->texture_storage_)); result->memory_ = result->texture_storage_.GetMemoryPtr(); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, - int size, void* data, CLContext* context, - LinearStorage* result) { +Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, + int size, void* data, CLContext* context, + LinearStorage* result) { if (creation_info.storage_type == LinearStorageType::BUFFER) { return CreateBufferLinearStorage(size, creation_info.data_type, data, context, result); diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.h b/tensorflow/lite/delegates/gpu/cl/linear_storage.h index 93aecd57854..3d3d9d5222f 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.h +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.h @@ -64,12 +64,12 @@ class LinearStorage { std::string GetDeclaration() const; private: - friend absl::Status CreateTextureLinearStorage(int size, DataType data_type, - void* data, CLContext* context, - LinearStorage* result); - friend absl::Status CreateBufferLinearStorage(int size, DataType data_type, - void* data, CLContext* context, - LinearStorage* result); + friend Status CreateTextureLinearStorage(int size, DataType data_type, + void* data, CLContext* context, + LinearStorage* result); + friend Status CreateBufferLinearStorage(int size, DataType data_type, + void* data, CLContext* context, + LinearStorage* result); LinearStorage(int depth, LinearStorageType storage_type, DataType data_type); @@ -83,22 +83,20 @@ class LinearStorage { DataType data_type_; }; -absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data, - CLContext* context, - LinearStorage* result); +Status CreateBufferLinearStorage(int size, DataType data_type, void* data, + CLContext* context, LinearStorage* result); -absl::Status CreateTextureLinearStorage(int size, DataType data_type, - void* data, CLContext* context, - LinearStorage* result); +Status CreateTextureLinearStorage(int size, DataType data_type, void* data, + CLContext* context, LinearStorage* result); -absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, - int size, void* data, CLContext* context, - LinearStorage* result); +Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, + int size, void* data, CLContext* context, + LinearStorage* result); template -absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, - const ::tflite::gpu::Tensor& tensor, - CLContext* context, LinearStorage* result) { +Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, + const ::tflite::gpu::Tensor& tensor, + CLContext* context, LinearStorage* result) { int size = creation_info.aligned_size != 0 ? creation_info.aligned_size : tensor.shape.v; const int depth = IntegralDivideRoundUp(size, 4); @@ -114,7 +112,7 @@ absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info, context, result)); } result->SetName(creation_info.name); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc index be551bc9973..3b471ce816c 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc @@ -31,11 +31,11 @@ namespace cl { function = reinterpret_cast(dlsym(libopencl, #function)); \ } -absl::Status LoadOpenCL() { +Status LoadOpenCL() { void* libopencl = dlopen("libOpenCL.so", RTLD_NOW | RTLD_LOCAL); if (libopencl) { LoadOpenCLFunctions(libopencl, false); - return absl::OkStatus(); + return OkStatus(); } else { // Pixel phone? libopencl = dlopen("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); @@ -45,9 +45,9 @@ absl::Status LoadOpenCL() { reinterpret_cast(dlsym(libopencl, "enableOpenCL")); enableOpenCL(); LoadOpenCLFunctions(libopencl, true); - return absl::OkStatus(); + return OkStatus(); } else { - return absl::UnknownError( + return UnknownError( absl::StrCat("OpenCL library not loaded - ", dlerror())); } } diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h index 2201b4c1e5d..16ae24437a3 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h @@ -27,7 +27,7 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status LoadOpenCL(); +Status LoadOpenCL(); void LoadOpenCLFunctions(void *libopencl, bool is_pixel); typedef cl_int(CL_API_CALL *PFN_clGetPlatformIDs)( diff --git a/tensorflow/lite/delegates/gpu/cl/program_cache.cc b/tensorflow/lite/delegates/gpu/cl/program_cache.cc index 285aa06d99b..e6735b448de 100644 --- a/tensorflow/lite/delegates/gpu/cl/program_cache.cc +++ b/tensorflow/lite/delegates/gpu/cl/program_cache.cc @@ -56,7 +56,7 @@ ProgramCache& ProgramCache::operator=(ProgramCache&& program_cache) { return *this; } -absl::Status ProgramCache::GetOrCreateCLKernel( +Status ProgramCache::GetOrCreateCLKernel( const std::string& code, const std::string& function_name, const std::vector& compiler_options, const CLContext& context, const CLDevice& device, CLKernel* result) { @@ -64,31 +64,32 @@ absl::Status ProgramCache::GetOrCreateCLKernel( ProgramDescriptor desc{code, options, use_fingerprints_}; auto it = programs_.find(desc); if (it != programs_.end()) { - return result->CreateFromProgram(it->second, function_name); + RETURN_IF_ERROR(result->CreateFromProgram(it->second, function_name)); + return OkStatus(); } CLProgram program; RETURN_IF_ERROR(CreateCLProgram(code, options, context, device, &program)); RETURN_IF_ERROR(result->CreateFromProgram(program, function_name)); programs_.insert(std::make_pair(std::move(desc), std::move(program))); - return absl::OkStatus(); + return OkStatus(); } -absl::Status ProgramCache::GetOrCreateCLKernel(const std::string& code, - const std::string& function_name, - const CLContext& context, - const CLDevice& device, - CLKernel* result) { +Status ProgramCache::GetOrCreateCLKernel(const std::string& code, + const std::string& function_name, + const CLContext& context, + const CLDevice& device, + CLKernel* result) { return GetOrCreateCLKernel(code, function_name, {}, context, device, result); } -absl::Status ProgramCache::AddSerializedCache( +Status ProgramCache::AddSerializedCache( const CLContext& context, const CLDevice& device, absl::Span serialized_cache) { flatbuffers::Verifier verifier(serialized_cache.data(), serialized_cache.size()); if (!data::VerifyCompiledCacheBuffer(verifier)) { - return absl::InvalidArgumentError("Serialized model is corrupted."); + return InvalidArgumentError("Serialized model is corrupted."); } auto model = data::GetCompiledCache(serialized_cache.data()); @@ -96,7 +97,7 @@ absl::Status ProgramCache::AddSerializedCache( model->driver_version()->size()); if (device.GetPlatformVersion() != platform_version) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "OpenCL driver changed, cache invalid, should be regenerated"); } @@ -115,10 +116,10 @@ absl::Status ProgramCache::AddSerializedCache( programs_.insert(std::make_pair(std::move(desc), std::move(program))); } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status ProgramCache::GetSerializedCache( +Status ProgramCache::GetSerializedCache( const CLDevice& device, std::vector* serialized_cache) const { ::flatbuffers::FlatBufferBuilder builder; std::vector> serialized_programs; @@ -139,9 +140,9 @@ absl::Status ProgramCache::GetSerializedCache( data::FinishCompiledCacheBuffer(builder, cache_builder.Finish()); size_t next_element = serialized_cache->size(); serialized_cache->resize(serialized_cache->size() + builder.GetSize()); - std::memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(), - builder.GetSize()); - return absl::OkStatus(); + memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(), + builder.GetSize()); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/program_cache.h b/tensorflow/lite/delegates/gpu/cl/program_cache.h index 21f9583a59a..b8d019d3d47 100644 --- a/tensorflow/lite/delegates/gpu/cl/program_cache.h +++ b/tensorflow/lite/delegates/gpu/cl/program_cache.h @@ -41,21 +41,20 @@ class ProgramCache { ProgramCache(const ProgramCache&) = delete; ProgramCache& operator=(const ProgramCache&) = delete; - absl::Status GetOrCreateCLKernel( + Status GetOrCreateCLKernel( const std::string& code, const std::string& function_name, const std::vector& compiler_options, const CLContext& context, const CLDevice& device, CLKernel* result); - absl::Status GetOrCreateCLKernel(const std::string& code, - const std::string& function_name, - const CLContext& context, - const CLDevice& device, CLKernel* result); + Status GetOrCreateCLKernel(const std::string& code, + const std::string& function_name, + const CLContext& context, const CLDevice& device, + CLKernel* result); - absl::Status AddSerializedCache(const CLContext& context, - const CLDevice& device, - absl::Span serialized_cache); - absl::Status GetSerializedCache(const CLDevice& device, - std::vector* serialized_cache) const; + Status AddSerializedCache(const CLContext& context, const CLDevice& device, + absl::Span serialized_cache); + Status GetSerializedCache(const CLDevice& device, + std::vector* serialized_cache) const; private: struct ProgramDescriptor { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc index d2d775f819f..a420373f50a 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc @@ -29,12 +29,11 @@ namespace gpu { namespace cl { namespace { -absl::Status SelectConvolutionAdreno(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, - ModelHints hints, - std::unique_ptr* ptr) { +Status SelectConvolutionAdreno(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + std::unique_ptr* ptr) { if (IsConvConstantsSupported(*creation_context.device, op_def, attr)) { ConvConstants conv; RETURN_IF_ERROR(CreateConvConstants(creation_context, op_def, attr, &conv)); @@ -44,24 +43,28 @@ absl::Status SelectConvolutionAdreno(const Convolution2DAttributes& attr, RETURN_IF_ERROR(CreateConvTexture(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); } - return absl::OkStatus(); + + return OkStatus(); } -absl::Status SelectConvolutionWinogradAdreno( - const Convolution2DAttributes& attr, const BHWC& dst_shape, - const CreationContext& creation_context, const OperationDef& op_def, - ModelHints hints, std::unique_ptr* ptr) { +Status SelectConvolutionWinogradAdreno(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + ModelHints hints, + std::unique_ptr* ptr) { ConvTexture conv; RETURN_IF_ERROR( CreateConvTextureWino4x4To6x6(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); - return absl::OkStatus(); + + return OkStatus(); } -absl::Status SelectConvolutionNVidia(const Convolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectConvolutionNVidia(const Convolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { if (IsConvConstantsSupported(*creation_context.device, op_def, attr)) { ConvConstants conv; RETURN_IF_ERROR(CreateConvConstants(creation_context, op_def, attr, &conv)); @@ -71,24 +74,24 @@ absl::Status SelectConvolutionNVidia(const Convolution2DAttributes& attr, RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectConvolutionPowerVR(const Convolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectConvolutionPowerVR(const Convolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { ConvPowerVR conv; RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectConvolutionMali(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectConvolutionMali(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER && IsConvBuffer1x1Supported(op_def, attr)) { ConvBuffer1x1 conv; @@ -101,13 +104,14 @@ absl::Status SelectConvolutionMali(const Convolution2DAttributes& attr, CreateConvPowerVR(creation_context, op_def, attr, &conv, &dst_shape)); *ptr = absl::make_unique(std::move(conv)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectConvolutionWinogradMali( - const Convolution2DAttributes& attr, const BHWC& dst_shape, - const CreationContext& creation_context, const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectConvolutionWinogradMali(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) { ConvBuffer1x1 conv; RETURN_IF_ERROR(CreateConvBuffer1x1Wino4x4To6x6(creation_context, op_def, @@ -119,16 +123,17 @@ absl::Status SelectConvolutionWinogradMali( attr, &conv, &dst_shape)); *ptr = absl::make_unique(std::move(conv)); } - return absl::OkStatus(); + + return OkStatus(); } } // namespace -absl::Status SelectConvolution(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - std::unique_ptr* ptr) { +Status SelectConvolution(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectConvolutionAdreno(attr, dst_shape, creation_context, op_def, @@ -147,10 +152,12 @@ absl::Status SelectConvolution(const Convolution2DAttributes& attr, } } -absl::Status SelectConvolutionForWinograd( - const Convolution2DAttributes& attr, const BHWC& dst_shape, - const CreationContext& creation_context, const OperationDef& op_def, - ModelHints hints, std::unique_ptr* ptr) { +Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + ModelHints hints, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectConvolutionWinogradAdreno(attr, dst_shape, creation_context, @@ -162,7 +169,7 @@ absl::Status SelectConvolutionForWinograd( RETURN_IF_ERROR( CreateConvPowerVRWino4x4To6x6(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); - return absl::OkStatus(); + return OkStatus(); } case Vendor::MALI: return SelectConvolutionWinogradMali(attr, dst_shape, creation_context, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h index 94723527ad5..dc0657ec47c 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h @@ -28,16 +28,18 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status SelectConvolution(const Convolution2DAttributes& attr, - const BHWC& dst_shape, - const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - std::unique_ptr* ptr); +Status SelectConvolution(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + std::unique_ptr* ptr); -absl::Status SelectConvolutionForWinograd( - const Convolution2DAttributes& attr, const BHWC& dst_shape, - const CreationContext& creation_context, const OperationDef& op_def, - ModelHints hints, std::unique_ptr* ptr); +Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + ModelHints hints, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc index 12e99b57aa7..8dd0ef6b3cb 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc @@ -28,7 +28,7 @@ namespace gpu { namespace cl { namespace { -absl::Status SelectConvolutionTransposedAdreno( +Status SelectConvolutionTransposedAdreno( const ConvolutionTransposedAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, std::unique_ptr* ptr) { @@ -49,10 +49,10 @@ absl::Status SelectConvolutionTransposedAdreno( CreateConvolutionTransposed(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectConvolutionTransposedPowerVR( +Status SelectConvolutionTransposedPowerVR( const ConvolutionTransposedAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, std::unique_ptr* ptr) { @@ -85,10 +85,10 @@ absl::Status SelectConvolutionTransposedPowerVR( CreateConvolutionTransposed(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectConvolutionTransposedMali( +Status SelectConvolutionTransposedMali( const ConvolutionTransposedAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, std::unique_ptr* ptr) { @@ -96,15 +96,14 @@ absl::Status SelectConvolutionTransposedMali( RETURN_IF_ERROR( CreateConvolutionTransposed(creation_context, op_def, attr, &conv)); *ptr = absl::make_unique(std::move(conv)); - return absl::OkStatus(); + return OkStatus(); } - } // namespace -absl::Status SelectConvolutionTransposed( - const ConvolutionTransposedAttributes& attr, - const CreationContext& creation_context, const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectConvolutionTransposed(const ConvolutionTransposedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectConvolutionTransposedAdreno(attr, creation_context, op_def, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h index ff37c1024ad..50f5e5baad5 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h @@ -26,10 +26,10 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status SelectConvolutionTransposed( - const ConvolutionTransposedAttributes& attr, - const CreationContext& creation_context, const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectConvolutionTransposed(const ConvolutionTransposedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc index e2a941870db..9fe7aa9732e 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc @@ -28,13 +28,12 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status SelectDefault(const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - const std::vector>*>& inputs, - const std::vector>*>& outputs, - const Node& node, - GPUOperationsSubgraph* gpu_subgraph) { - return absl::UnimplementedError( +Status SelectDefault(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const std::vector>*>& inputs, + const std::vector>*>& outputs, + const Node& node, GPUOperationsSubgraph* gpu_subgraph) { + return UnimplementedError( absl::StrCat("No selector for ", node.operation.type)); } diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h index 05e33501cd4..b4b996cc4fb 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h @@ -29,12 +29,11 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status SelectDefault(const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - const std::vector>*>& inputs, - const std::vector>*>& outputs, - const Node& node, - GPUOperationsSubgraph* gpu_subgraph); +Status SelectDefault(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const std::vector>*>& inputs, + const std::vector>*>& outputs, + const Node& node, GPUOperationsSubgraph* gpu_subgraph); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc index 0098117dea1..85afa3fff43 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc @@ -26,10 +26,10 @@ namespace gpu { namespace cl { namespace { -absl::Status SelectDWConvolutionAdreno( - const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectDWConvolutionAdreno(const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { if (!op_def.IsBatchSupported() && IsDepthWiseConv3x3Supported(attr)) { DepthWiseConv3x3 dw_conv; RETURN_IF_ERROR( @@ -41,13 +41,13 @@ absl::Status SelectDWConvolutionAdreno( CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); *ptr = absl::make_unique(std::move(dw_conv)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectDWConvolutionPowerVR( - const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectDWConvolutionPowerVR(const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { if (!op_def.IsBatchSupported() && IsDepthWiseConv3x3Supported(attr)) { DepthWiseConv3x3 dw_conv; RETURN_IF_ERROR( @@ -59,13 +59,13 @@ absl::Status SelectDWConvolutionPowerVR( CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); *ptr = absl::make_unique(std::move(dw_conv)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectDWConvolutionMali( - const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectDWConvolutionMali(const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { const auto storage_type = op_def.src_tensors[0].storage_type; bool buffer_type = storage_type == TensorStorageType::BUFFER || storage_type == TensorStorageType::IMAGE_BUFFER; @@ -83,14 +83,14 @@ absl::Status SelectDWConvolutionMali( CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); *ptr = absl::make_unique(std::move(dw_conv)); } - return absl::OkStatus(); + return OkStatus(); } } // namespace -absl::Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectDWConvolutionAdreno(attr, creation_context, op_def, ptr); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h index 7f7cc6da604..c15f2946495 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h @@ -26,10 +26,10 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc index 2a04a04460d..05d28b412ad 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc @@ -27,11 +27,10 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - int batch_size, - std::unique_ptr* ptr) { +Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, int batch_size, + std::unique_ptr* ptr) { if (op_def.IsBatchSupported()) { ConvTexture conv; RETURN_IF_ERROR(CreateConvTexture(creation_context, op_def, attr, &conv)); @@ -42,13 +41,13 @@ absl::Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr, CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique(std::move(fc)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectFullyConnectedPowerVR( - const FullyConnectedAttributes& attr, - const CreationContext& creation_context, const OperationDef& op_def, - int batch_size, std::unique_ptr* ptr) { +Status SelectFullyConnectedPowerVR(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, int batch_size, + std::unique_ptr* ptr) { if (op_def.IsBatchSupported()) { ConvPowerVR conv; RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv)); @@ -59,14 +58,13 @@ absl::Status SelectFullyConnectedPowerVR( CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique(std::move(fc)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - int batch_size, - std::unique_ptr* ptr) { +Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, int batch_size, + std::unique_ptr* ptr) { if (op_def.IsBatchSupported()) { if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) { ConvBuffer1x1 conv; @@ -84,13 +82,13 @@ absl::Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr, CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique(std::move(fc)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectFullyConnected(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, int batch_size, - std::unique_ptr* ptr) { +Status SelectFullyConnected(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, int batch_size, + std::unique_ptr* ptr) { switch (creation_context.device->vendor()) { case Vendor::QUALCOMM: return SelectFullyConnectedAdreno(attr, creation_context, op_def, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h index 4ae44490996..023020b6041 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h @@ -26,10 +26,10 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status SelectFullyConnected(const FullyConnectedAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, int batch_size, - std::unique_ptr* ptr); +Status SelectFullyConnected(const FullyConnectedAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, int batch_size, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index b0996aa53ea..2fcb90fc8d1 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -36,7 +36,6 @@ namespace tflite { namespace gpu { namespace cl { namespace { - bool IsWidthBroadcastedForSecondInput( const std::vector>*>& inputs) { return inputs.size() == 2 && @@ -75,14 +74,14 @@ bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr, return suitable_attributes && recommended_channels && recommended_hw; } -absl::Status WinogradFromNode(const CreationContext& creation_context, - const OperationDef& op_def, ModelHints hints, - const BHWC& input_shape, const BHWC& output_shape, - const Convolution2DAttributes& attr, - GPUOperationsSubgraph* gpu_subgraph) { +Status WinogradFromNode(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const BHWC& input_shape, const BHWC& output_shape, + const Convolution2DAttributes& attr, + GPUOperationsSubgraph* gpu_subgraph) { if (!IsSuitableForWinograd4x4To6x6(attr, *creation_context.device, output_shape)) { - return absl::UnimplementedError("No implementation for this case."); + return UnimplementedError("No implementation for this case."); } const int tiles_x = IntegralDivideRoundUp(output_shape.w, 4); @@ -141,16 +140,18 @@ absl::Status WinogradFromNode(const CreationContext& creation_context, } RETURN_IF_ERROR(SelectWinograd36To4x4(creation_context, winograd_down_def, bias_copy, &winograd_down.operation)); - return absl::OkStatus(); + + return OkStatus(); } } // namespace -absl::Status GPUOperationFromNode( - const CreationContext& creation_context, const OperationDef& op_def, - ModelHints hints, const std::vector>*>& inputs, - const std::vector>*>& outputs, const Node& node, - GPUOperationsSubgraph* gpu_subgraph) { +Status GPUOperationFromNode(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const std::vector>*>& inputs, + const std::vector>*>& outputs, + const Node& node, + GPUOperationsSubgraph* gpu_subgraph) { std::unique_ptr* gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph); auto op_type = OperationTypeFromString(node.operation.type); @@ -182,7 +183,7 @@ absl::Status GPUOperationFromNode( } SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op); } - return absl::OkStatus(); + return OkStatus(); } } case OperationType::CONCAT: { @@ -201,7 +202,7 @@ absl::Status GPUOperationFromNode( if (WinogradFromNode(creation_context, op_def, hints, input_shape, output_shape, attr, gpu_subgraph) .ok()) { - return absl::OkStatus(); + return OkStatus(); } else { gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph); return SelectConvolution(attr, output_shape, creation_context, op_def, @@ -227,13 +228,13 @@ absl::Status GPUOperationFromNode( } case OperationType::LSTM: { SelectLSTM(op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::MAX_UNPOOLING_2D: { auto attr = absl::any_cast(node.operation.attributes); SelectMaxUnpooling(attr, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::MEAN: { auto attr = absl::any_cast(node.operation.attributes); @@ -255,24 +256,24 @@ absl::Status GPUOperationFromNode( CreateElementwiseTwoInput(op_def, op_type, broadcast); *gpu_op = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } else { - return absl::UnimplementedError( + return UnimplementedError( "No support of multiply with more than 2 inputs"); } - return absl::OkStatus(); + return OkStatus(); } } case OperationType::PAD: { auto attr = absl::any_cast(node.operation.attributes); SelectPadding(attr, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::POOLING_2D: { auto attr = absl::any_cast(node.operation.attributes); SelectPooling(attr, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::PRELU: { auto attr = absl::any_cast(node.operation.attributes); @@ -287,13 +288,13 @@ absl::Status GPUOperationFromNode( case OperationType::RELU: { auto attr = absl::any_cast(node.operation.attributes); SelectReLU(creation_context, attr, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::RESHAPE: { const int src_channels = inputs[0]->tensor.shape.c; auto attr = absl::any_cast(node.operation.attributes); SelectReshape(src_channels, attr.new_shape.c, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::RESIZE: { auto attr = absl::any_cast(node.operation.attributes); @@ -302,23 +303,23 @@ absl::Status GPUOperationFromNode( case OperationType::SLICE: { auto attr = absl::any_cast(node.operation.attributes); SelectStridedSlice(attr, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::SOFTMAX: { SelectSoftmax(inputs[0]->tensor.shape, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::SPACE_TO_DEPTH: { auto attr = absl::any_cast(node.operation.attributes); SelectSpaceToDepth(attr, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::TRANSPOSE: { auto attr = absl::any_cast(node.operation.attributes); SelectTranspose(attr, op_def, gpu_op); - return absl::OkStatus(); + return OkStatus(); } case OperationType::ABS: case OperationType::COS: @@ -334,7 +335,7 @@ absl::Status GPUOperationFromNode( ElementwiseOneInput operation = CreateElementwiseOneInput(op_def, op_type); *gpu_op = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } case OperationType::DIV: case OperationType::MAXIMUM: @@ -351,7 +352,7 @@ absl::Status GPUOperationFromNode( ElementwiseTwoInput operation = CreateElementwiseTwoInput( creation_context, op_def, op_type, broadcast, attr); *gpu_op = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } default: return SelectDefault(creation_context, op_def, hints, inputs, outputs, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h index dd09c16dad0..bcb46c1e0c4 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h @@ -29,11 +29,12 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status GPUOperationFromNode( - const CreationContext& creation_context, const OperationDef& op_def, - ModelHints hints, const std::vector>*>& inputs, - const std::vector>*>& outputs, const Node& node, - GPUOperationsSubgraph* gpu_subgraph); +Status GPUOperationFromNode(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const std::vector>*>& inputs, + const std::vector>*>& outputs, + const Node& node, + GPUOperationsSubgraph* gpu_subgraph); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc index 44a88165e4c..ff26a3be601 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc @@ -59,14 +59,14 @@ void SelectReLU(const CreationContext& creation_context, *ptr = absl::make_unique(std::move(relu)); } -absl::Status SelectPReLU(const PReLUAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectPReLU(const PReLUAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { PReLU operation; RETURN_IF_ERROR(CreatePReLU(creation_context, op_def, attr, &operation)); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def, @@ -88,32 +88,31 @@ void SelectAdd(const OperationDef& op_def, const std::vector& channels, *ptr = absl::make_unique(std::move(operation)); } -absl::Status SelectResize(const Resize2DAttributes& attr, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectResize(const Resize2DAttributes& attr, const OperationDef& op_def, + std::unique_ptr* ptr) { Resize operation = CreateResize(op_def, attr); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectConcat(const ConcatAttributes& attr, - const std::vector& channels, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectConcat(const ConcatAttributes& attr, + const std::vector& channels, + const OperationDef& op_def, + std::unique_ptr* ptr) { switch (attr.axis) { case Axis::CHANNELS: { ConcatZ operation = CreateConcatZ(op_def, channels); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } case Axis::WIDTH: case Axis::HEIGHT: { ConcatXY operation = CreateConcatXY(op_def, attr, channels.size()); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } default: - return absl::UnimplementedError("No concat for this axis."); + return UnimplementedError("No concat for this axis."); } } @@ -148,36 +147,36 @@ void SelectStridedSlice(const SliceAttributes& attr, const OperationDef& op_def, *ptr = absl::make_unique(std::move(operation)); } -absl::Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, + std::unique_ptr* ptr) { if (attr.dims != std::set({Axis::HEIGHT, Axis::WIDTH})) { - return absl::UnimplementedError("Mean operation supports only HW plane"); + return UnimplementedError("Mean operation supports only HW plane"); } Mean operation = CreateMean(op_def); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectMultiplyScalar(const MultiplyAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectMultiplyScalar(const MultiplyAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { MultiplyAdd operation; RETURN_IF_ERROR( CreateMultiplyAdd(creation_context, op_def, attr, &operation)); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectBroadcastAdd(const AddAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectBroadcastAdd(const AddAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { MultiplyAdd operation; RETURN_IF_ERROR( CreateMultiplyAdd(creation_context, op_def, attr, &operation)); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } void SelectSoftmax(const BHWC& shape, const OperationDef& op_def, @@ -198,18 +197,18 @@ void SelectTranspose(const TransposeAttributes& attr, *ptr = absl::make_unique(std::move(operation)); } -absl::Status SelectWinograd4x4To36(const CreationContext& creation_context, - const Padding2D& padding, - const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectWinograd4x4To36(const CreationContext& creation_context, + const Padding2D& padding, + const OperationDef& op_def, + std::unique_ptr* ptr) { Winograd4x4To36 operation; RETURN_IF_ERROR( CreateWinograd4x4To36(creation_context, op_def, padding, &operation)); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectWinograd36To4x4( +Status SelectWinograd36To4x4( const CreationContext& creation_context, const OperationDef& op_def, const ::tflite::gpu::Tensor& biases, std::unique_ptr* ptr) { @@ -217,18 +216,18 @@ absl::Status SelectWinograd36To4x4( RETURN_IF_ERROR( CreateWinograd36To4x4(creation_context, op_def, biases, &operation)); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status SelectQuantizeAndDequantize( - const QuantizeAndDequantizeAttributes& attr, - const CreationContext& creation_context, const OperationDef& op_def, - std::unique_ptr* ptr) { +Status SelectQuantizeAndDequantize(const QuantizeAndDequantizeAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr) { QuantizeAndDequantize operation; RETURN_IF_ERROR( CreateQuantizeAndDequantize(creation_context, op_def, attr, &operation)); *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h index 118701fe9b0..d9a5365fc9e 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h @@ -33,10 +33,10 @@ void SelectReLU(const CreationContext& creation_context, const ReLUAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr); -absl::Status SelectPReLU(const PReLUAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectPReLU(const PReLUAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr); @@ -48,14 +48,13 @@ void SelectMaxUnpooling(const MaxUnpooling2DAttributes& attr, void SelectAdd(const OperationDef& op_def, const std::vector& channels, int dst_channels, std::unique_ptr* ptr); -absl::Status SelectResize(const Resize2DAttributes& attr, - const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectResize(const Resize2DAttributes& attr, const OperationDef& op_def, + std::unique_ptr* ptr); -absl::Status SelectConcat(const ConcatAttributes& attr, - const std::vector& channels, - const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectConcat(const ConcatAttributes& attr, + const std::vector& channels, + const OperationDef& op_def, + std::unique_ptr* ptr); void SelectReshape(int src_channels, int dst_channels, const OperationDef& op_def, @@ -67,18 +66,18 @@ void SelectPadding(const PadAttributes& attr, const OperationDef& op_def, void SelectStridedSlice(const SliceAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr); -absl::Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, + std::unique_ptr* ptr); -absl::Status SelectMultiplyScalar(const MultiplyAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectMultiplyScalar(const MultiplyAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); -absl::Status SelectBroadcastAdd(const AddAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectBroadcastAdd(const AddAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); void SelectSoftmax(const BHWC& shape, const OperationDef& op_def, std::unique_ptr* ptr); @@ -91,20 +90,20 @@ void SelectTranspose(const TransposeAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr); -absl::Status SelectWinograd4x4To36(const CreationContext& creation_context, - const Padding2D& padding, - const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectWinograd4x4To36(const CreationContext& creation_context, + const Padding2D& padding, + const OperationDef& op_def, + std::unique_ptr* ptr); -absl::Status SelectWinograd36To4x4( +Status SelectWinograd36To4x4( const CreationContext& creation_context, const OperationDef& op_def, const ::tflite::gpu::Tensor& biases, std::unique_ptr* ptr); -absl::Status SelectQuantizeAndDequantize( - const QuantizeAndDequantizeAttributes& attr, - const CreationContext& creation_context, const OperationDef& op_def, - std::unique_ptr* ptr); +Status SelectQuantizeAndDequantize(const QuantizeAndDequantizeAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr* ptr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc index f6201fa92ca..26eb3ad3538 100644 --- a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc +++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc @@ -24,7 +24,6 @@ limitations under the License. namespace tflite { namespace gpu { namespace cl { - bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, const BHWDC& shape, const TensorDescriptor& descriptor) { diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc index 308e1b69205..e9de22c6dc0 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc @@ -27,10 +27,9 @@ namespace tflite { namespace gpu { namespace cl { namespace { - -absl::Status CreateImageBufferFromBuffer(const CLContext& context, - cl_mem memory, enum DataType data_type, - int width, cl_mem* result) { +Status CreateImageBufferFromBuffer(const CLContext& context, cl_mem memory, + enum DataType data_type, int width, + cl_mem* result) { cl_image_format format; cl_image_desc desc; std::memset(&desc, 0, sizeof(desc)); @@ -45,17 +44,16 @@ absl::Status CreateImageBufferFromBuffer(const CLContext& context, *result = clCreateImage(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error); if (error != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to create Texture2D (clCreateImage)", CLErrorCodeToString(error))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWDC& shape, - const TensorDescriptor& descriptor, cl_mem memory, - Tensor* result) { +Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWDC& shape, const TensorDescriptor& descriptor, + cl_mem memory, Tensor* result) { const bool memory_owner = memory == nullptr; if (memory_owner) { CLMemory mem; @@ -74,9 +72,8 @@ absl::Status CreateTensor(const CLContext& context, const CLDevice& device, } else { *result = Tensor(memory, memory_owner, shape, descriptor); } - return absl::OkStatus(); + return OkStatus(); } - } // namespace Tensor::Tensor(cl_mem memory, bool memory_owner, const BHWC& shape, @@ -159,48 +156,41 @@ int3 Tensor::GetFullTensorRegion() const { } } -absl::Status Tensor::IsValid(const BHWC& shape) const { +Status Tensor::IsValid(const BHWC& shape) const { if (shape.b != shape_.b) { - return absl::InvalidArgumentError( - "Shape batch does not match tensor batch"); + return InvalidArgumentError("Shape batch does not match tensor batch"); } if (shape.w != shape_.w) { - return absl::InvalidArgumentError( - "Shape width does not match tensor width"); + return InvalidArgumentError("Shape width does not match tensor width"); } if (shape.h != shape_.h) { - return absl::InvalidArgumentError( - "Shape height does not match tensor height"); + return InvalidArgumentError("Shape height does not match tensor height"); } if (shape.c != shape_.c) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Shape channels does not match tensor channels"); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Tensor::IsValid(const BHWDC& shape) const { +Status Tensor::IsValid(const BHWDC& shape) const { if (shape.b != shape_.b) { - return absl::InvalidArgumentError( - "Shape batch does not match tensor batch"); + return InvalidArgumentError("Shape batch does not match tensor batch"); } if (shape.w != shape_.w) { - return absl::InvalidArgumentError( - "Shape width does not match tensor width"); + return InvalidArgumentError("Shape width does not match tensor width"); } if (shape.h != shape_.h) { - return absl::InvalidArgumentError( - "Shape height does not match tensor height"); + return InvalidArgumentError("Shape height does not match tensor height"); } if (shape.d != shape_.d) { - return absl::InvalidArgumentError( - "Shape depth does not match tensor depth"); + return InvalidArgumentError("Shape depth does not match tensor depth"); } if (shape.c != shape_.c) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Shape channels does not match tensor channels"); } - return absl::OkStatus(); + return OkStatus(); } int Tensor::GetChannelsAlignment() const { @@ -240,8 +230,8 @@ cl_mem Tensor::GetMemoryPtr() const { cl_mem Tensor::GetMemoryPtrForWriting() const { return memory_; } -absl::Status Tensor::WriteDataBHWDC(absl::Span in, - CLCommandQueue* queue) { +Status Tensor::WriteDataBHWDC(absl::Span in, + CLCommandQueue* queue) { void* data_ptr = nullptr; const int aligned_channels = GetAlignedChannels(); const int elements_count = @@ -273,26 +263,24 @@ absl::Status Tensor::WriteDataBHWDC(absl::Span in, queue->EnqueueWriteImage(memory_, GetFullTensorRegion(), data_ptr)); break; default: - return absl::InternalError("Unsupported tensor storage type"); + return InternalError("Unsupported tensor storage type"); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Tensor::WriteData(CLCommandQueue* queue, - const TensorFloat32& src) { +Status Tensor::WriteData(CLCommandQueue* queue, const TensorFloat32& src) { RETURN_IF_ERROR(IsValid(src.shape)); return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue); } -absl::Status Tensor::WriteData(CLCommandQueue* queue, - const Tensor5DFloat32& src) { +Status Tensor::WriteData(CLCommandQueue* queue, const Tensor5DFloat32& src) { RETURN_IF_ERROR(IsValid(src.shape)); return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue); } -absl::Status Tensor::ReadDataBHWDC(absl::Span out, - CLCommandQueue* queue) const { +Status Tensor::ReadDataBHWDC(absl::Span out, + CLCommandQueue* queue) const { void* data_ptr = nullptr; const int aligned_channels = GetAlignedChannels(); const int elements_count = @@ -321,7 +309,7 @@ absl::Status Tensor::ReadDataBHWDC(absl::Span out, queue->EnqueueReadImage(memory_, GetFullTensorRegion(), data_ptr)); break; default: - return absl::InternalError("Unsupported tensor storage type"); + return InternalError("Unsupported tensor storage type"); } if (descriptor_.data_type == DataType::FLOAT32) { @@ -330,62 +318,57 @@ absl::Status Tensor::ReadDataBHWDC(absl::Span out, DataToBHWDC(absl::MakeConstSpan(data_h.data(), data_h.size()), out); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Tensor::ReadData(CLCommandQueue* queue, TensorFloat32* dst) const { +Status Tensor::ReadData(CLCommandQueue* queue, TensorFloat32* dst) const { RETURN_IF_ERROR(IsValid(dst->shape)); return ReadDataBHWDC(absl::MakeSpan(dst->data), queue); } -absl::Status Tensor::ReadData(CLCommandQueue* queue, - Tensor5DFloat32* dst) const { +Status Tensor::ReadData(CLCommandQueue* queue, Tensor5DFloat32* dst) const { RETURN_IF_ERROR(IsValid(dst->shape)); return ReadDataBHWDC(absl::MakeSpan(dst->data), queue); } -absl::Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWC& shape, const TensorDescriptor& descriptor, - Tensor* result) { +Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWC& shape, const TensorDescriptor& descriptor, + Tensor* result) { const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c); return CreateTensor(context, device, shape5D, descriptor, nullptr, result); } -absl::Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWDC& shape, - const TensorDescriptor& descriptor, Tensor* result) { +Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWDC& shape, const TensorDescriptor& descriptor, + Tensor* result) { return CreateTensor(context, device, shape, descriptor, nullptr, result); } -absl::Status CreateSharedTensor(const CLContext& context, - const CLDevice& device, cl_mem memory, - const BHWC& shape, - const TensorDescriptor& descriptor, - Tensor* result) { +Status CreateSharedTensor(const CLContext& context, const CLDevice& device, + cl_mem memory, const BHWC& shape, + const TensorDescriptor& descriptor, Tensor* result) { const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c); return CreateTensor(context, device, shape5D, descriptor, memory, result); } -absl::Status CreateSharedTensor(const CLContext& context, - const CLDevice& device, cl_mem memory, - const BHWDC& shape, - const TensorDescriptor& descriptor, - Tensor* result) { +Status CreateSharedTensor(const CLContext& context, const CLDevice& device, + cl_mem memory, const BHWDC& shape, + const TensorDescriptor& descriptor, Tensor* result) { return CreateTensor(context, device, shape, descriptor, memory, result); } -absl::Status AllocateTensorMemory(const CLContext& context, - const CLDevice& device, const BHWC& shape, - const TensorDescriptor& descriptor, - CLMemory* result) { +Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, + const BHWC& shape, + const TensorDescriptor& descriptor, + CLMemory* result) { const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c); return AllocateTensorMemory(context, device, shape5D, descriptor, result); } -absl::Status AllocateTensorMemory(const CLContext& context, - const CLDevice& device, const BHWDC& shape, - const TensorDescriptor& descriptor, - CLMemory* result) { +Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, + const BHWDC& shape, + const TensorDescriptor& descriptor, + CLMemory* result) { const int slices = IntegralDivideRoundUp(shape.c, 4); switch (descriptor.storage_type) { case TensorStorageType::BUFFER: @@ -396,12 +379,12 @@ absl::Status AllocateTensorMemory(const CLContext& context, cl_mem memory = clCreateBuffer(context.context(), CL_MEM_READ_WRITE, data_size, nullptr, &error_code); if (!memory) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to allocate device memory with clCreateBuffer", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return absl::OkStatus(); + return OkStatus(); } case TensorStorageType::TEXTURE_2D: { cl_image_desc desc; @@ -423,13 +406,13 @@ absl::Status AllocateTensorMemory(const CLContext& context, cl_mem memory = CreateImage2DLegacy(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to create Texture2D (clCreateImage)", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return absl::OkStatus(); + return OkStatus(); } case TensorStorageType::TEXTURE_3D: { cl_image_desc desc; @@ -451,13 +434,13 @@ absl::Status AllocateTensorMemory(const CLContext& context, cl_mem memory = CreateImage3DLegacy(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to create Texture3D (clCreateImage)", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return absl::OkStatus(); + return OkStatus(); } case TensorStorageType::TEXTURE_ARRAY: { cl_image_desc desc; @@ -480,18 +463,18 @@ absl::Status AllocateTensorMemory(const CLContext& context, cl_mem memory = clCreateImage(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to create TextureArray (clCreateImage)", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return absl::OkStatus(); + return OkStatus(); } case TensorStorageType::SINGLE_TEXTURE_2D: { if (slices != 1) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "SINGLE_TEXTURE_2D support only channels in range [1-4], but ", shape.c, "was provided")); } @@ -512,7 +495,7 @@ absl::Status AllocateTensorMemory(const CLContext& context, format.image_channel_data_type = ToImageChannelType(descriptor.data_type); } else { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "This device doesn't support ", shape.c, "-channel textures.")); } @@ -520,17 +503,17 @@ absl::Status AllocateTensorMemory(const CLContext& context, cl_mem memory = CreateImage2DLegacy(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to create Texture2D (clCreateImage)", CLErrorCodeToString(error_code))); } *result = CLMemory(memory, true); - return absl::OkStatus(); + return OkStatus(); } default: - return absl::InternalError("Unsupported tensor storage type"); + return InternalError("Unsupported tensor storage type"); } } diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h index a27c54a74e5..34a45436386 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.h +++ b/tensorflow/lite/delegates/gpu/cl/tensor.h @@ -87,22 +87,20 @@ class Tensor { // memory ptr. cl_mem GetMemoryPtrForWriting() const; - absl::Status WriteData(CLCommandQueue* queue, const TensorFloat32& src); - absl::Status WriteData(CLCommandQueue* queue, const Tensor5DFloat32& src); - absl::Status ReadData(CLCommandQueue* queue, TensorFloat32* dst) const; - absl::Status ReadData(CLCommandQueue* queue, Tensor5DFloat32* dst) const; + Status WriteData(CLCommandQueue* queue, const TensorFloat32& src); + Status WriteData(CLCommandQueue* queue, const Tensor5DFloat32& src); + Status ReadData(CLCommandQueue* queue, TensorFloat32* dst) const; + Status ReadData(CLCommandQueue* queue, Tensor5DFloat32* dst) const; private: - absl::Status IsValid(const BHWC& shape) const; - absl::Status IsValid(const BHWDC& shape) const; + Status IsValid(const BHWC& shape) const; + Status IsValid(const BHWDC& shape) const; int GetChannelsAlignment() const; int GetAlignedChannels() const; - absl::Status WriteDataBHWDC(absl::Span in, - CLCommandQueue* queue); - absl::Status ReadDataBHWDC(absl::Span out, - CLCommandQueue* queue) const; + Status WriteDataBHWDC(absl::Span in, CLCommandQueue* queue); + Status ReadDataBHWDC(absl::Span out, CLCommandQueue* queue) const; template void DataFromBHWDC(absl::Span src, absl::Span dst) const; @@ -147,35 +145,31 @@ class Tensor { using TensorPtr = std::shared_ptr; -absl::Status AllocateTensorMemory(const CLContext& context, - const CLDevice& device, const BHWC& shape, - const TensorDescriptor& descriptor, - CLMemory* result); +Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, + const BHWC& shape, + const TensorDescriptor& descriptor, + CLMemory* result); -absl::Status AllocateTensorMemory(const CLContext& context, - const CLDevice& device, const BHWDC& shape, - const TensorDescriptor& descriptor, - CLMemory* result); +Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, + const BHWDC& shape, + const TensorDescriptor& descriptor, + CLMemory* result); -absl::Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWC& shape, const TensorDescriptor& descriptor, - Tensor* result); +Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWC& shape, const TensorDescriptor& descriptor, + Tensor* result); -absl::Status CreateTensor(const CLContext& context, const CLDevice& device, - const BHWDC& shape, +Status CreateTensor(const CLContext& context, const CLDevice& device, + const BHWDC& shape, const TensorDescriptor& descriptor, + Tensor* result); + +Status CreateSharedTensor(const CLContext& context, const CLDevice& device, + cl_mem memory, const BHWC& shape, const TensorDescriptor& descriptor, Tensor* result); -absl::Status CreateSharedTensor(const CLContext& context, - const CLDevice& device, cl_mem memory, - const BHWC& shape, - const TensorDescriptor& descriptor, - Tensor* result); - -absl::Status CreateSharedTensor(const CLContext& context, - const CLDevice& device, cl_mem memory, - const BHWDC& shape, - const TensorDescriptor& descriptor, - Tensor* result); +Status CreateSharedTensor(const CLContext& context, const CLDevice& device, + cl_mem memory, const BHWDC& shape, + const TensorDescriptor& descriptor, Tensor* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_test.cc b/tensorflow/lite/delegates/gpu/cl/tensor_test.cc index 99ba269cf60..7c859c43e6e 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor_test.cc @@ -30,9 +30,8 @@ namespace gpu { namespace cl { namespace { -absl::Status TensorGenericTest(const BHWC& shape, - const TensorDescriptor& descriptor, - Environment* env) { +Status TensorGenericTest(const BHWC& shape, const TensorDescriptor& descriptor, + Environment* env) { TensorFloat32 tensor_cpu; tensor_cpu.shape = shape; tensor_cpu.data.resize(shape.DimensionsProduct()); @@ -54,15 +53,15 @@ absl::Status TensorGenericTest(const BHWC& shape, for (int i = 0; i < tensor_gpu.data.size(); ++i) { if (tensor_gpu.data[i] != tensor_cpu.data[i]) { - return absl::InternalError("Wrong value."); + return InternalError("Wrong value."); } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Tensor5DGenericTest(const BHWDC& shape, - const TensorDescriptor& descriptor, - Environment* env) { +Status Tensor5DGenericTest(const BHWDC& shape, + const TensorDescriptor& descriptor, + Environment* env) { Tensor5DFloat32 tensor_cpu; tensor_cpu.shape = shape; tensor_cpu.data.resize(shape.DimensionsProduct()); @@ -84,14 +83,14 @@ absl::Status Tensor5DGenericTest(const BHWDC& shape, for (int i = 0; i < tensor_gpu.data.size(); ++i) { if (tensor_gpu.data[i] != tensor_cpu.data[i]) { - return absl::InternalError("Wrong value."); + return InternalError("Wrong value."); } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status TensorTests(DataType data_type, TensorStorageType storage_type, - Environment* env) { +Status TensorTests(DataType data_type, TensorStorageType storage_type, + Environment* env) { RETURN_IF_ERROR(TensorGenericTest( BHWC(1, 6, 7, 3), {data_type, storage_type, Layout::HWC}, env)); RETURN_IF_ERROR(TensorGenericTest( @@ -126,7 +125,7 @@ absl::Status TensorTests(DataType data_type, TensorStorageType storage_type, BHWDC(7, 6, 1, 3, 7), {data_type, storage_type, Layout::BHWDC}, env)); RETURN_IF_ERROR(Tensor5DGenericTest( BHWDC(13, 7, 3, 4, 3), {data_type, storage_type, Layout::BHWDC}, env)); - return absl::OkStatus(); + return OkStatus(); } TEST_F(OpenCLTest, BufferF32) { diff --git a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc index 151924197c2..f231cf3143a 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc +++ b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc @@ -45,11 +45,10 @@ class DelegateContext { const TfLiteDelegateParams* delegate_params) { auto denormalized_graph = reinterpret_cast(delegate_params->delegate->data_); - absl::Status status = - BuildModel(context, delegate_params, denormalized_graph); + Status status = BuildModel(context, delegate_params, denormalized_graph); if (!status.ok()) { context->ReportError(context, "Failed to convert a model: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); } return status.ok(); } @@ -83,14 +82,14 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { return status; } -absl::Status FlatBufferToGPUGraph( +Status FlatBufferToGPUGraph( const std::unique_ptr& flatbuffer, GraphFloat32* graph) { tflite::ops::builtin::BuiltinOpResolver op_resolver; std::unique_ptr interpreter; tflite::InterpreterBuilder interpreter_builder(*flatbuffer, op_resolver); if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) { - return absl::InternalError("Unable to prepare TfLite interpreter."); + return InternalError("Unable to prepare TfLite interpreter."); } interpreter->UseNNAPI(false); TfLiteDelegate delegate; @@ -102,20 +101,20 @@ absl::Status FlatBufferToGPUGraph( delegate.FreeBufferHandle = nullptr; if (interpreter->ModifyGraphWithDelegate(&delegate) != kTfLiteOk) { - return absl::InternalError("Conversion from TfLite model failed."); + return InternalError("Conversion from TfLite model failed."); } NullTransformationReporter reporter; ModelTransformer transformer(graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + return InternalError("Graph general transformations failed"); } - return absl::OkStatus(); + return OkStatus(); } } // namespace -absl::Status RunModelSample(const std::string& model_name) { +Status RunModelSample(const std::string& model_name) { auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str()); GraphFloat32 graph_cl; RETURN_IF_ERROR(FlatBufferToGPUGraph(flatbuffer, &graph_cl)); @@ -161,7 +160,7 @@ absl::Status RunModelSample(const std::string& model_name) { std::cout << "Total time - " << average_inference_time << "ms" << std::endl; } - return absl::OkStatus(); + return OkStatus(); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/texture2d.cc b/tensorflow/lite/delegates/gpu/cl/texture2d.cc index 022c15660ce..907721dad8c 100644 --- a/tensorflow/lite/delegates/gpu/cl/texture2d.cc +++ b/tensorflow/lite/delegates/gpu/cl/texture2d.cc @@ -21,9 +21,8 @@ namespace cl { namespace { // Creates new 4-channel 2D texture with cl_channel_type elements -absl::Status CreateTexture2D(int width, int height, cl_channel_type type, - void* data, CLContext* context, - Texture2D* result) { +Status CreateTexture2D(int width, int height, cl_channel_type type, void* data, + CLContext* context, Texture2D* result) { cl_image_desc desc; desc.image_type = CL_MEM_OBJECT_IMAGE2D; desc.image_width = width; @@ -48,14 +47,14 @@ absl::Status CreateTexture2D(int width, int height, cl_channel_type type, cl_mem texture = CreateImage2DLegacy(context->context(), flags, &format, &desc, data, &error_code); if (error_code != CL_SUCCESS) { - return absl::UnknownError( + return UnknownError( absl::StrCat("Failed to create Texture2D (clCreateImage)", CLErrorCodeToString(error_code))); } *result = Texture2D(texture, width, height, type); - return absl::OkStatus(); + return OkStatus(); } } // namespace @@ -96,20 +95,20 @@ void Texture2D::Release() { } // Creates new 4-channel 2D texture with f32 elements -absl::Status CreateTexture2DRGBA32F(int width, int height, CLContext* context, - Texture2D* result) { +Status CreateTexture2DRGBA32F(int width, int height, CLContext* context, + Texture2D* result) { return CreateTexture2D(width, height, CL_FLOAT, nullptr, context, result); } // Creates new 4-channel 2D texture with f16 elements -absl::Status CreateTexture2DRGBA16F(int width, int height, CLContext* context, - Texture2D* result) { +Status CreateTexture2DRGBA16F(int width, int height, CLContext* context, + Texture2D* result) { return CreateTexture2D(width, height, CL_HALF_FLOAT, nullptr, context, result); } -absl::Status CreateTexture2DRGBA(DataType type, int width, int height, - CLContext* context, Texture2D* result) { +Status CreateTexture2DRGBA(DataType type, int width, int height, + CLContext* context, Texture2D* result) { if (type == DataType::FLOAT32) { return CreateTexture2D(width, height, CL_FLOAT, nullptr, context, result); } else { @@ -118,9 +117,8 @@ absl::Status CreateTexture2DRGBA(DataType type, int width, int height, } } -absl::Status CreateTexture2DRGBA(DataType type, int width, int height, - void* data, CLContext* context, - Texture2D* result) { +Status CreateTexture2DRGBA(DataType type, int width, int height, void* data, + CLContext* context, Texture2D* result) { if (type == DataType::FLOAT32) { return CreateTexture2D(width, height, CL_FLOAT, data, context, result); } else { diff --git a/tensorflow/lite/delegates/gpu/cl/texture2d.h b/tensorflow/lite/delegates/gpu/cl/texture2d.h index c12d8a2836c..bdac984a2db 100644 --- a/tensorflow/lite/delegates/gpu/cl/texture2d.h +++ b/tensorflow/lite/delegates/gpu/cl/texture2d.h @@ -50,11 +50,11 @@ class Texture2D { // Writes data to a texture. Data should point to a region that // has exact width * height * sizeof(pixel) bytes. template - absl::Status WriteData(CLCommandQueue* queue, const absl::Span data); + Status WriteData(CLCommandQueue* queue, const absl::Span data); // Reads data from Texture2D into CPU memory. template - absl::Status ReadData(CLCommandQueue* queue, std::vector* result) const; + Status ReadData(CLCommandQueue* queue, std::vector* result) const; private: void Release(); @@ -68,45 +68,43 @@ class Texture2D { using Texture2DPtr = std::shared_ptr; // Creates new 4-channel 2D texture with f32 elements -absl::Status CreateTexture2DRGBA32F(int width, int height, CLContext* context, - Texture2D* result); +Status CreateTexture2DRGBA32F(int width, int height, CLContext* context, + Texture2D* result); // Creates new 4-channel 2D texture with f16 elements -absl::Status CreateTexture2DRGBA16F(int width, int height, CLContext* context, - Texture2D* result); +Status CreateTexture2DRGBA16F(int width, int height, CLContext* context, + Texture2D* result); -absl::Status CreateTexture2DRGBA(DataType type, int width, int height, - CLContext* context, Texture2D* result); +Status CreateTexture2DRGBA(DataType type, int width, int height, + CLContext* context, Texture2D* result); -absl::Status CreateTexture2DRGBA(DataType type, int width, int height, - void* data, CLContext* context, - Texture2D* result); +Status CreateTexture2DRGBA(DataType type, int width, int height, void* data, + CLContext* context, Texture2D* result); template -absl::Status Texture2D::WriteData(CLCommandQueue* queue, - const absl::Span data) { +Status Texture2D::WriteData(CLCommandQueue* queue, const absl::Span data) { const int element_size = ChannelTypeToSizeInBytes(channel_type_); if (sizeof(T) % element_size != 0) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Template type T has not suitable element type for created texture."); } if (4 * width_ * height_ * element_size != data.size() * sizeof(T)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "absl::Span data size is different from texture allocated size."); } RETURN_IF_ERROR(queue->EnqueueWriteImage(texture_, int3(width_, height_, 1), data.data())); - return absl::OkStatus(); + return OkStatus(); } template -absl::Status Texture2D::ReadData(CLCommandQueue* queue, - std::vector* result) const { +Status Texture2D::ReadData(CLCommandQueue* queue, + std::vector* result) const { const int element_size = ChannelTypeToSizeInBytes(channel_type_); if (sizeof(T) != element_size) { - return absl::InvalidArgumentError("Pixel format is different."); + return InvalidArgumentError("Pixel format is different."); } const int elements_count = width_ * height_ * 4; diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 30ac016ff83..08612e37b3e 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -24,8 +24,8 @@ cc_library( srcs = ["custom_parsers.cc"], hdrs = ["custom_parsers.h"], deps = [ - ":shape", - ":status", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", "@flatbuffers", @@ -193,7 +193,6 @@ cc_test( cc_library( name = "status", hdrs = ["status.h"], - deps = ["@com_google_absl//absl/status"], ) cc_library( diff --git a/tensorflow/lite/delegates/gpu/common/convert.cc b/tensorflow/lite/delegates/gpu/common/convert.cc index cee2e8f0e60..81d09b2797e 100644 --- a/tensorflow/lite/delegates/gpu/common/convert.cc +++ b/tensorflow/lite/delegates/gpu/common/convert.cc @@ -30,15 +30,15 @@ constexpr int kPhwo4i4ChannelsInPlane = 4; constexpr int kPiohw4ChannelsInPlane = 4; // Layout is Po,H,W,OI4x4. -absl::Status ConvertToPHWO4I4(absl::Span in, const OHWI& shape, - absl::Span out, bool reverse_space) { +Status ConvertToPHWO4I4(absl::Span in, const OHWI& shape, + absl::Span out, bool reverse_space) { if (in.size() != shape.DimensionsProduct()) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertToPHWO4I4: Input data size does not match expected size: ", in.size(), " != ", shape.DimensionsProduct())); } if (out.size() != GetElementsSizeForPHWO4I4(shape)) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertToPHWO4I4: Output data size does not match expected size: ", out.size(), " != ", GetElementsSizeForPHWO4I4(shape))); } @@ -69,7 +69,7 @@ absl::Status ConvertToPHWO4I4(absl::Span in, const OHWI& shape, } } } - return absl::OkStatus(); + return OkStatus(); } } // namespace @@ -110,15 +110,15 @@ uint3 Get3DSizeForPHWO4I4(const OHWI& shape) { } // Layout is Po,H,W,OI4x4. -absl::Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, - absl::Span out) { +Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, + absl::Span out) { if (in.size() != shape.DimensionsProduct()) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertToPHWO4I4: Input data size does not match expected size: ", in.size(), " != ", shape.DimensionsProduct())); } if (out.size() != GetElementsSizeForPHWO4I4(shape)) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertToPHWO4I4: Output data size does not match expected size: ", out.size(), " != ", GetElementsSizeForPHWO4I4(shape))); } @@ -147,7 +147,7 @@ absl::Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, } } } - return absl::OkStatus(); + return OkStatus(); } std::vector ConvertToPHWO4I4( @@ -164,15 +164,15 @@ uint32_t GetElementsSizeForPIOHW4(const OHWI& shape) { shape.w; } -absl::Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, - absl::Span out) { +Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, + absl::Span out) { if (in.size() != shape.DimensionsProduct()) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertToPIOHW4: Input data size does not match expected size: ", in.size(), " != ", shape.DimensionsProduct())); } if (out.size() != GetElementsSizeForPIOHW4(shape)) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertToPIOHW4: Output data size does not match expected size: ", out.size(), " != ", GetElementsSizeForPIOHW4(shape))); } @@ -194,7 +194,7 @@ absl::Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, } } } - return absl::OkStatus(); + return OkStatus(); } std::vector ConvertToPIOHW4( @@ -207,29 +207,29 @@ std::vector ConvertToPIOHW4( } template -absl::Status ValidateConvertToPHWC4(absl::Span in, - const BHWC& shape, absl::Span out) { +Status ValidateConvertToPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { if (in.size() != shape.DimensionsProduct()) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertToPHWC4: Input data size does not match expected size: ", in.size(), " != ", shape.DimensionsProduct())); } if (out.size() != GetElementsSizeForPHWC4(shape)) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertToPHWC4: Output data size does not match expected size: ", out.size(), " != ", GetElementsSizeForPHWC4(shape))); } - return absl::OkStatus(); + return OkStatus(); } // Layout is Pc,H,W,C4 where P - is a plane based on channels. -absl::Status ConvertToPHWC4(absl::Span in, const BHWC& shape, - absl::Span out) { +Status ConvertToPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { RETURN_IF_ERROR(ValidateConvertToPHWC4(in, shape, out)); if (shape.c == 4) { std::memcpy(out.data(), in.data(), shape.DimensionsProduct() * sizeof(float)); - return absl::OkStatus(); + return OkStatus(); } // Layout is Pc,H,W,C4 where P - is a plane based on channels. int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); @@ -256,7 +256,7 @@ absl::Status ConvertToPHWC4(absl::Span in, const BHWC& shape, const int remaining_channels = shape.c - num_full_planes * kPhwc4ChannelsInPlane; if (remaining_channels == 0) { - return absl::OkStatus(); + return OkStatus(); } for (int b = 0; b < shape.b; b++) { const float* src = @@ -272,12 +272,12 @@ absl::Status ConvertToPHWC4(absl::Span in, const BHWC& shape, dest += kPhwc4ChannelsInPlane; } } - return absl::OkStatus(); + return OkStatus(); } // Layout is Pc,H,W,C4 where P - is a plane based on channels. -absl::Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, - absl::Span out) { +Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, + absl::Span out) { RETURN_IF_ERROR(ValidateConvertToPHWC4(in, shape, out)); // Layout is Pc,H,W,C4 where P - is a plane based on channels. @@ -308,7 +308,7 @@ absl::Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, const int remaining_channels = shape.c - num_full_planes * kPhwc4ChannelsInPlane; if (remaining_channels == 0) { - return absl::OkStatus(); + return OkStatus(); } for (int b = 0; b < shape.b; b++) { @@ -349,11 +349,11 @@ absl::Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, } break; default: - return absl::UnimplementedError( + return UnimplementedError( "ConvertToPHWC4Half: Unsupported channels per planes count."); } } - return absl::OkStatus(); + return OkStatus(); } std::vector ConvertToPHWC4( @@ -383,28 +383,28 @@ uint32_t GetElementsSizeForPHWC4(const BHWC& shape) { } template -absl::Status ValidateConvertFromPHWC4(absl::Span in, const BHWC& shape, - absl::Span out) { +Status ValidateConvertFromPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { if (in.size() != GetElementsSizeForPHWC4(shape)) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertFromPHWC4: Input data size does not match expected size: ", in.size(), " != ", GetElementsSizeForPHWC4(shape))); } if (out.size() != shape.DimensionsProduct()) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "ConvertFromPHWC4: Output data size does not match expected size: ", out.size(), " != ", shape.DimensionsProduct())); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, - absl::Span out) { +Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { RETURN_IF_ERROR(ValidateConvertFromPHWC4(in, shape, out)); if (shape.c == 4) { std::memcpy(out.data(), in.data(), shape.DimensionsProduct() * sizeof(float)); - return absl::OkStatus(); + return OkStatus(); } int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); @@ -429,7 +429,7 @@ absl::Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, const int remaining_channels = shape.c - num_full_planes * kPhwc4ChannelsInPlane; if (remaining_channels == 0) { - return absl::OkStatus(); + return OkStatus(); } for (int b = 0; b < shape.b; b++) { const float* src = in.data() + b * padded_size + @@ -443,11 +443,11 @@ absl::Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, dest += shape.c; } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status ConvertFromPHWC4Half(absl::Span in, - const BHWC& shape, absl::Span out) { +Status ConvertFromPHWC4Half(absl::Span in, const BHWC& shape, + absl::Span out) { RETURN_IF_ERROR(ValidateConvertFromPHWC4(in, shape, out)); int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); const int num_pixels = shape.h * shape.w; @@ -474,7 +474,7 @@ absl::Status ConvertFromPHWC4Half(absl::Span in, const int remaining_channels = shape.c - num_full_planes * kPhwc4ChannelsInPlane; if (remaining_channels == 0) { - return absl::OkStatus(); + return OkStatus(); } for (int b = 0; b < shape.b; b++) { const HalfBits* src = in.data() + b * padded_size + @@ -508,11 +508,11 @@ absl::Status ConvertFromPHWC4Half(absl::Span in, } break; default: - return absl::UnimplementedError( + return UnimplementedError( "ConvertToPHWC4Half: Unsupported channels per planes count."); } } - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/convert.h b/tensorflow/lite/delegates/gpu/common/convert.h index 3aba9c913c5..30a0a5f3183 100644 --- a/tensorflow/lite/delegates/gpu/common/convert.h +++ b/tensorflow/lite/delegates/gpu/common/convert.h @@ -29,19 +29,19 @@ namespace gpu { // PHWC4 layout is where channels are grouped by 4 in a row and P stands for // a plane that was derived by dividing channels by 4. -absl::Status ConvertToPHWC4(absl::Span in, const BHWC& shape, - absl::Span out); -absl::Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, - absl::Span out); +Status ConvertToPHWC4(absl::Span in, const BHWC& shape, + absl::Span out); +Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, + absl::Span out); // @return number of elements when shape is converted into PHWC4. uint32_t GetElementsSizeForPHWC4(const BHWC& shape); // Operation is opposite to ConvertToPHWC4. -absl::Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, - absl::Span out); -absl::Status ConvertFromPHWC4Half(absl::Span in, - const BHWC& shape, absl::Span out); +Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, + absl::Span out); +Status ConvertFromPHWC4Half(absl::Span in, const BHWC& shape, + absl::Span out); // Convenience wrapper around a method above. std::vector ConvertToPHWC4( @@ -53,8 +53,8 @@ uint32_t GetElementsSizeForPIOHW4(const OHWI& shape); // PIOHW4 layout re-arranges weights in groups by 4, where outer dimension is // P which is OxI/4. -absl::Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, - absl::Span out); +Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, + absl::Span out); // Convenience wrapper around a method above. std::vector ConvertToPIOHW4( @@ -79,8 +79,8 @@ uint3 Get3DSizeForPHWO4I4(const OHWI& shape); uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape); // Layout is Po,H,W,OI4x4. -absl::Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, - absl::Span out); +Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, + absl::Span out); // Convenience wrapper around a method above. std::vector ConvertToPHWO4I4( diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.cc b/tensorflow/lite/delegates/gpu/common/custom_parsers.cc index e43cba05525..d46a9247c81 100644 --- a/tensorflow/lite/delegates/gpu/common/custom_parsers.cc +++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.cc @@ -25,10 +25,10 @@ limitations under the License. namespace tflite { namespace gpu { -absl::Status ParseCustomAttributes(absl::string_view op_name, const void* data, - uint32_t data_size, absl::any* attr, - BHWC* output_shape) { - return absl::UnimplementedError(absl::StrCat( +Status ParseCustomAttributes(absl::string_view op_name, const void* data, + uint32_t data_size, absl::any* attr, + BHWC* output_shape) { + return UnimplementedError(absl::StrCat( "Attributes parsing is not enabled for ", op_name, " operation")); } diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.h b/tensorflow/lite/delegates/gpu/common/custom_parsers.h index 707087e6fdb..e9a191d46cb 100644 --- a/tensorflow/lite/delegates/gpu/common/custom_parsers.h +++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.h @@ -27,9 +27,9 @@ namespace gpu { // Matches the custom operation by the string name and parses attributes stored // as flexbuffers. -absl::Status ParseCustomAttributes(absl::string_view op_name, const void* data, - uint32_t data_size, absl::any* attr, - BHWC* output_shape); +Status ParseCustomAttributes(absl::string_view op_name, const void* data, + uint32_t data_size, absl::any* attr, + BHWC* output_shape); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.cc b/tensorflow/lite/delegates/gpu/common/memory_management.cc index d7e6a060eb2..5cfd26b1832 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management.cc @@ -55,9 +55,8 @@ OffsetsAssignment ObjectsToOffsets( return result; } -absl::Status BestGreedy( - const std::vector>& usage_records, - ObjectsAssignment* assignment) { +Status BestGreedy(const std::vector>& usage_records, + ObjectsAssignment* assignment) { RETURN_IF_ERROR( GreedyBySizeDistPriorityAssignment(usage_records, assignment)); ObjectsAssignment assignment_by_breadth; @@ -65,11 +64,11 @@ absl::Status BestGreedy( TotalSize(assignment_by_breadth) < TotalSize(*assignment)) { std::swap(*assignment, assignment_by_breadth); } - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -90,14 +89,14 @@ absl::Status AssignObjectsToTensors( case MemoryStrategy::MINCOSTFLOW: return MinCostFlowAssignment(usage_records, assignment); default: - return absl::InternalError( + return InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -107,14 +106,14 @@ absl::Status AssignObjectsToTensors( case MemoryStrategy::EQUALITY: return EqualityAssignmentWithHash(usage_records, assignment); default: - return absl::InternalError( + return InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -126,14 +125,14 @@ absl::Status AssignObjectsToTensors( case MemoryStrategy::GREEDY_IN_ORDER: return GreedyInOrderAssignmentMultidimensional(usage_records, assignment); default: - return absl::InternalError( + return InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -145,13 +144,13 @@ absl::Status AssignObjectsToTensors( case MemoryStrategy::GREEDY_IN_ORDER: return GreedyInOrderAssignmentMultidimensional(usage_records, assignment); default: - return absl::InternalError( + return InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status AssignOffsetsToTensors( +Status AssignOffsetsToTensors( const std::vector>& usage_records, const MemoryStrategy& strategy, OffsetsAssignment* assignment, const UsageGraph* reallocation_graph) { @@ -162,7 +161,7 @@ absl::Status AssignOffsetsToTensors( RETURN_IF_ERROR(AssignObjectsToTensors( usage_records, strategy, &objects_assignment, reallocation_graph)); *assignment = ObjectsToOffsets(objects_assignment); - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.h b/tensorflow/lite/delegates/gpu/common/memory_management.h index 7df4947ee3d..e45c361d955 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management.h @@ -79,9 +79,8 @@ enum class MemoryStrategy { // Chooses greedy algorithm with the lowest memory consumption for given usage // records and returns corresponding shared objects assignment. -absl::Status BestGreedy( - const std::vector>& usage_records, - ObjectsAssignment* assignment); +Status BestGreedy(const std::vector>& usage_records, + ObjectsAssignment* assignment); // Calculates the assignment of shared objects to given tensors, including // objects' sizes. Below there are specializations for different types, that @@ -91,7 +90,7 @@ absl::Status BestGreedy( // can be larger. Currently only GREEDY_IN_ORDER strategy can use this // reallocation_graph. template -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph = nullptr) { @@ -101,39 +100,39 @@ absl::Status AssignObjectsToTensors( case MemoryStrategy::EQUALITY: return EqualityAssignment(usage_records, assignment); default: - return absl::InternalError( + return InternalError( "MemoryStrategy is not supported with current tensor size type."); } - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); template <> -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); template <> -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); template <> -absl::Status AssignObjectsToTensors( +Status AssignObjectsToTensors( const std::vector>& usage_records, MemoryStrategy strategy, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph); // Calculates the assignment of tensors to offsets, considering those tensors // are going to be allocated in one continuous memory block. -absl::Status AssignOffsetsToTensors( +Status AssignOffsetsToTensors( const std::vector>& usage_records, const MemoryStrategy& strategy, OffsetsAssignment* assignment, const UsageGraph* reallocation_graph = nullptr); diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h index fdccce5159f..0955393e00c 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h @@ -29,7 +29,7 @@ namespace gpu { // Fast version of Equality Assignments for hashable types. template -absl::Status EqualityAssignmentWithHash( +Status EqualityAssignmentWithHash( const std::vector>& usage_records, ObjectsAssignment* assignment) { size_t num_records = usage_records.size(); @@ -69,12 +69,12 @@ absl::Status EqualityAssignmentWithHash( {usage_records[i].last_task, assignment->object_ids[i]}); } } - return absl::OkStatus(); + return OkStatus(); } // Slower version of Equality Assignments for unhashable types. template -absl::Status EqualityAssignment( +Status EqualityAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { size_t num_records = usage_records.size(); @@ -109,7 +109,7 @@ absl::Status EqualityAssignment( dealloc_task[best_obj] = usage_records[i].last_task; } } - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc index 2c138b4c14c..5d0f6b620b0 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc @@ -46,7 +46,7 @@ struct TaskBreadthWithId { } // namespace -absl::Status GreedyByBreadthAssignment( +Status GreedyByBreadthAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { std::vector task_profiles = CalculateTaskProfiles(usage_records); @@ -133,10 +133,10 @@ absl::Status GreedyByBreadthAssignment( // In the end all tensors must be assigned to some objects. for (const auto& obj_id : assignment->object_ids) { if (obj_id == kNotAssigned) { - return absl::InternalError("Error while calculating the assignment."); + return InternalError("Error while calculating the assignment."); } } - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h index 47035229920..c139ba0fe0f 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h @@ -44,7 +44,7 @@ namespace gpu { // tensor’s size, assign current tensor to the smallest of them; // - If there are suitable objects only with size less than current tensor’s // size, assign current tensor to the largest of them and increase its size. -absl::Status GreedyByBreadthAssignment( +Status GreedyByBreadthAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment); diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc index 76309ce8f1b..bf56c6d92dd 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc @@ -60,7 +60,7 @@ struct SizeDistPriorityInfo { } // namespace -absl::Status GreedyBySizeAssignment( +Status GreedyBySizeAssignment( const std::vector>& usage_records, OffsetsAssignment* assignment) { const size_t num_tensors = usage_records.size(); @@ -104,7 +104,7 @@ absl::Status GreedyBySizeAssignment( prev_offset, cur_offset + usage_records[allocated_id].tensor_size); } if (assignment->total_size < prev_offset) { - return absl::InternalError("Total size is wrong."); + return InternalError("Total size is wrong."); } // If no suitable gap found, we should allocate current tensor after the @@ -125,7 +125,7 @@ absl::Status GreedyBySizeAssignment( assignment->total_size = std::max(assignment->total_size, best_offset + rec->tensor_size); } - return absl::OkStatus(); + return OkStatus(); } // Assigns given tensors to shared objects, using the following greedy @@ -152,7 +152,7 @@ absl::Status GreedyBySizeAssignment( // object with size equal to current tensor's size; // - Modify SizeDistPriority records of tensors, that haven't been assigned yet, // to reflect distance changes after that assignment. -absl::Status GreedyBySizeDistPriorityAssignment( +Status GreedyBySizeDistPriorityAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { std::vector positional_max = @@ -175,7 +175,7 @@ absl::Status GreedyBySizeDistPriorityAssignment( ++pos; } if (pos == 0) { - return absl::InternalError("Variable pos must be positive."); + return InternalError("Variable pos must be positive."); } priority_info[rec_id].position = pos - 1; } @@ -198,7 +198,7 @@ absl::Status GreedyBySizeDistPriorityAssignment( if (best_info_id == kNotAssigned) { // During each iteration we assign exactly one of the tensors, so some not // yet assigned tensors must exist. - return absl::InternalError("Invalid value for variable best_info_id."); + return InternalError("Invalid value for variable best_info_id."); } size_t best_rec_id = priority_info[best_info_id].tensor_usage_id; @@ -271,7 +271,7 @@ absl::Status GreedyBySizeDistPriorityAssignment( } } } - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h index b0ad9d18911..fb875fd0920 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h @@ -38,7 +38,7 @@ namespace gpu { // gap. Otherwise we can allocate it after the rightmost tensor, which usage // interval intersects with usage interval of current tensor. So we assign // corresponding offset to current tensor and the tensor becomes assigned. -absl::Status GreedyBySizeAssignment( +Status GreedyBySizeAssignment( const std::vector>& usage_records, OffsetsAssignment* assignment); @@ -66,7 +66,7 @@ absl::Status GreedyBySizeAssignment( // object with size equal to current tensor's size; // - Modify SizeDistPriority records of tensors, that haven't been assigned yet, // to reflect distance changes after that assignment. -absl::Status GreedyBySizeDistPriorityAssignment( +Status GreedyBySizeDistPriorityAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment); diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h index 8c3719e4a8b..b454920ffcb 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h @@ -46,7 +46,7 @@ namespace gpu { // // 3. Shared object size may increase when tensor requests larger size. template -absl::Status GreedyInOrderAssignment( +Status GreedyInOrderAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment, const UsageGraph* reallocation_graph = nullptr) { @@ -111,7 +111,7 @@ absl::Status GreedyInOrderAssignment( } // best_it can't be equal to pool.end(), because pool is not empty if (best_it == pool.end()) { - return absl::InternalError( + return InternalError( "No shared object is found in non-empty pool in " "GreedyInOrderAssignment."); } @@ -135,14 +135,14 @@ absl::Status GreedyInOrderAssignment( {usage_records[i].last_task, assignment->object_ids[i]}); } } - return absl::OkStatus(); + return OkStatus(); } // The same algorithm as above, but for multidimensional case. The only // difference is that shared object dimensions can't be increased to be reused // for tensor, that is larger (at least by one dimension). template -absl::Status GreedyInOrderAssignmentMultidimensional( +Status GreedyInOrderAssignmentMultidimensional( const std::vector>& usage_records, ObjectsAssignment* assignment) { size_t num_records = usage_records.size(); @@ -198,7 +198,7 @@ absl::Status GreedyInOrderAssignmentMultidimensional( {usage_records[i].last_task, assignment->object_ids[i]}); } } - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc index 059c23fab33..ab15af88429 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc @@ -211,14 +211,14 @@ class MinCostFlowSolver { // auxiliary flow graph, find minimum-cost flow in it and calculates the // assignment of shared objects to tensors, using the result of the flow // algorithm. -absl::Status MinCostFlowAssignment( +Status MinCostFlowAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { MinCostFlowSolver solver; solver.Build(usage_records); solver.Solve(); solver.CalculateAssignment(assignment); - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h index 1284c12c5c2..7e45f83c79e 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h @@ -30,7 +30,7 @@ namespace gpu { // auxiliary flow graph, find minimum-cost flow in it and calculates the // assignment of shared objects to tensors, using the result of the flow // algorithm. -absl::Status MinCostFlowAssignment( +Status MinCostFlowAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment); diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h index 8a00c67d853..94cd41ed9a5 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h @@ -30,7 +30,7 @@ namespace gpu { // The problem of memory management is NP-complete. This implements a // naive algorithm that assigns each tensor to a separate object in memory. template -absl::Status NaiveAssignment( +Status NaiveAssignment( const std::vector>& usage_records, ObjectsAssignment* assignment) { assignment->object_sizes.resize(usage_records.size()); @@ -40,7 +40,7 @@ absl::Status NaiveAssignment( assignment->object_ids[i] = i; assignment->object_sizes[i] = record.tensor_size; } - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/model.h b/tensorflow/lite/delegates/gpu/common/model.h index 2e38bcc5f3f..6989584a24c 100644 --- a/tensorflow/lite/delegates/gpu/common/model.h +++ b/tensorflow/lite/delegates/gpu/common/model.h @@ -136,33 +136,33 @@ class Graph { // for a value. If a value had another producer, it will reassign producer // appropriately. If a value didn't have a producer, it will be removed // from a graph's input. - virtual absl::Status SetProducer(NodeId producer, ValueId value) = 0; + virtual Status SetProducer(NodeId producer, ValueId value) = 0; // Removes a producer for the given value. Value becomes producer-less and // therefore becomes graph's input. - virtual absl::Status RemoveProducer(ValueId value) = 0; + virtual Status RemoveProducer(ValueId value) = 0; // Sets a consumer for the given value. There could be multiple consumers // for a value. - virtual absl::Status AddConsumer(NodeId consumer, ValueId value) = 0; + virtual Status AddConsumer(NodeId consumer, ValueId value) = 0; // Replace input value for given node. - virtual absl::Status ReplaceInput(NodeId node, ValueId old_value, - ValueId new_value) = 0; + virtual Status ReplaceInput(NodeId node, ValueId old_value, + ValueId new_value) = 0; // Removes a consumer for the given value. If value does not have any // consumers it becomes graph's output. - virtual absl::Status RemoveConsumer(NodeId consumer, ValueId value) = 0; + virtual Status RemoveConsumer(NodeId consumer, ValueId value) = 0; // Removes node from this graph. For all input values this node will be // removed from consumers and for all output values a producer will be // removed. - virtual absl::Status DeleteNode(NodeId id) = 0; + virtual Status DeleteNode(NodeId id) = 0; // Removes value from this graph. It will be removed from inputs for all // dependent nodes. A node that was a producer of this value will loose its // output. - virtual absl::Status DeleteValue(ValueId id) = 0; + virtual Status DeleteValue(ValueId id) = 0; }; // Implementation of a Graph interface. It keeps values and nodes referenced by @@ -268,7 +268,7 @@ class Model : public Graph { return values_[id].consumers; } - absl::Status SetProducer(NodeId producer, ValueId value) final { + Status SetProducer(NodeId producer, ValueId value) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(value, &v)); Value* value_ptr = v->value.get(); @@ -278,13 +278,12 @@ class Model : public Graph { // check if this value has the same producer already if (node_ptr == v->producer) { - return absl::InvalidArgumentError( - "Node is already a producer of the value"); + return InvalidArgumentError("Node is already a producer of the value"); } // Check if the node is a consumer of this value. if (IsInput(producer, value)) { - return absl::InvalidArgumentError("Node is a consumer of the value"); + return InvalidArgumentError("Node is a consumer of the value"); } // TODO(akulik): detect circular dependency? @@ -294,23 +293,22 @@ class Model : public Graph { } v->producer = node_ptr; n->outputs.push_back(value_ptr); - return absl::OkStatus(); + return OkStatus(); } - absl::Status RemoveProducer(ValueId value) final { + Status RemoveProducer(ValueId value) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(value, &v)); Value* value_ptr = v->value.get(); if (v->producer == nullptr) { - return absl::InvalidArgumentError("Value does not have a producer"); + return InvalidArgumentError("Value does not have a producer"); } Erase(&nodes_[v->producer->id].outputs, value_ptr); v->producer = nullptr; - return absl::OkStatus(); + return OkStatus(); } - absl::Status ReplaceInput(NodeId node, ValueId old_value, - ValueId new_value) final { + Status ReplaceInput(NodeId node, ValueId old_value, ValueId new_value) final { ValueDef* v_old; RETURN_IF_ERROR(LookupValue(old_value, &v_old)); Value* value_old_ptr = v_old->value.get(); @@ -323,17 +321,17 @@ class Model : public Graph { // Check if the node is a consumer of old_value. if (!IsInput(node, old_value)) { - return absl::InvalidArgumentError("old_value must be input of node."); + return InvalidArgumentError("old_value must be input of node."); } // Check if the node is not a consumer of new_value. if (IsInput(node, new_value)) { - return absl::InvalidArgumentError("new_value can not be input of node."); + return InvalidArgumentError("new_value can not be input of node."); } // Check if this value has the same producer already if (node_ptr == v_new->producer) { - return absl::InvalidArgumentError("new_value can not be output of node."); + return InvalidArgumentError("new_value can not be output of node."); } for (int i = 0; i < n->inputs.size(); ++i) { @@ -344,10 +342,10 @@ class Model : public Graph { } v_new->consumers.push_back(node_ptr); Erase(&v_old->consumers, node_ptr); - return absl::OkStatus(); + return OkStatus(); } - absl::Status AddConsumer(NodeId consumer, ValueId value) final { + Status AddConsumer(NodeId consumer, ValueId value) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(value, &v)); Value* value_ptr = v->value.get(); @@ -357,21 +355,20 @@ class Model : public Graph { // check if this value has the same producer already if (node_ptr == v->producer) { - return absl::InvalidArgumentError("Node is a producer of the value"); + return InvalidArgumentError("Node is a producer of the value"); } // check if this value has the same consumer already if (IsInput(consumer, value)) { - return absl::InvalidArgumentError( - "Node is already a consumer of the value"); + return InvalidArgumentError("Node is already a consumer of the value"); } n->inputs.push_back(value_ptr); v->consumers.push_back(node_ptr); - return absl::OkStatus(); + return OkStatus(); } - absl::Status RemoveConsumer(NodeId consumer, ValueId value) final { + Status RemoveConsumer(NodeId consumer, ValueId value) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(value, &v)); Value* value_ptr = v->value.get(); @@ -379,14 +376,14 @@ class Model : public Graph { RETURN_IF_ERROR(LookupNode(consumer, &n)); Node* node_ptr = n->node.get(); if (!IsInput(consumer, value)) { - return absl::InvalidArgumentError("Node is not a consumer of the value"); + return InvalidArgumentError("Node is not a consumer of the value"); } Erase(&n->inputs, value_ptr); Erase(&v->consumers, node_ptr); - return absl::OkStatus(); + return OkStatus(); } - absl::Status DeleteNode(NodeId id) final { + Status DeleteNode(NodeId id) final { NodeDef* n; RETURN_IF_ERROR(LookupNode(id, &n)); Node* node_ptr = n->node.get(); @@ -399,10 +396,10 @@ class Model : public Graph { n->inputs.clear(); n->outputs.clear(); n->node.reset(); - return absl::OkStatus(); + return OkStatus(); } - absl::Status DeleteValue(ValueId id) final { + Status DeleteValue(ValueId id) final { ValueDef* v; RETURN_IF_ERROR(LookupValue(id, &v)); Value* value_ptr = v->value.get(); @@ -417,10 +414,10 @@ class Model : public Graph { v->producer = nullptr; v->consumers.clear(); v->value.reset(); - return absl::OkStatus(); + return OkStatus(); } - absl::Status MakeExactCopy(Model* model) const { + Status MakeExactCopy(Model* model) const { model->nodes_.clear(); model->values_.clear(); model->name_ = name_; @@ -443,7 +440,7 @@ class Model : public Graph { } } } - return absl::OkStatus(); + return OkStatus(); } private: @@ -478,29 +475,29 @@ class Model : public Graph { } // @return non-nullptr NodeDef that has valid Node or an error - absl::Status LookupNode(NodeId id, NodeDef** node_def) { + Status LookupNode(NodeId id, NodeDef** node_def) { if (id >= nodes_.size()) { - return absl::OutOfRangeError("NodeId is out of range"); + return OutOfRangeError("NodeId is out of range"); } auto& n = nodes_[id]; if (!n.node) { - return absl::OutOfRangeError("Node is already deleted"); + return OutOfRangeError("Node is already deleted"); } *node_def = &n; - return absl::OkStatus(); + return OkStatus(); } // @return non-nullptr ValueDef that has valid Value or an error - absl::Status LookupValue(ValueId id, ValueDef** value_def) { + Status LookupValue(ValueId id, ValueDef** value_def) { if (id >= values_.size()) { - return absl::OutOfRangeError("ValueId is out of range"); + return OutOfRangeError("ValueId is out of range"); } auto& v = values_[id]; if (!v.value) { - return absl::OutOfRangeError("Value is already deleted"); + return OutOfRangeError("Value is already deleted"); } *value_def = &v; - return absl::OkStatus(); + return OkStatus(); } template @@ -540,14 +537,14 @@ class Model : public Graph { // outputs that are consumed only by to_keep. In such case to_keep inherits all // to_remove inputs. template -absl::Status RemovePrecedingNode(Graph* graph, const Node* to_remove, - const Node* to_keep) { +Status RemovePrecedingNode(Graph* graph, const Node* to_remove, + const Node* to_keep) { // Make sure all outputs from to_remove are consumed by to_keep. for (auto output : graph->FindOutputs(to_remove->id)) { auto consumers = graph->FindConsumers(output->id); if (consumers.size() > 1 || (consumers.size() == 1 && consumers[0] != to_keep)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Output from to_remove node has other consumers"); } } @@ -565,13 +562,13 @@ absl::Status RemovePrecedingNode(Graph* graph, const Node* to_remove, // Removes to_remove node that follows to_keep node only if to_remove has inputs // that are produced by to_keep. to_keep inherits all to_remove inputs. template -absl::Status RemoveFollowingNode(Graph* graph, const Node* to_remove, - const Node* to_keep) { +Status RemoveFollowingNode(Graph* graph, const Node* to_remove, + const Node* to_keep) { // Make sure all inputs to to_remove are produced by to_keep. for (auto input : graph->FindInputs(to_remove->id)) { Node* producer = graph->FindProducer(input->id); if (producer->id != to_keep->id) { - return absl::InvalidArgumentError("To_remove node has other inputs"); + return InvalidArgumentError("To_remove node has other inputs"); } } @@ -587,12 +584,12 @@ absl::Status RemoveFollowingNode(Graph* graph, const Node* to_remove, // Removes to_remove node. // Requires that node has one input and one output; template -absl::Status RemoveOneInputOneOutputNode(Graph* graph, - const Node* to_remove) { +Status RemoveOneInputOneOutputNode(Graph* graph, + const Node* to_remove) { auto inputs = graph->FindInputs(to_remove->id); auto outputs = graph->FindOutputs(to_remove->id); if (inputs.size() != 1 || outputs.size() != 1) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "To_remove node must have 1 input and 1 output"); } auto input_id = inputs[0]->id; @@ -607,26 +604,26 @@ absl::Status RemoveOneInputOneOutputNode(Graph* graph, if (!producer && consumers.empty()) { RETURN_IF_ERROR(graph->DeleteValue(input_id)); } - return absl::OkStatus(); + return OkStatus(); } template -absl::Status AddOutput(Graph* graph, const Node* from_node, - Value** output) { +Status AddOutput(Graph* graph, const Node* from_node, + Value** output) { auto link = graph->NewValue(); RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id)); *output = link; - return absl::OkStatus(); + return OkStatus(); } template -absl::Status ConnectTwoNodes(Graph* graph, const Node* from_node, - const Node* to_node, Value** output) { +Status ConnectTwoNodes(Graph* graph, const Node* from_node, + const Node* to_node, Value** output) { Value* link; RETURN_IF_ERROR(AddOutput(graph, from_node, &link)); RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id)); *output = link; - return absl::OkStatus(); + return OkStatus(); } using GraphFloat32 = Model>; diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 94899efe91e..b37c3542413 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -65,9 +65,9 @@ namespace { // node(output) // will turn into: // node(copy(output)) <- passthrough_node(output) -absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node, - const Value>* output, - Node** passthru_node) { +Status NewPassthroughNode(GraphFloat32* graph, Node* node, + const Value>* output, + Node** passthru_node) { *passthru_node = graph->NewNode(); // Make copies for every output in the original node. RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id)); @@ -76,18 +76,18 @@ absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node, RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id)); copy_output->tensor = output->tensor; copy_output->tensor.ref = -1; - return absl::OkStatus(); + return OkStatus(); } template -absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) { +Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) { if (tensor.bytes % sizeof(T) != 0) { - return absl::InvalidArgumentError( + return InvalidArgumentError( absl::StrCat("Input data size ", tensor.bytes, " is not aligned to expected type: ", sizeof(T))); } std::memcpy(tensor_data, tensor.data.uint8, tensor.bytes); - return absl::OkStatus(); + return OkStatus(); } void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src, @@ -98,8 +98,8 @@ void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src, } template <> -absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, - float* tensor_data) { +Status CreateVectorCopyData(const TfLiteTensor& tensor, + float* tensor_data) { switch (tensor.type) { case kTfLiteFloat32: std::memcpy(tensor_data, tensor.data.f, tensor.bytes); @@ -110,97 +110,104 @@ absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, reinterpret_cast(tensor.data.f16), tensor_data); break; default: - return absl::InvalidArgumentError( - "Unsupported data type for float32 tensor"); + return InvalidArgumentError("Unsupported data type for float32 tensor"); } - return absl::OkStatus(); + return OkStatus(); } template -absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, ShapeT* shape); +Status SetAllDimensions(const TfLiteIntArray* dimensions, ShapeT* shape); template <> -absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, - Scalar* shape) { +Status SetAllDimensions(const TfLiteIntArray* dimensions, + Scalar* shape) { if (dimensions->size < 0) { - return absl::InvalidArgumentError("Invalid Scalar dimensions"); + return InvalidArgumentError("Invalid Scalar dimensions"); } for (int i = 0; i < dimensions->size; ++i) { if (dimensions->data[i] != 1) { - return absl::InvalidArgumentError( - "Dimension can not be reduced to scalar."); + return InvalidArgumentError("Dimension can not be reduced to scalar."); } } shape->v = 1; - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, - Linear* shape) { +Status SetAllDimensions(const TfLiteIntArray* dimensions, + Linear* shape) { if (dimensions->size <= 0) { - return absl::InvalidArgumentError("Dimension is empty."); + return InvalidArgumentError("Dimension is empty."); } for (int i = 0; i < dimensions->size - 1; ++i) { if (dimensions->data[i] != 1) { - return absl::InvalidArgumentError( - "Dimension can not be reduced to linear."); + return InvalidArgumentError("Dimension can not be reduced to linear."); } } shape->v = dimensions->data[dimensions->size - 1]; - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, - HWC* shape) { +Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) { if (dimensions->size != 4) { - return absl::InvalidArgumentError("Dimensions are not HWC"); + return InvalidArgumentError("Dimensions are not HWC"); } if (dimensions->data[0] != 1) { - return absl::UnimplementedError("Batch size is not equal to 1."); + return UnimplementedError("Batch size is not equal to 1."); } shape->h = dimensions->data[1]; shape->w = dimensions->data[2]; shape->c = dimensions->data[3]; - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) { +Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) { if (dimensions->size != 2) { - return absl::InvalidArgumentError("Dimensions are not HW"); + return InvalidArgumentError("Dimensions are not HW"); } shape->h = dimensions->data[0]; shape->w = dimensions->data[1]; - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, - OHWI* shape) { +Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) { if (dimensions->size != 4) { - return absl::InvalidArgumentError( + return InvalidArgumentError( absl::StrCat("Dimensions are not OHWI: ", dimensions->size)); } shape->o = dimensions->data[0]; shape->h = dimensions->data[1]; shape->w = dimensions->data[2]; shape->i = dimensions->data[3]; - return absl::OkStatus(); + return OkStatus(); } template <> -absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, - BHWC* shape) { +Status SetAllDimensions(const TfLiteIntArray* dimensions, IHWO* shape) { if (dimensions->size != 4) { - return absl::InvalidArgumentError("Dimensions are not BHWC"); + return InvalidArgumentError( + absl::StrCat("Dimensions are not IHWO: ", dimensions->size)); + } + shape->i = dimensions->data[0]; + shape->h = dimensions->data[1]; + shape->w = dimensions->data[2]; + shape->o = dimensions->data[3]; + return OkStatus(); +} + +template <> +Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) { + if (dimensions->size != 4) { + return InvalidArgumentError("Dimensions are not BHWC"); } shape->b = dimensions->data[0]; shape->h = dimensions->data[1]; shape->w = dimensions->data[2]; shape->c = dimensions->data[3]; - return absl::OkStatus(); + return OkStatus(); } DataType ToDataType(TfLiteType type) { @@ -246,46 +253,46 @@ int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context, return number_of_runtime_outputs; } -absl::Status CheckTensorIsAvailable(const TfLiteContext* context, - const TfLiteNode* tflite_node, int idx) { +Status CheckTensorIsAvailable(const TfLiteContext* context, + const TfLiteNode* tflite_node, int idx) { // If tensor id is in range, it's guaranteed that it'll be available. if (idx >= tflite_node->inputs->size) { - return absl::OutOfRangeError( + return OutOfRangeError( absl::StrFormat("Requested index goes beyond array size (%d vs %d).", idx, tflite_node->inputs->data[idx])); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckInputsOutputs(const TfLiteContext* context, - const TfLiteNode* tflite_node, - int runtime_inputs, int outputs) { +Status CheckInputsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, int runtime_inputs, + int outputs) { int runtime_inputs_from_model = GetNumberOfRuntimeInputsForNode(context, tflite_node); if (runtime_inputs_from_model != runtime_inputs) { - return absl::InternalError(absl::StrFormat( + return InternalError(absl::StrFormat( "Expected %d runtime input tensor(s), but node has %d runtime " "input(s).", runtime_inputs, runtime_inputs_from_model)); } int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); if (runtime_outputs != outputs) { - return absl::InternalError( + return InternalError( absl::StrFormat("Expected %d output tensor(s), but node has %d " "output(s).", outputs, runtime_outputs)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckInputsConstsOutputs(const TfLiteContext* context, - const TfLiteNode* tflite_node, - int runtime_inputs, int const_inputs, - int outputs) { +Status CheckInputsConstsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, + int runtime_inputs, int const_inputs, + int outputs) { int const_inputs_from_model = GetNumberOfConstInputsForNode(context, tflite_node); if (const_inputs_from_model != const_inputs) { - return absl::InternalError(absl::StrFormat( + return InternalError(absl::StrFormat( "Expected %d const input tensor(s), but node has %d const " "input(s).", const_inputs, const_inputs_from_model)); @@ -303,9 +310,9 @@ class ObjectReader { tflite_node_(tflite_node), tensor_to_value_(tensor_to_value) {} - absl::Status ReadValue(uint32_t idx, Value>** value) const { + Status ReadValue(uint32_t idx, Value>** value) const { if (idx >= tflite_node_->inputs->size) { - return absl::OutOfRangeError( + return OutOfRangeError( absl::StrCat("ReadValue: input tensor index: ", idx)); } return ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value); @@ -315,21 +322,21 @@ class ObjectReader { return GetNumberOfRuntimeInputsForNode(context_, tflite_node_); } - absl::Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const { + Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const { if (idx >= tflite_node_->inputs->size) { - return absl::OutOfRangeError(absl::StrCat("Input tensor index: ", idx)); + return OutOfRangeError(absl::StrCat("Input tensor index: ", idx)); } const int tensor_idx = tflite_node_->inputs->data[idx]; if (tensor_idx < 0 || tensor_idx > context_->tensors_size) { - return absl::OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx)); + return OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx)); } const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx]; *dimensions = *tflite_tensor.dims; - return absl::OkStatus(); + return OkStatus(); } template - absl::Status ReadTensor(uint32_t idx, TensorT* t) const { + Status ReadTensor(uint32_t idx, TensorT* t) const { RETURN_IF_ERROR(CheckTensorIsAvailable(context_, tflite_node_, idx)); const int32_t tensor_idx = tflite_node_->inputs->data[idx]; const TfLiteTensor* tflite_tensor = context_->tensors + tensor_idx; @@ -342,9 +349,9 @@ class ObjectReader { return SetAllDimensions(tflite_tensor->dims, &t->shape); } - absl::Status AddOutput(const Node* node, int id) { + Status AddOutput(const Node* node, int id) { if (tflite_node_->outputs->size <= id) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "Data id ", id, " must be less than tflite node outputs size ", tflite_node_->outputs->size)); } @@ -352,32 +359,32 @@ class ObjectReader { Value>* value; RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value)); RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status AddOutputs(const Node* node) { + Status AddOutputs(const Node* node) { for (int i = 0; i < tflite_node_->outputs->size; ++i) { RETURN_IF_ERROR(AddOutput(node, i)); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status AddInput(const Node* node, uint32_t idx) { + Status AddInput(const Node* node, uint32_t idx) { Value>* input; RETURN_IF_ERROR(ReadValue(idx, &input)); return graph_->AddConsumer(node->id, input->id); } - absl::Status ReadValueByTensorIdx(uint32_t tensor_idx, - Value>** value) const { + Status ReadValueByTensorIdx(uint32_t tensor_idx, + Value>** value) const { if (tensor_idx >= tensor_to_value_->size()) { - return absl::OutOfRangeError( + return OutOfRangeError( absl::StrCat("ReadValue: input tensor index: ", tensor_idx)); } if ((*tensor_to_value_)[tensor_idx] == nullptr) { const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx]; if (tflite::IsConstantTensor(&tflite_tensor)) { - return absl::NotFoundError(absl::StrCat( + return NotFoundError(absl::StrCat( "ReadValue: value is a constant tensor: ", tensor_idx)); } Value>* value = graph_->NewValue(); @@ -387,7 +394,7 @@ class ObjectReader { (*tensor_to_value_)[tensor_idx] = value; } *value = (*tensor_to_value_)[tensor_idx]; - return absl::OkStatus(); + return OkStatus(); } TfLiteTensor* GetInputTensor(int index) const { @@ -402,9 +409,9 @@ class ObjectReader { : nullptr; } - absl::Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node, - int runtime_inputs, int const_inputs, - int outputs) { + Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node, + int runtime_inputs, int const_inputs, + int outputs) { return CheckInputsConstsOutputs(context_, tflite_node, runtime_inputs, const_inputs, outputs); } @@ -423,30 +430,28 @@ class TFLiteOperationParser { // Parses TFLite operation. This method allows expanding fused operations // into more than one node. - virtual absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) = 0; + virtual Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) = 0; // Verifies whether passed tflite node may be built by GPU delegate or not. - virtual absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) = 0; + virtual Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) = 0; }; -absl::Status IsActivationSupported(TfLiteFusedActivation fused_activation) { +Status IsActivationSupported(TfLiteFusedActivation fused_activation) { switch (fused_activation) { case kTfLiteActNone: case kTfLiteActRelu: case kTfLiteActRelu1: case kTfLiteActRelu6: case kTfLiteActTanh: - return absl::OkStatus(); + return OkStatus(); case kTfLiteActSignBit: - return absl::UnimplementedError( - "TfLiteFusedActivation.kTfLiteActSignBit"); + return UnimplementedError("TfLiteFusedActivation.kTfLiteActSignBit"); case kTfLiteActSigmoid: - return absl::UnimplementedError( - "TfLiteFusedActivation.kTfLiteActSigmoid"); + return UnimplementedError("TfLiteFusedActivation.kTfLiteActSigmoid"); // Do not add default; we want compilation error rather than run-time // error. @@ -456,15 +461,15 @@ absl::Status IsActivationSupported(TfLiteFusedActivation fused_activation) { // If there is fused activation present, then there will be another node created // that will have identical output as the given node. New operation node will // depend on the given node output. -absl::Status MaybeFuseActivation(TfLiteFusedActivation fused_activation, - const std::vector& output_indices, - GraphFloat32* graph, Node* node) { +Status MaybeFuseActivation(TfLiteFusedActivation fused_activation, + const std::vector& output_indices, + GraphFloat32* graph, Node* node) { if (fused_activation == kTfLiteActNone) { - return absl::OkStatus(); + return OkStatus(); } const auto& outputs = graph->FindOutputs(node->id); if (outputs.empty()) { - return absl::InternalError("Empty outputs in fused node"); + return InternalError("Empty outputs in fused node"); } switch (fused_activation) { case kTfLiteActRelu: @@ -492,16 +497,16 @@ absl::Status MaybeFuseActivation(TfLiteFusedActivation fused_activation, } break; default: - return absl::NotFoundError( + return NotFoundError( absl::StrCat("Unsupported fused activation: ", fused_activation)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status MaybeFuseActivationToTheSingleOutput( +Status MaybeFuseActivationToTheSingleOutput( TfLiteFusedActivation fused_activation, GraphFloat32* graph, Node* node) { if (graph->FindOutputs(node->id).size() != 1) { - return absl::InternalError("Number of outputs exceeds 1"); + return InternalError("Number of outputs exceeds 1"); } return MaybeFuseActivation(fused_activation, {0}, graph, node); } @@ -519,10 +524,9 @@ void UpdatePadding(const TfLitePadding& padding, const BHWC& input_shape, } } -absl::Status GetFullyConnectedAttributes(int weights_tensor_id, - int bias_tensor_id, - ObjectReader* reader, - FullyConnectedAttributes* attr) { +Status GetFullyConnectedAttributes(int weights_tensor_id, int bias_tensor_id, + ObjectReader* reader, + FullyConnectedAttributes* attr) { Tensor weights; RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights)); attr->weights.data = std::move(weights.data); @@ -533,100 +537,100 @@ absl::Status GetFullyConnectedAttributes(int weights_tensor_id, attr->weights.shape.i = weights.shape.w; reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional - return absl::OkStatus(); + return OkStatus(); } template -absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node, - ParamsT** tf_options) { +Status RetrieveBuiltinData(const TfLiteNode* tflite_node, + ParamsT** tf_options) { const auto* params = reinterpret_cast(tflite_node->builtin_data); if (!params) { - return absl::InternalError("Unable to retrieve builtin_data."); + return InternalError("Unable to retrieve builtin_data."); } *tf_options = const_cast(params); - return absl::OkStatus(); + return OkStatus(); } template -absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node, - ParamsType** tf_options) { +Status RetrieveCustomInitialData(const TfLiteNode* tflite_node, + ParamsType** tf_options) { const auto* params = reinterpret_cast(tflite_node->custom_initial_data); if (!params) { - return absl::InternalError("Unable to retrieve custom_initial_data."); + return InternalError("Unable to retrieve custom_initial_data."); } *tf_options = const_cast(params); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration, - int max_version) { +Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration, + int max_version) { const int op_version = registration->version; if (op_version > max_version) { - return absl::UnimplementedError( + return UnimplementedError( absl::StrFormat("Max version supported: %d. Requested version %d.", max_version, op_version)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckExactSupportedOpVersion( - const TfLiteRegistration* registration, int expected_version) { +Status CheckExactSupportedOpVersion(const TfLiteRegistration* registration, + int expected_version) { int op_version = registration->version; if (op_version != expected_version) { - return absl::UnimplementedError( + return UnimplementedError( absl::StrFormat("Only version %d is supported. Requested version %d.", expected_version, op_version)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckKernels(int kernel_h, int kernel_w) { +Status CheckKernels(int kernel_h, int kernel_w) { if (kernel_h <= 0 || kernel_w <= 0) { - return absl::InvalidArgumentError(absl::StrFormat( + return InvalidArgumentError(absl::StrFormat( "Incorrect kernel values: kernel_height = %d, kernel_width = %d.", kernel_h, kernel_w)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckStrides(int strides_h, int strides_w) { +Status CheckStrides(int strides_h, int strides_w) { if (strides_h <= 0 || strides_w <= 0) { - return absl::InvalidArgumentError(absl::StrFormat( + return InvalidArgumentError(absl::StrFormat( "Incorrect stride values: stride_height = %d, stride_width = %d.", strides_h, strides_w)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckDilation(int dilation_h, int dilation_w) { +Status CheckDilation(int dilation_h, int dilation_w) { if (dilation_h <= 0 || dilation_w <= 0) { - return absl::InvalidArgumentError( + return InvalidArgumentError( absl::StrFormat("Incorrect dilation values: dilation_factor = %d, " "dilation_factor = %d.", dilation_h, dilation_w)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckStridesAndDilation(int strides_h, int strides_w, - int dilation_h, int dilation_w) { +Status CheckStridesAndDilation(int strides_h, int strides_w, int dilation_h, + int dilation_w) { RETURN_IF_ERROR(CheckStrides(strides_h, strides_w)); RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w)); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h, - int strides_w) { +Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h, + int strides_w) { RETURN_IF_ERROR(CheckKernels(kernel_h, kernel_w)); RETURN_IF_ERROR(CheckStrides(strides_h, strides_w)); - return absl::OkStatus(); + return OkStatus(); } // Creates a simple node that holds tensor value. -absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, - Value>** value) { +Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, + Value>** value) { ConstTensorAttributes attr; attr.tensor = std::move(t); Node* node = graph->NewNode(); @@ -638,59 +642,59 @@ absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, (*value)->tensor.ref = attr.tensor.id; (*value)->tensor.type = attr.tensor.kType; (*value)->tensor.shape = attr.tensor.shape; - return absl::OkStatus(); + return OkStatus(); } -absl::Status ParsePoolingAttributes(const TfLitePoolParams* tf_options, - const BHWC& input_shape, - Pooling2DAttributes* attr) { +Status ParsePoolingAttributes(const TfLitePoolParams* tf_options, + const BHWC& input_shape, + Pooling2DAttributes* attr) { attr->kernel = ToHW(tf_options->filter_height, tf_options->filter_width); attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width); UpdatePadding(tf_options->padding, input_shape, attr); - return absl::OkStatus(); + return OkStatus(); } -absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) { +Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) { const TfLiteIntArray* dims = tflite_tensor.dims; switch (dims->size) { case 1: *bhwc = BHWC(dims->data[0], 1, 1, 1); - return absl::OkStatus(); + return OkStatus(); case 2: *bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]); - return absl::OkStatus(); + return OkStatus(); case 3: *bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]); - return absl::OkStatus(); + return OkStatus(); case 4: *bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]); - return absl::OkStatus(); + return OkStatus(); default: - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "Tensor \"", tflite_tensor.name ? tflite_tensor.name : "nullptr", "\" has bad input dims size: ", dims->size, ".")); } } -absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader, - TensorOrScalar* tensor_or_scalar) { +Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader, + TensorOrScalar* tensor_or_scalar) { const std::string& opname = node->operation.type; // Determine runtime/constant tensors. const TfLiteTensor* input0 = reader->GetInputTensor(0); if (!input0) { - return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " + - opname); + return InvalidArgumentError("Couldn't get the 1st input tensor for " + + opname); } const TfLiteTensor* input1 = reader->GetInputTensor(1); if (!input1) { - return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " + - opname); + return InvalidArgumentError("Couldn't get the 2nd input tensor for " + + opname); } const bool constant_tensor0 = IsConstantTensor(input0); const bool constant_tensor1 = IsConstantTensor(input1); if (constant_tensor0 && constant_tensor1) { - return absl::InvalidArgumentError("No runtime input tensors for " + opname); + return InvalidArgumentError("No runtime input tensors for " + opname); } const bool runtime_tensor0 = !constant_tensor0; const bool runtime_tensor1 = !constant_tensor1; @@ -718,26 +722,26 @@ absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader, *tensor_or_scalar = std::move(tensor); } } - return absl::OkStatus(); + return OkStatus(); } class AddOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); if (tflite_node->inputs->size != 2) { - return absl::UnimplementedError("ADD requires two input tensors."); + return UnimplementedError("ADD requires two input tensors."); } // TODO(eignasheva): Add shapes check. TfLiteAddParams* tf_options = nullptr; return RetrieveBuiltinData(tflite_node, &tf_options); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { // TFLite currently only supports 2 input ADDs. Thus, the logic below only // considers 2 input cases. The underlying GPU shader programs can accept // more inputs, but the logic below would have to be expanded. @@ -751,7 +755,7 @@ class AddOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node); @@ -760,9 +764,9 @@ class AddOperationParser : public TFLiteOperationParser { class ConcatenationOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); // TODO(eignasheva): add proper tensor availability checking @@ -772,12 +776,12 @@ class ConcatenationOperationParser : public TFLiteOperationParser { // TODO(eignasheva): add axis checking. TfLiteConcatenationParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { ConcatAttributes attr; // Read inputs first to make sure const node is added to a graph before // concat node to ensure topological order. @@ -828,16 +832,16 @@ class ConcatenationOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node)); node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } private: - absl::Status SetAxis(const std::vector& input_shapes, Axis* axis) { + Status SetAxis(const std::vector& input_shapes, Axis* axis) { *axis = Axis::BATCH; for (int i = 1; i < input_shapes.size(); i++) { if (input_shapes[0].h != input_shapes[i].h && @@ -847,7 +851,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser { break; } } - if (*axis == Axis::BATCH) return absl::OkStatus(); + if (*axis == Axis::BATCH) return OkStatus(); for (int i = 1; i < input_shapes.size(); i++) { if (input_shapes[0].b != input_shapes[i].b && input_shapes[0].w != input_shapes[i].w && @@ -856,7 +860,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser { break; } } - if (*axis == Axis::HEIGHT) return absl::OkStatus(); + if (*axis == Axis::HEIGHT) return OkStatus(); for (int i = 1; i < input_shapes.size(); i++) { if (input_shapes[0].b != input_shapes[i].b && input_shapes[0].h != input_shapes[i].h && @@ -865,25 +869,25 @@ class ConcatenationOperationParser : public TFLiteOperationParser { break; } } - if (*axis == Axis::WIDTH) return absl::OkStatus(); + if (*axis == Axis::WIDTH) return OkStatus(); for (int i = 1; i < input_shapes.size(); i++) { if (input_shapes[0].b != input_shapes[i].b && input_shapes[0].w != input_shapes[i].w && input_shapes[0].h != input_shapes[i].h) { - return absl::UnimplementedError( + return UnimplementedError( "Can concatenate tensors only by batch, height, width, or " "channels."); } } - return absl::OkStatus(); + return OkStatus(); } }; class Conv2DOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -896,9 +900,9 @@ class Conv2DOperationParser : public TFLiteOperationParser { return IsActivationSupported(tf_options->activation); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::CONVOLUTION_2D); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -911,7 +915,7 @@ class Conv2DOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); attr.dilations = HW(tf_options->dilation_height_factor, @@ -921,26 +925,26 @@ class Conv2DOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node)); node->operation.attributes = std::move(attr); - return absl::OkStatus(); + return OkStatus(); } }; class Convolution2DTransposeBiasParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); TfLiteTransposeConvParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); RETURN_IF_ERROR( CheckStrides(tf_options->stride_height, tf_options->stride_width)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -959,15 +963,15 @@ class Convolution2DTransposeBiasParser : public TFLiteOperationParser { &attr); node->operation.attributes = std::move(attr); - return absl::OkStatus(); + return OkStatus(); } }; class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -987,38 +991,37 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { : nullptr; const auto* output = context->tensors + tflite_node->outputs->data[0]; if (!input->dims || input->dims->size != 4) { - return absl::InvalidArgumentError("input.dims.size != 4"); + return InvalidArgumentError("input.dims.size != 4"); } if (!filter->dims || filter->dims->size != 4) { - return absl::InvalidArgumentError("filter.dims.size != 4"); + return InvalidArgumentError("filter.dims.size != 4"); } if (!output->dims || output->dims->size != 4) { - return absl::InvalidArgumentError("output.dims.size != 4"); + return InvalidArgumentError("output.dims.size != 4"); } if (input->dims->data[0] != output->dims->data[0]) { - return absl::InvalidArgumentError("input.b != output.b"); + return InvalidArgumentError("input.b != output.b"); } const int input_depth = input->dims->data[3]; const int output_depth = output->dims->data[3]; if (filter->dims->data[3] != output_depth) { - return absl::InvalidArgumentError("filter.i != output.c"); + return InvalidArgumentError("filter.i != output.c"); } if (output_depth != input_depth * depth_multiplier) { - return absl::InvalidArgumentError( - "output.c != input.c * depth_multiplier"); + return InvalidArgumentError("output.c != input.c * depth_multiplier"); } if (bias && NumElements(bias) != output_depth) { - return absl::InvalidArgumentError("bias.size != output.c"); + return InvalidArgumentError("bias.size != output.c"); } if (depth_multiplier != 1 && input_depth != 1) { - return absl::UnimplementedError("depth_multiplier != 1 && input.c != 1"); + return UnimplementedError("depth_multiplier != 1 && input.c != 1"); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1044,7 +1047,7 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { TransposeWeights(input, filter, output, depth_multiplier, &attr); } node->operation.attributes = std::move(attr); - return absl::OkStatus(); + return OkStatus(); } private: @@ -1083,9 +1086,9 @@ class ElementwiseOperationParser : public TFLiteOperationParser { explicit ElementwiseOperationParser(OperationType operation_type) : operation_type_(operation_type) {} - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); if (IsOneArgumentOperation()) { RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, @@ -1103,17 +1106,16 @@ class ElementwiseOperationParser : public TFLiteOperationParser { /*const_inputs=*/1, /*outputs=*/1)); } else { - return absl::InvalidArgumentError( - "Op can only handle 1 or 2 operand(s)."); + return InvalidArgumentError("Op can only handle 1 or 2 operand(s)."); } TfLiteFusedActivation activation; RETURN_IF_ERROR(GetActivation(tflite_node, &activation)); return IsActivationSupported(activation); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(operation_type_); @@ -1130,7 +1132,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser { /*const_inputs=*/0, /*outputs=*/1)); if (tflite_node->inputs->size != 2) { - return absl::InvalidArgumentError("Applies only two input tensors"); + return InvalidArgumentError("Applies only two input tensors"); } RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddInput(node, 1)); @@ -1171,32 +1173,32 @@ class ElementwiseOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); node->operation.attributes = std::move(attr); } else { - return absl::InvalidArgumentError("Incorrect operation type passed"); + return InvalidArgumentError("Incorrect operation type passed"); } return reader->AddOutputs(node); } private: - absl::Status GetActivation(const TfLiteNode* tflite_node, - TfLiteFusedActivation* activation) const { + Status GetActivation(const TfLiteNode* tflite_node, + TfLiteFusedActivation* activation) const { if (operation_type_ == OperationType::DIV) { TfLiteDivParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); *activation = tf_options ? tf_options->activation : kTfLiteActNone; - return absl::OkStatus(); + return OkStatus(); } if (operation_type_ == OperationType::SUB) { TfLiteSubParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); *activation = tf_options ? tf_options->activation : kTfLiteActNone; - return absl::OkStatus(); + return OkStatus(); } // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or // TfLiteXxxParams.activation. *activation = kTfLiteActNone; - return absl::OkStatus(); + return OkStatus(); } bool IsOneArgumentOperation() const { @@ -1245,24 +1247,23 @@ class ElementwiseOperationParser : public TFLiteOperationParser { class FullyConnectedOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); TfLiteFullyConnectedParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) { - return absl::UnimplementedError( - "Unsupported FullyConnected weights format."); + return UnimplementedError("Unsupported FullyConnected weights format."); } // TODO(eignasheva): check input shape - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1271,8 +1272,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { tflite_node->builtin_data); if (tf_options->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) { - return absl::UnimplementedError( - "Unsupported FullyConnected weights format."); + return UnimplementedError("Unsupported FullyConnected weights format."); } FullyConnectedAttributes attr; @@ -1284,7 +1284,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { int batch_size = input->tensor.shape.b; if (input->tensor.shape.DimensionsProduct() / batch_size != weights.shape.w) { - return absl::UnimplementedError( + return UnimplementedError( "Amount of input data should match weights width"); } @@ -1306,7 +1306,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { conv->operation.type = ToString(OperationType::FULLY_CONNECTED); conv->operation.attributes = std::move(attr); - absl::Status result = reader->AddOutputs(conv); + Status result = reader->AddOutputs(conv); RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, conv)); @@ -1316,15 +1316,15 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { class HardSwishOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration*) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration*) final { return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } - absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode*, const TfLiteRegistration*, + GraphFloat32* graph, ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::HARD_SWISH); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1353,9 +1353,9 @@ class HardSwishOperationParser : public TFLiteOperationParser { // class LSTMOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2)); // TODO(eignasheva): Fix bad check. // RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, @@ -1364,23 +1364,23 @@ class LSTMOperationParser : public TFLiteOperationParser { TfLiteLSTMParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckParameters(tf_options)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { if (tflite_node->inputs->size != 5) { - return absl::InvalidArgumentError("LSTM should have 5 input tensors"); + return InvalidArgumentError("LSTM should have 5 input tensors"); } if (tflite_node->outputs->size != 4) { - return absl::InvalidArgumentError("LSTM should have 4 output tensors"); + return InvalidArgumentError("LSTM should have 4 output tensors"); } const auto* params = reinterpret_cast(tflite_node->builtin_data); if (!params) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } RETURN_IF_ERROR(CheckParameters(params)); @@ -1423,61 +1423,58 @@ class LSTMOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation - return absl::OkStatus(); + return OkStatus(); } private: - absl::Status CheckParameters(const TfLiteLSTMParams* tf_options) { + Status CheckParameters(const TfLiteLSTMParams* tf_options) { if (tf_options->kernel_type != TfLiteLSTMKernelType::kTfLiteLSTMBasicKernel) { - return absl::UnimplementedError( - "Only kTfLiteLSTMBasicKernel is supported."); + return UnimplementedError("Only kTfLiteLSTMBasicKernel is supported."); } if (tf_options->activation != kTfLiteActTanh) { - return absl::UnimplementedError("Only TANH activation is supported."); + return UnimplementedError("Only TANH activation is supported."); } if (tf_options->cell_clip != 0.0f) { - return absl::UnimplementedError("cell_clip is not supported."); + return UnimplementedError("cell_clip is not supported."); } if (tf_options->proj_clip != 0.0f) { - return absl::UnimplementedError("proj_clip is not supported."); + return UnimplementedError("proj_clip is not supported."); } - return absl::OkStatus(); + return OkStatus(); } }; class MulOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); if (tflite_node->inputs->size != 2) { - return absl::UnimplementedError("MUL requires two input tensors."); + return UnimplementedError("MUL requires two input tensors."); } TfLiteMulParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); return IsActivationSupported(tf_options->activation); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { // Determine runtime/constant tensors. const TfLiteTensor* input0 = reader->GetInputTensor(0); if (!input0) { - return absl::InvalidArgumentError( - "Couldn't get the 1st input tensor for MUL."); + return InvalidArgumentError("Couldn't get the 1st input tensor for MUL."); } const TfLiteTensor* input1 = reader->GetInputTensor(1); if (!input1) { - return absl::InvalidArgumentError( - "Couldn't get the 2nd input tensor for MUL."); + return InvalidArgumentError("Couldn't get the 2nd input tensor for MUL."); } const bool constant_tensor0 = IsConstantTensor(input0); const bool constant_tensor1 = IsConstantTensor(input1); if (constant_tensor0 && constant_tensor1) { - return absl::InvalidArgumentError("No runtime input tensors for MUL."); + return InvalidArgumentError("No runtime input tensors for MUL."); } const bool runtime_tensor0 = !constant_tensor0; const bool runtime_tensor1 = !constant_tensor1; @@ -1519,24 +1516,24 @@ class MulOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return absl::InternalError("Missing TfLiteMulParams"); + return InternalError("Missing TfLiteMulParams"); } return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node); } private: - absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1, - GraphFloat32* graph, ObjectReader* reader) { + Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1, + GraphFloat32* graph, ObjectReader* reader) { RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); return reader->AddOutputs(node); } - absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor, - int constant_tensor, - const TfLiteIntArray* constant_dims, - GraphFloat32* graph, ObjectReader* reader) { + Status ParseMultiplyScalar(Node* node, int runtime_tensor, + int constant_tensor, + const TfLiteIntArray* constant_dims, + GraphFloat32* graph, ObjectReader* reader) { RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); MultiplyAttributes attr; if (constant_dims->size <= 0) { @@ -1555,16 +1552,16 @@ class MulOperationParser : public TFLiteOperationParser { class PReLUOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); // TODO(eignasheva): add params check - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::PRELU); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1572,10 +1569,10 @@ class PReLUOperationParser : public TFLiteOperationParser { PReLUAttributes attr; Tensor linear_alpha; - absl::Status status = reader->ReadTensor(1, &linear_alpha); + Status status = reader->ReadTensor(1, &linear_alpha); if (status.ok()) { if (linear_alpha.shape.v != input_shape.c) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Linear alpha shape does not match the number of input channels."); } attr.alpha = std::move(linear_alpha); @@ -1585,8 +1582,7 @@ class PReLUOperationParser : public TFLiteOperationParser { if (hwc_alpha.shape.h != input_shape.h || hwc_alpha.shape.w != input_shape.w || hwc_alpha.shape.c != input_shape.c) { - return absl::InvalidArgumentError( - "Alpha shape does not match input shape."); + return InvalidArgumentError("Alpha shape does not match input shape."); } attr.alpha = std::move(hwc_alpha); } @@ -1599,15 +1595,15 @@ class PadOperationParser : public TFLiteOperationParser { public: explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {} - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { if (mirror_pad_) { auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (tf_options->mode != TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Only Reflective padding is supported for Mirror Pad operation."); } } @@ -1615,12 +1611,12 @@ class PadOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::PAD); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1638,15 +1634,14 @@ class PadOperationParser : public TFLiteOperationParser { // 4x2 tensor with paddings. if (paddings.shape.h != 4 || paddings.shape.w != 2) { - return absl::InvalidArgumentError( - "Paddings tensor has unexpected shape."); + return InvalidArgumentError("Paddings tensor has unexpected shape."); } attr.prepended = BHWC(paddings.data[0], paddings.data[2], paddings.data[4], paddings.data[6]); attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5], paddings.data[7]); node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } private: @@ -1655,9 +1650,9 @@ class PadOperationParser : public TFLiteOperationParser { class Pooling2DOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); TfLitePoolParams* tf_options = nullptr; auto status = RetrieveCustomInitialData(tflite_node, &tf_options); @@ -1680,9 +1675,9 @@ class Pooling2DOperationParser : public TFLiteOperationParser { public: explicit Pooling2DOperationParser(PoolingType type) : type_(type) {} - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::POOLING_2D); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1704,7 +1699,7 @@ class Pooling2DOperationParser : public TFLiteOperationParser { reinterpret_cast(tflite_node->builtin_data); } if (!tf_options) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } std::vector max_tensor_id{0}; @@ -1724,7 +1719,7 @@ class Pooling2DOperationParser : public TFLiteOperationParser { } RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr)); node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } private: @@ -1735,16 +1730,16 @@ class ReLUOperationParser : public TFLiteOperationParser { public: explicit ReLUOperationParser(int clip) : clip_(clip) {} - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::RELU); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1764,19 +1759,19 @@ class ReLUOperationParser : public TFLiteOperationParser { class ReshapeOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); // TODO(eignasheva): add shape checking - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::RESHAPE); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1789,7 +1784,7 @@ class ReshapeOperationParser : public TFLiteOperationParser { ReshapeAttributes attr; attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape; node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } }; @@ -1798,9 +1793,9 @@ class Resize2DOperationParser : public TFLiteOperationParser { explicit Resize2DOperationParser(SamplingType sampling_type) : sampling_type_(sampling_type) {} - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -1810,12 +1805,12 @@ class Resize2DOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners)); bool half_pixel_centers; RETURN_IF_ERROR(GetHalfPixelCentersValue(tflite_node, &half_pixel_centers)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::RESIZE); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -1831,12 +1826,12 @@ class Resize2DOperationParser : public TFLiteOperationParser { attr.new_shape.CopyAllDefinedAxis( graph->FindOutputs(node->id)[0]->tensor.shape); node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } private: - absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node, - bool* align_corners) { + Status GetAlignCornersValue(const TfLiteNode* tflite_node, + bool* align_corners) { switch (sampling_type_) { case SamplingType::BILINEAR: return GetAlignCornersValueForType( @@ -1845,62 +1840,61 @@ class Resize2DOperationParser : public TFLiteOperationParser { return GetAlignCornersValueForType( tflite_node, align_corners); case SamplingType::UNKNOWN: - return absl::InternalError("Sampling type is not specified"); + return InternalError("Sampling type is not specified"); } - return absl::OkStatus(); + return OkStatus(); } template - absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node, - bool* align_corners) { + Status GetAlignCornersValueForType(const TfLiteNode* tflite_node, + bool* align_corners) { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } *align_corners = tf_options->align_corners; - return absl::OkStatus(); + return OkStatus(); } - absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node, - bool* half_pixel_centers) { + Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node, + bool* half_pixel_centers) { if (sampling_type_ == SamplingType::BILINEAR) { const auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return absl::InternalError( - "Missing tflite params for ResizeBilinear op"); + return InternalError("Missing tflite params for ResizeBilinear op"); } if (tf_options->align_corners && tf_options->half_pixel_centers) { - return absl::InternalError( + return InternalError( "If half_pixel_centers is True, align_corners must be False."); } *half_pixel_centers = tf_options->half_pixel_centers; } else { *half_pixel_centers = false; } - return absl::OkStatus(); + return OkStatus(); } - absl::Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node) { + Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node) { const auto* input = context->tensors + tflite_node->inputs->data[0]; const auto* output = context->tensors + tflite_node->outputs->data[0]; if (!input->dims || input->dims->size != 4) { - return absl::InvalidArgumentError("input.dims.size != 4"); + return InvalidArgumentError("input.dims.size != 4"); } if (!output->dims || output->dims->size != 4) { - return absl::InvalidArgumentError("output.dims.size != 4"); + return InvalidArgumentError("output.dims.size != 4"); } if (output->dims->data[1] < input->dims->data[1] || output->dims->data[2] < input->dims->data[2]) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "Only upsampling is supported, received output h,w = ", output->dims->data[1], ",", output->dims->data[2], " input h,w = ", input->dims->data[1], ",", input->dims->data[2])); } - return absl::OkStatus(); + return OkStatus(); } SamplingType sampling_type_ = SamplingType::UNKNOWN; @@ -1908,16 +1902,16 @@ class Resize2DOperationParser : public TFLiteOperationParser { class SliceOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::SLICE); RETURN_IF_ERROR(reader->AddOutputs(node)); @@ -1931,7 +1925,7 @@ class SliceOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->ReadTensor(1, &starts)); RETURN_IF_ERROR(reader->ReadTensor(2, &sizes)); if (starts.data.size() != sizes.data.size()) { - return absl::InvalidArgumentError("Starts amount != sizes amount."); + return InvalidArgumentError("Starts amount != sizes amount."); } if (starts.data.size() == 4) { attr.starts = @@ -1945,31 +1939,30 @@ class SliceOperationParser : public TFLiteOperationParser { BHWC(input->tensor.shape.b, starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]); } else { - return absl::UnimplementedError( + return UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); } RETURN_IF_ERROR(UpdateIfNegative(input->tensor.shape, &attr)); auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; if ((attr.ends.b - attr.starts.b) != out_shape.b) { - return absl::UnimplementedError("Output batch don't match"); + return UnimplementedError("Output batch don't match"); } if ((attr.ends.h - attr.starts.h) != out_shape.h) { - return absl::UnimplementedError("Output height doesn't match"); + return UnimplementedError("Output height doesn't match"); } if ((attr.ends.w - attr.starts.w) != out_shape.w) { - return absl::UnimplementedError("Output width doesn't match"); + return UnimplementedError("Output width doesn't match"); } if ((attr.ends.c - attr.starts.c) != out_shape.c) { - return absl::UnimplementedError("Output channels don't match"); + return UnimplementedError("Output channels don't match"); } node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } private: - absl::Status UpdateIfNegative(const BHWC& input_shape, - SliceAttributes* attr) { + Status UpdateIfNegative(const BHWC& input_shape, SliceAttributes* attr) { if (attr->ends.h < 0) { attr->ends.h = input_shape.h + attr->ends.h; } @@ -1982,15 +1975,15 @@ class SliceOperationParser : public TFLiteOperationParser { if (attr->ends.b < 0) { attr->ends.b = input_shape.b + attr->ends.b; } - return absl::OkStatus(); + return OkStatus(); } }; class SoftmaxOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -1998,14 +1991,14 @@ class SoftmaxOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->beta != 1) { // TODO(eignasheva): figure out, what's wrong with softmax. - return absl::UnimplementedError("Softmax.beta != 1 is not supported."); + return UnimplementedError("Softmax.beta != 1 is not supported."); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::SOFTMAX); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2014,27 +2007,27 @@ class SoftmaxOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast(tflite_node->builtin_data); if (!tf_options) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } if (tf_options->beta != 1) { // there is multiply by scalar operation fused in softmax. Make a layer // out of it before softmax. - return absl::UnimplementedError("Softmax.beta != 1 is not supported."); + return UnimplementedError("Softmax.beta != 1 is not supported."); // auto mul_node = reader->NewPassthroughNode(node); // mul_node->operation.type = ToString(OperationType::MUL); } SoftmaxAttributes attr; attr.axis = Axis::CHANNELS; // always by channels node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } }; class SpaceToDepthOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); @@ -2042,19 +2035,17 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser { TfLiteSpaceToDepthParams* s2d_params = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params)); if (s2d_params->block_size == 1) { - return absl::InvalidArgumentError( - "SPACE_TO_DEPTH block_size = 1 is a no-op."); + return InvalidArgumentError("SPACE_TO_DEPTH block_size = 1 is a no-op."); } if (s2d_params->block_size < 1) { - return absl::InvalidArgumentError( - "SPACE_TO_DEPTH block_size must be > 1."); + return InvalidArgumentError("SPACE_TO_DEPTH block_size must be > 1."); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::SPACE_TO_DEPTH); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2064,25 +2055,25 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser { SpaceToDepthAttributes attr; attr.block_size = tf_options->block_size; node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } }; class StridedSliceOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); TfLiteStridedSliceParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::SLICE); RETURN_IF_ERROR(reader->AddOutputs(node)); @@ -2096,7 +2087,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser { bool read_without_batch = tmp.data.size() == 3; bool read_with_batch = tmp.data.size() == 4; if (!read_without_batch && !read_with_batch) { - return absl::UnimplementedError( + return UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); } @@ -2104,7 +2095,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser { tflite_node->builtin_data); auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; if (!tf_options) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); @@ -2119,37 +2110,36 @@ class StridedSliceOperationParser : public TFLiteOperationParser { } if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 || attr.strides.c == 0) { - return absl::InvalidArgumentError("stride values must be non-zero"); + return InvalidArgumentError("stride values must be non-zero"); } if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 || attr.strides.c < 0) { - return absl::UnimplementedError("Reverse slices are not supported."); + return UnimplementedError("Reverse slices are not supported."); } if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b != out_shape.b) { - return absl::UnimplementedError("Output batch don't match"); + return UnimplementedError("Output batch don't match"); } if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h != out_shape.h) { - return absl::UnimplementedError("Output height doesn't match"); + return UnimplementedError("Output height doesn't match"); } if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w != out_shape.w) { - return absl::UnimplementedError("Output width doesn't match"); + return UnimplementedError("Output width doesn't match"); } if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c != out_shape.c) { - return absl::UnimplementedError("Output channels don't match"); + return UnimplementedError("Output channels don't match"); } node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } private: - absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options, - const BHWC& input_shape, int ignore_b, - int ignore_h, int ignore_w, int ignore_c, - SliceAttributes* attr) { + Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, int ignore_b, int ignore_h, + int ignore_w, int ignore_c, SliceAttributes* attr) { if (tf_options->begin_mask & ignore_h) { attr->starts.h = 0; } @@ -2175,11 +2165,10 @@ class StridedSliceOperationParser : public TFLiteOperationParser { if (tf_options->end_mask & ignore_b) { attr->ends.b = input_shape.b; } - return absl::OkStatus(); + return OkStatus(); } - absl::Status UpdateIfNegative(const BHWC& input_shape, - SliceAttributes* attr) { + Status UpdateIfNegative(const BHWC& input_shape, SliceAttributes* attr) { if (attr->ends.h < 0) { attr->ends.h = input_shape.h + attr->ends.h; } @@ -2192,18 +2181,17 @@ class StridedSliceOperationParser : public TFLiteOperationParser { if (attr->ends.b < 0) { attr->ends.b = input_shape.b + attr->ends.b; } - return absl::OkStatus(); + return OkStatus(); } - absl::Status ReadAttribsWithBatch(const ObjectReader* reader, - const TfLiteStridedSliceParams* tf_options, - const BHWC& input_shape, - SliceAttributes* attr) { - auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status { + Status ReadAttribsWithBatch(const ObjectReader* reader, + const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, SliceAttributes* attr) { + auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> Status { Tensor t; RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]); - return absl::OkStatus(); + return OkStatus(); }; RETURN_IF_ERROR(read_bhwc(1, &attr->starts)); @@ -2211,17 +2199,18 @@ class StridedSliceOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(read_bhwc(3, &attr->strides)); RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr)); RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status ReadAttribsWithoutBatch( - const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options, - const BHWC& input_shape, SliceAttributes* attr) { - auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status { + Status ReadAttribsWithoutBatch(const ObjectReader* reader, + const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, + SliceAttributes* attr) { + auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> Status { Tensor t; RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]); - return absl::OkStatus(); + return OkStatus(); }; RETURN_IF_ERROR(read_hwc(1, &attr->starts)); @@ -2232,43 +2221,43 @@ class StridedSliceOperationParser : public TFLiteOperationParser { attr->starts.b = 0; attr->ends.b = input_shape.b; attr->strides.b = 1; - return absl::OkStatus(); + return OkStatus(); } - absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) { + Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) { if (tf_options->ellipsis_mask) { - return absl::UnimplementedError("Slice does not support ellipsis_mask."); + return UnimplementedError("Slice does not support ellipsis_mask."); } if (tf_options->new_axis_mask) { - return absl::UnimplementedError("Slice does not support new_axis_mask."); + return UnimplementedError("Slice does not support new_axis_mask."); } if (tf_options->shrink_axis_mask) { - return absl::UnimplementedError( + return UnimplementedError( "Slice does not support shrink_axis_mask parameter. "); } - return absl::OkStatus(); + return OkStatus(); } }; class TransposeConvOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); TfLiteTransposeConvParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR( CheckStrides(tf_options->stride_height, tf_options->stride_width)); - return absl::OkStatus(); + return OkStatus(); } // TFLite's TRANSPOSE_CONV expects 3 input (output shape, weights, and input) // and allows configurable padding & stride. // TODO(impjdi): Translate output_shape to attr.adjacent. - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); Value>* input; @@ -2279,7 +2268,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->builtin_data); if (!tf_options) { - return absl::InternalError("Missing tflite options."); + return InternalError("Missing tflite options."); } ConvolutionTransposedAttributes attr; attr.stride = tf_options @@ -2292,24 +2281,24 @@ class TransposeConvOperationParser : public TFLiteOperationParser { UpdatePadding(tf_options->padding, graph->FindInputs(node->id)[0]->tensor.shape, &attr); node->operation.attributes = std::move(attr); - return absl::OkStatus(); + return OkStatus(); } }; class TransposeOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::TRANSPOSE); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2325,20 +2314,19 @@ class TransposeOperationParser : public TFLiteOperationParser { } else if (perm.data.size() == 2) { attr.perm = BHWC(0, 1, perm.data[0] + 2, perm.data[1] + 2); } else { - return absl::InvalidArgumentError( - "Permutation for transpose is invalid."); + return InvalidArgumentError("Permutation for transpose is invalid."); } node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } }; class Unpooling2DOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { TfLitePoolParams* tf_options = nullptr; RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/2, /*outputs=*/1)); @@ -2346,12 +2334,12 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckKernelsAndStrides( tf_options->filter_height, tf_options->filter_width, tf_options->stride_height, tf_options->stride_width)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2362,7 +2350,7 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { const auto* tf_options = reinterpret_cast( tflite_node->custom_initial_data); if (!tf_options) { - return absl::InternalError("Missing tflite params"); + return InternalError("Missing tflite params"); } attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width); attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); @@ -2372,22 +2360,22 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { auto output_value = graph->FindOutputs(node->id)[0]; output_value->tensor.shape = CalculateOutputShape(input_shape, attr); - return absl::OkStatus(); + return OkStatus(); } }; // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported. class BatchToSpaceOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return absl::OkStatus(); + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::BATCH_TO_SPACE); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2397,7 +2385,7 @@ class BatchToSpaceOperationParser : public TFLiteOperationParser { Tensor block; RETURN_IF_ERROR(reader->ReadTensor(1, &block)); if (block.shape.v != 2) { - return absl::InternalError("Space has to be HxW."); + return InternalError("Space has to be HxW."); } bs_attr.block.h = block.data[0]; bs_attr.block.w = block.data[1]; @@ -2406,7 +2394,7 @@ class BatchToSpaceOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->ReadTensor(2, &crop)); auto crop_shape = crop.shape; if (crop_shape.h != 2 && crop_shape.w != 2) { - return absl::InternalError("Space has to be HxW."); + return InternalError("Space has to be HxW."); } bs_attr.crop.prepended.h = crop.data[0]; @@ -2416,21 +2404,21 @@ class BatchToSpaceOperationParser : public TFLiteOperationParser { bs_attr.crop.appended.w = crop.data[3]; node->operation.attributes = std::move(bs_attr); - return absl::OkStatus(); + return OkStatus(); } }; class SpaceToBatchOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return absl::OkStatus(); + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::SPACE_TO_BATCH); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2439,7 +2427,7 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser { Tensor block; RETURN_IF_ERROR(reader->ReadTensor(1, &block)); if (block.shape.v != 2) { - return absl::InternalError("Space has to be HxW."); + return InternalError("Space has to be HxW."); } sb_attr.block.h = block.data[0]; sb_attr.block.w = block.data[1]; @@ -2449,7 +2437,7 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser { auto padding_shape = padding.shape; if (padding_shape.h != 2 && padding_shape.w != 2) { - return absl::InternalError("Space has to be HxW."); + return InternalError("Space has to be HxW."); } sb_attr.padding.prepended.h = padding.data[0]; @@ -2459,23 +2447,23 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser { sb_attr.padding.appended.w = padding.data[3]; node->operation.attributes = std::move(sb_attr); - return absl::OkStatus(); + return OkStatus(); } }; class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox RETURN_IF_ERROR(reader->AddOutputs(node)); @@ -2490,7 +2478,7 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { auto output_value = graph->FindOutputs(node->id)[0]; output_value->tensor.shape = output_shape; - return absl::OkStatus(); + return OkStatus(); } private: @@ -2498,17 +2486,17 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { class TransformTensorOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/2, /*outputs=*/1)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); // data RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox @@ -2527,7 +2515,7 @@ class TransformTensorOperationParser : public TFLiteOperationParser { output_value->tensor.shape = BHWC(1, output_shape.h, output_shape.w, graph->FindInputs(node->id)[0]->tensor.shape.c); - return absl::OkStatus(); + return OkStatus(); } private: @@ -2535,17 +2523,17 @@ class TransformTensorOperationParser : public TFLiteOperationParser { class TransformLandmarksOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/2, /*outputs=*/1)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); // data RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox @@ -2561,7 +2549,7 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { auto output_value = graph->FindOutputs(node->id)[0]; output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape; - return absl::OkStatus(); + return OkStatus(); } private: @@ -2569,16 +2557,16 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix @@ -2593,7 +2581,7 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { auto output_value = graph->FindOutputs(node->id)[0]; output_value->tensor.shape = output_shape; - return absl::OkStatus(); + return OkStatus(); } private: @@ -2601,16 +2589,16 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { class MeanOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { auto* node = graph->NewNode(); node->operation.type = ToString(OperationType::MEAN); RETURN_IF_ERROR(reader->AddInput(node, 0)); @@ -2635,27 +2623,27 @@ class MeanOperationParser : public TFLiteOperationParser { unsupported = unsupported.empty() ? "channels" : unsupported; ABSL_FALLTHROUGH_INTENDED; default: - return absl::UnimplementedError( + return UnimplementedError( absl::StrCat("Unsupported mean dimension: ", unsupported)); } } node->operation.attributes = attr; - return absl::OkStatus(); + return OkStatus(); } }; class UnsupportedOperationParser : public TFLiteOperationParser { public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return absl::UnimplementedError("Operation is not supported."); + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return UnimplementedError("Operation is not supported."); } - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - return absl::UnimplementedError("Operation is not supported."); + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + return UnimplementedError("Operation is not supported."); } }; @@ -2784,15 +2772,15 @@ std::unique_ptr NewOperationParser( return absl::make_unique(); } -absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id, - TfLiteNode** tflite_node, - TfLiteRegistration** registration) { +Status GetNodeAndRegistration(TfLiteContext* context, int node_id, + TfLiteNode** tflite_node, + TfLiteRegistration** registration) { if (context->GetNodeAndRegistration(context, node_id, tflite_node, registration) != kTfLiteOk) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "Couldn't get node and registration info for op: ", node_id)); } - return absl::OkStatus(); + return OkStatus(); } using IsNodeSupportedFn = tflite::delegates::IsNodeSupportedFn; @@ -2975,8 +2963,8 @@ class GraphWithDequantPartitionHelper std::set dequant_nodes_to_save_; }; -absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node, - const TfLiteRegistration* registration) { +Status IsSupported(const TfLiteContext* context, TfLiteNode* node, + const TfLiteRegistration* registration) { return NewOperationParser(registration) ->IsSupported(context, node, registration); } @@ -2995,8 +2983,8 @@ bool IsAllFloatTensors(const TfLiteContext* context, } } // namespace -absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, - TensorRef* tensor_ref) { +Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, + TensorRef* tensor_ref) { tensor_ref->type = ToDataType(tflite_tensor.type); return ExtractTensorShape(tflite_tensor, &tensor_ref->shape); } @@ -3010,9 +2998,7 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { std::string* unsupported_details) -> bool { const auto status = IsSupported(context, node, registration); if (!status.ok()) { - if (unsupported_details) { - *unsupported_details = std::string(status.message()); - } + if (unsupported_details) *unsupported_details = status.error_message(); return false; } @@ -3062,9 +3048,9 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { return ConvertVectorToTfLiteIntArray(ops_to_replace); } -absl::Status BuildModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph) { +Status BuildModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph) { std::vector> operations; std::vector tflite_nodes; for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { @@ -3079,7 +3065,7 @@ absl::Status BuildModel(TfLiteContext* context, } auto op_parser = NewOperationParser(registration); if (!op_parser) { - return absl::UnimplementedError( + return UnimplementedError( absl::StrCat("Operation ", registration->builtin_code, "(", registration->custom_name, ") is not supported by TFLite GPU Delegate.")); @@ -3099,25 +3085,25 @@ absl::Status BuildModel(TfLiteContext* context, const auto status = operations[i]->Parse(tflite_node, registration, graph, &reader); if (!status.ok()) { - return absl::InternalError(absl::StrCat( - GetOpNameByRegistration(*registration), ": ", status.message())); + return InternalError(absl::StrCat(GetOpNameByRegistration(*registration), + ": ", status.error_message())); } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status BuildFinalModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph) { +Status BuildFinalModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph) { RETURN_IF_ERROR(BuildModel(context, delegate_params, graph)); // Apply general transformations on the graph. NullTransformationReporter reporter; ModelTransformer transformer(graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + return InternalError("Graph general transformations failed"); } - return absl::OkStatus(); + return OkStatus(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h index b8fcab0c5c8..f81dd90933c 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -32,19 +32,19 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context); // Extracts TFLite delegate execution plan from the input TFLite context and // converts it into generic graph format. -absl::Status BuildModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph); +Status BuildModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph); // Same as above but also apply all transformations on the final graph. // Prefer using this method instead of BuildModel. -absl::Status BuildFinalModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph); +Status BuildFinalModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph); // Module-internal converter, exposed for unit testing purpose only. -absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, - TensorRef* tensor_ref); +Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, + TensorRef* tensor_ref); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc index 771ed7378b9..b20b24d28c3 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.cc +++ b/tensorflow/lite/delegates/gpu/common/operations.cc @@ -519,15 +519,14 @@ BHWC CalculateOutputShape(const BHWC& input, const MeanAttributes& attr) { return BHWC(b, h, w, c); } -absl::Status CalculateOutputShape(const std::vector& input, - const ConcatAttributes& attr, - BHWC* output_shape) { +Status CalculateOutputShape(const std::vector& input, + const ConcatAttributes& attr, BHWC* output_shape) { BHWC new_shape = input[0]; switch (attr.axis) { case Axis::CHANNELS: for (int i = 1; i < input.size(); i++) { if (input[i].h != new_shape.h || input[i].w != new_shape.w) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Height and Width must be the same when concatenating " "by channels axis"); } @@ -537,7 +536,7 @@ absl::Status CalculateOutputShape(const std::vector& input, case Axis::HEIGHT: for (int i = 1; i < input.size(); i++) { if (input[i].w != new_shape.w || input[i].c != new_shape.c) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Channels and Width must be the same when concatenating " "by height axis"); } @@ -547,7 +546,7 @@ absl::Status CalculateOutputShape(const std::vector& input, case Axis::WIDTH: for (int i = 1; i < input.size(); i++) { if (input[i].h != new_shape.h || input[i].c != new_shape.c) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Height and Channels must be the same when concatenating " "by width axis"); } @@ -555,11 +554,11 @@ absl::Status CalculateOutputShape(const std::vector& input, } break; default: - return absl::InvalidArgumentError("Invalid axis"); + return InvalidArgumentError("Invalid axis"); break; } *output_shape = new_shape; - return absl::OkStatus(); + return OkStatus(); } Padding2D CalculateSamePadding(const BHWC& input, diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h index 4eb41dfe1a3..16016d334cf 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.h +++ b/tensorflow/lite/delegates/gpu/common/operations.h @@ -202,9 +202,8 @@ BHWDC CalculateOutputShape(const BHWDC& input, const Pooling3DAttributes& attr); // @return shape of a tensor after Concat operation is applied to the given // input. -absl::Status CalculateOutputShape(const std::vector& input, - const ConcatAttributes& attr, - BHWC* output_shape); +Status CalculateOutputShape(const std::vector& input, + const ConcatAttributes& attr, BHWC* output_shape); // @return padding for pooling operation to make sure output keep the same shape // as the given input. diff --git a/tensorflow/lite/delegates/gpu/common/status.h b/tensorflow/lite/delegates/gpu/common/status.h index d6b5dd8a94a..250a3b5e3eb 100644 --- a/tensorflow/lite/delegates/gpu/common/status.h +++ b/tensorflow/lite/delegates/gpu/common/status.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,109 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ -#include "absl/status/status.h" -#define RETURN_IF_ERROR(s) {auto c=(s);if(!c.ok())return c;} +#include + +namespace tflite { +namespace gpu { + +enum class StatusCode { + kOk = 0, + kCancelled = 1, + kUnknown = 2, + kInvalidArgument = 3, + kDeadlineExceeded = 4, + kNotFound = 5, + kAlreadyExists = 6, + kPermissionDenied = 7, + kResourceExhausted = 8, + kFailedPrecondition = 9, + kAborted = 10, + kOutOfRange = 11, + kUnimplemented = 12, + kInternal = 13, + kUnavailable = 14, + kDataLoss = 15, + kUnauthenticated = 16, + kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 +}; + +// Lite version of Status without dependency on protobuf. +// TODO(b/128867901): Migrate to absl::Status. +class Status { + public: + Status() = default; + Status(StatusCode code) : code_(code) {} + Status(StatusCode code, const std::string& error_message) + : code_(code), error_message_(error_message) {} + + const std::string& error_message() const { return error_message_; } + StatusCode code() const { return code_; } + bool ok() const { return code_ == StatusCode::kOk; } + + void IgnoreError() const {} + + private: + StatusCode code_ = StatusCode::kOk; + std::string error_message_; +}; + +#define RETURN_IF_ERROR(status) \ + { \ + const auto status2 = (status); \ + if (!status2.ok()) return status2; \ + } + +inline Status OkStatus() { return Status(); } + +inline Status AlreadyExistsError(const std::string& message) { + return Status(StatusCode::kAlreadyExists, message); +} + +inline Status DeadlineExceededError(const std::string& message) { + return Status(StatusCode::kDeadlineExceeded, message); +} + +inline Status FailedPreconditionError(const std::string& message) { + return Status(StatusCode::kFailedPrecondition, message); +} + +inline Status InternalError(const std::string& message) { + return Status(StatusCode::kInternal, message); +} + +inline Status InvalidArgumentError(const std::string& message) { + return Status(StatusCode::kInvalidArgument, message); +} + +inline Status NotFoundError(const std::string& message) { + return Status(StatusCode::kNotFound, message); +} + +inline Status OutOfRangeError(const std::string& message) { + return Status(StatusCode::kOutOfRange, message); +} + +inline Status PermissionDeniedError(const std::string& message) { + return Status(StatusCode::kPermissionDenied, message); +} + +inline Status ResourceExhaustedError(const std::string& message) { + return Status(StatusCode::kResourceExhausted, message); +} + +inline Status UnavailableError(const std::string& message) { + return Status(StatusCode::kUnavailable, message); +} + +inline Status UnimplementedError(const std::string& message) { + return Status(StatusCode::kUnimplemented, message); +} + +inline Status UnknownError(const std::string& message) { + return Status(StatusCode::kUnknown, message); +} + +} // namespace gpu +} // namespace tflite #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc index 08d9448f7e5..cbd62fa6853 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc @@ -30,21 +30,21 @@ namespace tflite { namespace gpu { namespace testing { -absl::Status InterpreterInvokeWithOpResolver( - const ::tflite::Model* model, TfLiteDelegate* delegate, - const OpResolver& op_resolver, const std::vector& inputs, - std::vector* outputs) { +Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, + TfLiteDelegate* delegate, + const OpResolver& op_resolver, + const std::vector& inputs, + std::vector* outputs) { auto interpreter = absl::make_unique(); if (InterpreterBuilder(model, op_resolver)(&interpreter) != kTfLiteOk) { - return absl::InternalError("Unable to create TfLite InterpreterBuilder"); + return InternalError("Unable to create TfLite InterpreterBuilder"); } if (delegate && interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) { - return absl::InternalError( - "Unable to modify TfLite graph with the delegate"); + return InternalError("Unable to modify TfLite graph with the delegate"); } interpreter->SetNumThreads(1); if (interpreter->AllocateTensors() != kTfLiteOk) { - return absl::InternalError("Unable to allocate TfLite tensors"); + return InternalError("Unable to allocate TfLite tensors"); } for (int i = 0; i < inputs.size(); ++i) { DCHECK_EQ(interpreter->tensor(interpreter->inputs()[i])->type, @@ -57,10 +57,10 @@ absl::Status InterpreterInvokeWithOpResolver( inputs[i].data.size() * sizeof(float)); } if (interpreter->Invoke() != kTfLiteOk) { - return absl::InternalError("Unable to invoke TfLite interpreter"); + return InternalError("Unable to invoke TfLite interpreter"); } if (!outputs || !outputs->empty()) { - return absl::InternalError("Invalid outputs pointer"); + return InternalError("Invalid outputs pointer"); } outputs->reserve(interpreter->outputs().size()); for (auto t : interpreter->outputs()) { @@ -69,7 +69,7 @@ absl::Status InterpreterInvokeWithOpResolver( bhwc.id = t; // TODO(impjdi) Relax this condition to arbitrary batch size. if (out_tensor->dims->data[0] != 1) { - return absl::InternalError("Batch dimension is expected to be 1"); + return InternalError("Batch dimension is expected to be 1"); } bhwc.shape.b = out_tensor->dims->data[0]; switch (out_tensor->dims->size) { @@ -89,21 +89,20 @@ absl::Status InterpreterInvokeWithOpResolver( bhwc.shape.c = out_tensor->dims->data[3]; break; default: - return absl::InternalError("Unsupported dimensions size " + - std::to_string(out_tensor->dims->size)); + return InternalError("Unsupported dimensions size " + + std::to_string(out_tensor->dims->size)); } bhwc.data = std::vector( out_tensor->data.f, out_tensor->data.f + out_tensor->bytes / sizeof(float)); outputs->push_back(bhwc); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status InterpreterInvoke(const ::tflite::Model* model, - TfLiteDelegate* delegate, - const std::vector& inputs, - std::vector* outputs) { +Status InterpreterInvoke(const ::tflite::Model* model, TfLiteDelegate* delegate, + const std::vector& inputs, + std::vector* outputs) { ops::builtin::BuiltinOpResolver builtin_op_resolver; return InterpreterInvokeWithOpResolver(model, delegate, builtin_op_resolver, inputs, outputs); diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h index ca2825b7563..a38a5d1363a 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h +++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h @@ -31,18 +31,18 @@ namespace testing { // Runs Tensorflow Lite model using Tensorflow Lite with a delegate and // an appropriate operations resolver. If delegate is nullptr, inference will // be done only on CPU. -absl::Status InterpreterInvokeWithOpResolver( - const ::tflite::Model* model, TfLiteDelegate* delegate, - const OpResolver& op_resolver, const std::vector& inputs, - std::vector* outputs); +Status InterpreterInvokeWithOpResolver(const ::tflite::Model* model, + TfLiteDelegate* delegate, + const OpResolver& op_resolver, + const std::vector& inputs, + std::vector* outputs); // Runs Tensorflow Lite model using Tensorflow Lite with a delegate and // builtin operations resolver. If delegate is nullptr, inference will // be done only on CPU. -absl::Status InterpreterInvoke(const ::tflite::Model* model, - TfLiteDelegate* delegate, - const std::vector& inputs, - std::vector* outputs); +Status InterpreterInvoke(const ::tflite::Model* model, TfLiteDelegate* delegate, + const std::vector& inputs, + std::vector* outputs); } // namespace testing } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc index 0011cc24dfa..872c4bcd903 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc @@ -61,7 +61,7 @@ class AddQuantAdjustments : public NodeTransformation { // The tensor information should rename the same. Value>* adjusted_value = graph->NewValue(); adjusted_value->tensor = output_value->tensor; - absl::Status status = + Status status = graph->SetProducer(quant_and_dequant_node->id, adjusted_value->id); if (!status.ok()) { return {TransformStatus::INVALID, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc index 4efb98a6847..586c7a34a37 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -81,11 +81,11 @@ class MergeConvolutionWithAdd : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } - absl::Status status = RemoveFollowingNode(graph, &add_node, &conv_node); + Status status = RemoveFollowingNode(graph, &add_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove add node after convolution: " + - std::string(status.message())}; + status.error_message()}; } return {TransformStatus::APPLIED, ""}; } @@ -131,11 +131,11 @@ class MergeAddWithConvolution : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } - absl::Status status = RemovePrecedingNode(graph, &add_node, &conv_node); + Status status = RemovePrecedingNode(graph, &add_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove add node after convolution: " + - std::string(status.message())}; + status.error_message()}; } return {TransformStatus::APPLIED, ""}; } diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc index 055327d3534..6b106a4be62 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc @@ -74,11 +74,11 @@ class MergeConvolutionWithMul : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } - absl::Status status = RemoveFollowingNode(graph, &mul_node, &conv_node); + Status status = RemoveFollowingNode(graph, &mul_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove mul node after convolution: " + - std::string(status.message())}; + status.error_message()}; } return {TransformStatus::APPLIED, ""}; } @@ -134,11 +134,11 @@ class MergeMulWithConvolution : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } - absl::Status status = RemovePrecedingNode(graph, &mul_node, &conv_node); + Status status = RemovePrecedingNode(graph, &mul_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove mul node after convolution: " + - std::string(status.message())}; + status.error_message()}; } return {TransformStatus::APPLIED, ""}; } diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc index 17aac83baf7..5e98edac943 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc @@ -76,10 +76,10 @@ class MakePaddingFromZerosConcat : public NodeTransformation { "Padding for concat axis is unsupported: " + ToString(concat_attr.axis)}; } - absl::Status status = RemovePrecedingNode(graph, dep, node); + Status status = RemovePrecedingNode(graph, dep, node); if (!status.ok()) { - return {TransformStatus::INVALID, "Unable to remove const node: " + - std::string(status.message())}; + return {TransformStatus::INVALID, + "Unable to remove const node: " + status.error_message()}; } node->operation.attributes = pad_attr; node->operation.type = ToString(OperationType::PAD); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc index f1c56477834..5257ba44f0e 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc @@ -72,7 +72,7 @@ class MatchDilatedConvolution : public SequenceTransformation { conv_node.operation.attributes = std::move(conv2d_attr); } - absl::Status status = RemoveFollowingNode(graph, &bs_node, &conv_node); + Status status = RemoveFollowingNode(graph, &bs_node, &conv_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove batch_to_space node after convolution."}; diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc index 23e99bc3305..5e2f1e17f54 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc @@ -62,11 +62,11 @@ class MergePaddingWith2DOperation : public SequenceTransformation { } Attr* node_attr = absl::any_cast(&op_node->operation.attributes); - absl::Status status = RemovePrecedingNode(graph, pad_node, op_node); + Status status = RemovePrecedingNode(graph, pad_node, op_node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove Pad node with Operation node: " + - std::string(status.message())}; + status.error_message()}; } node_attr->padding.appended.h += pad_attr.appended.h; @@ -154,10 +154,10 @@ class MergePaddingWithAddOperation : public NodeTransformation { "Cannot remove padding when this broadcast/scalar ADD"}; } - absl::Status status = RemovePrecedingNode(graph, node, add_node); + Status status = RemovePrecedingNode(graph, node, add_node); if (!status.ok()) { return {TransformStatus::INVALID, - "Unable to remove Pad node " + std::string(status.message())}; + "Unable to remove Pad node " + status.error_message()}; } return {TransformStatus::APPLIED, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc index e80b244b34f..64779990178 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc @@ -44,10 +44,10 @@ class RemoveOperation : public SequenceTransformation { if (!remove_predicate_(graph, op_node)) { return {TransformStatus::SKIPPED, ""}; } - absl::Status status = RemoveFollowingNode(graph, op_node, prev_op_node); + Status status = RemoveFollowingNode(graph, op_node, prev_op_node); if (!status.ok()) { return {TransformStatus::INVALID, - "Unable to remove a node: " + std::string(status.message())}; + "Unable to remove a node: " + status.error_message()}; } return {TransformStatus::APPLIED, ""}; } @@ -116,10 +116,10 @@ class RemoveIdentityReshape : public NodeTransformation { return {TransformStatus::SKIPPED, "Can not apply transformation when node output is graph output"}; } - absl::Status status = RemoveOneInputOneOutputNode(graph, node); + Status status = RemoveOneInputOneOutputNode(graph, node); if (!status.ok()) { return {TransformStatus::INVALID, - "Unable to remove a node: " + std::string(status.message())}; + "Unable to remove a node: " + status.error_message()}; } return {TransformStatus::APPLIED, "Removed reshape with input_shape == output_shape."}; diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc index d18e3726a1c..d6d22aa6a62 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc @@ -184,9 +184,10 @@ template std::vector GenerateWorkGroupSizes( WorkGroupSizeAlignment z_alignment); template -absl::Status GenerateWorkGroupSizesAlignedToGrid( - const T& grid, const T& max_work_group_size, - const int max_work_group_invocations, std::vector* work_groups) { +Status GenerateWorkGroupSizesAlignedToGrid(const T& grid, + const T& max_work_group_size, + const int max_work_group_invocations, + std::vector* work_groups) { auto alignment = WorkGroupSizeAlignment::PRECISE; *work_groups = GenerateWorkGroupSizes( grid, /*min_work_group_total_size = */ 32, max_work_group_invocations, @@ -196,16 +197,16 @@ absl::Status GenerateWorkGroupSizesAlignedToGrid( AddCornerCases(grid, max_work_group_invocations, max_work_group_size, alignment, alignment, alignment, work_groups); } - return absl::OkStatus(); + return OkStatus(); } // Specializations of GenerateWorkGroupSizesAlignedToGrid for int3 and uint3 -template absl::Status GenerateWorkGroupSizesAlignedToGrid( +template Status GenerateWorkGroupSizesAlignedToGrid( const int3& grid, const int3& max_work_group_size, const int max_work_group_invocations, std::vector* work_groups); -template absl::Status GenerateWorkGroupSizesAlignedToGrid( +template Status GenerateWorkGroupSizesAlignedToGrid( const uint3& grid, const uint3& max_work_group_size, const int max_work_group_invocations, std::vector* work_groups); diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h index 75967cb04df..80915ff5c95 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h @@ -42,9 +42,10 @@ std::vector GenerateWorkGroupSizes( WorkGroupSizeAlignment y_alignment, WorkGroupSizeAlignment z_alignment); template -absl::Status GenerateWorkGroupSizesAlignedToGrid( - const T& grid, const T& max_work_group_size, - const int max_work_group_invocations, std::vector* work_groups); +Status GenerateWorkGroupSizesAlignedToGrid(const T& grid, + const T& max_work_group_size, + const int max_work_group_invocations, + std::vector* work_groups); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 3451119c71d..452f81f536d 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -70,8 +70,8 @@ class Delegate { options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default(); } - absl::Status Prepare(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params) { + Status Prepare(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { thread_id_prepare_ = std::this_thread::get_id(); // Extract TFLite delegate execution plan from the context and convert it @@ -98,10 +98,9 @@ class Delegate { std::unique_ptr builder; bool graph_is_destroyed; - absl::Status status = - InitializeOpenClApi(&graph, &builder, &graph_is_destroyed); + Status status = InitializeOpenClApi(&graph, &builder, &graph_is_destroyed); if (!status.ok()) { - TF_LITE_KERNEL_LOG(context, std::string(status.message()).c_str()); + context->ReportError(context, "%s", status.error_message().c_str()); context->ReportError(context, "Falling back to OpenGL"); // Graph need to be re-created because it is moved above. @@ -133,7 +132,7 @@ class Delegate { return builder->Build(&runner_); } - absl::Status SetInputsAndOutputs(TfLiteContext* context) { + Status SetInputsAndOutputs(TfLiteContext* context) { int i = 0; for (auto index : input_indices_) { RETURN_IF_ERROR( @@ -144,15 +143,15 @@ class Delegate { RETURN_IF_ERROR( runner_->SetOutputObject(i++, GetTensorObject(index, context))); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status Invoke(TfLiteContext* context) { + Status Invoke(TfLiteContext* context) { if (thread_id_prepare_ != std::this_thread::get_id()) { TFLITE_LOG(tflite::TFLITE_LOG_WARNING, "GpuDelegate invoke thread != prepare thread"); if (enforce_same_thread_) { - return absl::FailedPreconditionError( + return FailedPreconditionError( "GpuDelegate must run on the same thread where it was " "initialized."); } @@ -179,9 +178,9 @@ class Delegate { TfLiteDelegate* tflite_delegate() { return &delegate_; } private: - absl::Status InitializeOpenClApi(GraphFloat32* graph, - std::unique_ptr* builder, - bool* graph_is_destroyed) { + Status InitializeOpenClApi(GraphFloat32* graph, + std::unique_ptr* builder, + bool* graph_is_destroyed) { *graph_is_destroyed = false; cl::InferenceEnvironmentOptions env_options; cl::InferenceEnvironmentProperties properties; @@ -208,11 +207,11 @@ class Delegate { options, std::move(*graph), builder)); TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Initialized OpenCL-based API."); - return absl::OkStatus(); + return OkStatus(); } - absl::Status InitializeOpenGlApi(GraphFloat32* graph, - std::unique_ptr* builder) { + Status InitializeOpenGlApi(GraphFloat32* graph, + std::unique_ptr* builder) { gl::InferenceEnvironmentOptions env_options; gl::InferenceEnvironmentProperties properties; RETURN_IF_ERROR( @@ -227,7 +226,7 @@ class Delegate { enforce_same_thread_ = true; TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Initialized OpenGL-based API."); - return absl::OkStatus(); + return OkStatus(); } TfLiteDelegate delegate_ = { @@ -270,7 +269,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = gpu_delegate->Prepare(context, params); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Init: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return nullptr; } return gpu_delegate; @@ -295,7 +294,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = GetDelegate(node)->Invoke(context); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Invoke: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/lite/delegates/gpu/gl/api.cc b/tensorflow/lite/delegates/gpu/gl/api.cc index f50f3458a8f..f9adbf253c1 100644 --- a/tensorflow/lite/delegates/gpu/gl/api.cc +++ b/tensorflow/lite/delegates/gpu/gl/api.cc @@ -58,20 +58,20 @@ class InferenceContextImpl : public InferenceContext { explicit InferenceContextImpl(std::unique_ptr runtime) : runtime_(std::move(runtime)) {} - absl::Status Execute() final { + Status Execute() final { std::lock_guard lock(guard_); if (state_ != InferenceContextState::NOT_STARTED) { - return absl::FailedPreconditionError("InferenceContext is not reset"); + return FailedPreconditionError("InferenceContext is not reset"); } state_ = InferenceContextState::IN_PROGRESS; return runtime_->Execute(); } - absl::Status Reset() final { + Status Reset() final { std::lock_guard lock(guard_); // TODO(akulik): should Reset not return Status? state_ = InferenceContextState::NOT_STARTED; - return absl::OkStatus(); + return OkStatus(); } RuntimeStats stats() const final { return runtime_->stats(); } @@ -94,10 +94,10 @@ class InferenceContextWithBatchImpl : public InferenceContext { refs_(std::move(refs)), runtime_(std::move(runtime)) {} - absl::Status Execute() final { + Status Execute() final { std::lock_guard lock(guard_); if (state_ != InferenceContextState::NOT_STARTED) { - return absl::FailedPreconditionError("InferenceContext is not reset"); + return FailedPreconditionError("InferenceContext is not reset"); } state_ = InferenceContextState::IN_PROGRESS; @@ -112,7 +112,7 @@ class InferenceContextWithBatchImpl : public InferenceContext { if (!buffer) continue; if (buffer->bytes_size() % byte_size) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "Object ", id, " does not match expected byte size: ", byte_size)); } @@ -120,7 +120,7 @@ class InferenceContextWithBatchImpl : public InferenceContext { if (num_batches == 0) { num_batches = b; } else if (num_batches != b) { - return absl::InvalidArgumentError(absl::StrCat( + return InvalidArgumentError(absl::StrCat( "Object ", id, " size does not match expected batch size: ", b, " vs ", num_batches)); } @@ -135,7 +135,7 @@ class InferenceContextWithBatchImpl : public InferenceContext { if (buffer) { auto ref = refs_->FindBuffer(id); if (!ref) { - return absl::InvalidArgumentError( + return InvalidArgumentError( absl::StrCat("Reference to ", id, " is not found")); } RETURN_IF_ERROR(buffer->MakeView(b * byte_size, byte_size, ref)); @@ -143,14 +143,14 @@ class InferenceContextWithBatchImpl : public InferenceContext { } RETURN_IF_ERROR(runtime_->Execute()); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status Reset() final { + Status Reset() final { std::lock_guard lock(guard_); state_ = InferenceContextState::NOT_STARTED; // TODO(akulik): should Reset not return Status? - return absl::OkStatus(); + return OkStatus(); } RuntimeStats stats() const final { return runtime_->stats(); } @@ -197,8 +197,8 @@ class CompiledModelImpl explicit CompiledModelImpl(const GpuInfo& gpu_info) : gpu_info_(gpu_info) {} // Called while compiling shaders from scratch - absl::Status Add(const WorkgroupsCalculator& workgroup_calculator, - ShaderCode code) { + Status Add(const WorkgroupsCalculator& workgroup_calculator, + ShaderCode code) { // Calculate workgroup size. uint3 workgroup_size = workgroup_calculator.Calculate(code); uint3 num_workgroups = IntegralDivideRoundUp(code.workload, workgroup_size); @@ -220,13 +220,13 @@ class CompiledModelImpl num_workgroups, shader_idx, }); - return absl::OkStatus(); + return OkStatus(); } // Store full shader and compile it if necessary. // Returns full_shader_index - absl::Status AddFullShader(const std::string& partial_shader, - const uint3& workgroup_size, size_t* size) { + Status AddFullShader(const std::string& partial_shader, + const uint3& workgroup_size, size_t* size) { std::string shader_src = GetShaderHeader(workgroup_size) + partial_shader; auto it = shader_to_index_.find(shader_src); if (it == shader_to_index_.end()) { @@ -239,10 +239,10 @@ class CompiledModelImpl } else { *size = it->second; } - return absl::OkStatus(); + return OkStatus(); } - absl::Status NewRun( + Status NewRun( const RuntimeOptions& options, const ObjectManager* objects, CommandQueue* command_queue, std::unique_ptr* inference_context) const final { @@ -273,16 +273,15 @@ class CompiledModelImpl *inference_context = absl::make_unique(std::move(runtime)); } - return absl::OkStatus(); + return OkStatus(); } #ifndef TFLITE_GPU_BINARY_RELEASE // Called on deserialization - absl::Status OnProgram(const std::vector& parameters, - const std::vector& objects, - const uint3& workgroup_size, - const uint3& num_workgroups, - size_t partial_shader_index) final { + Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, const uint3& num_workgroups, + size_t partial_shader_index) final { for (auto& object : objects) { if (IsRef(object)) { object_sizes_[GetRef(object)] = ByteSizeOf(object); @@ -299,10 +298,10 @@ class CompiledModelImpl num_workgroups, shader_idx, }); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Serialize( + Status Serialize( std::vector* serialized_compiled_model) const final { SerializedCompiledModelBuilder builder; @@ -339,13 +338,13 @@ class CompiledModelImpl auto data = builder.Finalize(options); serialized_compiled_model->insert(serialized_compiled_model->end(), data.begin(), data.end()); - return absl::OkStatus(); + return OkStatus(); } - absl::Status OnShader(absl::Span shader_src) final { + Status OnShader(absl::Span shader_src) final { std::string source(shader_src.data(), shader_src.size()); partial_shaders_.push_back(source); - return absl::OkStatus(); + return OkStatus(); } void OnOptions(const CompiledModelOptions& options) final { @@ -372,48 +371,45 @@ class CompiledModelImpl }; } // namespace -absl::Status Compile(const CompilationOptions& options, - const GraphFloat32& model, - const std::unordered_set& tflite_graph_io, - const NodeShader& node_shader, - const WorkgroupsCalculator& workgroup_calculator, - std::unique_ptr* compiled_model) { +Status Compile(const CompilationOptions& options, const GraphFloat32& model, + const std::unordered_set& tflite_graph_io, + const NodeShader& node_shader, + const WorkgroupsCalculator& workgroup_calculator, + std::unique_ptr* compiled_model) { if (!IsBatchMatchesForAllValues(model)) { - return absl::InvalidArgumentError( - "Only identical batch dimension is supported"); + return InvalidArgumentError("Only identical batch dimension is supported"); } GpuInfo gpu_info; RETURN_IF_ERROR(RequestGpuInfo(&gpu_info)); if (!IsOpenGl31OrAbove(gpu_info)) { - return absl::InternalError( + return InternalError( "OpenGL ES 3.1 or above is required to use OpenGL inference."); } auto compiled_model_impl = absl::make_unique(gpu_info); compiled_model_impl->set_dynamic_batch(options.dynamic_batch); auto compiler = NewCompiler(&node_shader, &gpu_info, options); - RETURN_IF_ERROR(compiler->Compile( - model, tflite_graph_io, [&](ShaderCode code) -> absl::Status { + RETURN_IF_ERROR( + compiler->Compile(model, tflite_graph_io, [&](ShaderCode code) -> Status { return compiled_model_impl->Add(workgroup_calculator, std::move(code)); })); *compiled_model = std::move(compiled_model_impl); - return absl::OkStatus(); + return OkStatus(); } #ifndef TFLITE_GPU_BINARY_RELEASE -absl::Status ReadSerializedModel( - const std::vector& serialized_model, - std::unique_ptr* compiled_model) { +Status ReadSerializedModel(const std::vector& serialized_model, + std::unique_ptr* compiled_model) { GpuInfo gpu_info; RETURN_IF_ERROR(RequestGpuInfo(&gpu_info)); if (!IsOpenGl31OrAbove(gpu_info)) { - return absl::InternalError( + return InternalError( "OpenGL ES 3.1 or above is required to use OpenGL inference."); } auto compiled_model_impl = absl::make_unique(gpu_info); RETURN_IF_ERROR(DeserializeCompiledModel( absl::MakeConstSpan(serialized_model), compiled_model_impl.get())); *compiled_model = std::move(compiled_model_impl); - return absl::OkStatus(); + return OkStatus(); } #endif // TFLITE_GPU_BINARY_RELEASE diff --git a/tensorflow/lite/delegates/gpu/gl/api.h b/tensorflow/lite/delegates/gpu/gl/api.h index c37eb9b7772..78b277852d0 100644 --- a/tensorflow/lite/delegates/gpu/gl/api.h +++ b/tensorflow/lite/delegates/gpu/gl/api.h @@ -51,7 +51,7 @@ class CompiledModel { // // NewRun call as well as subsequent calls to InferenceContext methods should // be done from the same EGL context. - virtual absl::Status NewRun( + virtual Status NewRun( const RuntimeOptions& options, const ObjectManager* objects, CommandQueue* command_queue, std::unique_ptr* inference_context) const = 0; @@ -59,25 +59,23 @@ class CompiledModel { #ifndef TFLITE_GPU_BINARY_RELEASE // Serializes compiled model to a string. // @return true if serialization finished successfully. - virtual absl::Status Serialize( + virtual Status Serialize( std::vector* serialized_compiled_model) const = 0; #endif // TFLITE_GPU_BINARY_RELEASE }; // Turns the given model into "compiled" form that is suitable for inference. -absl::Status Compile(const CompilationOptions& options, - const GraphFloat32& model, - const std::unordered_set& tflite_graph_io, - const NodeShader& node_shader, - const WorkgroupsCalculator& workgroup_calculator, - std::unique_ptr* compiled_model); +Status Compile(const CompilationOptions& options, const GraphFloat32& model, + const std::unordered_set& tflite_graph_io, + const NodeShader& node_shader, + const WorkgroupsCalculator& workgroup_calculator, + std::unique_ptr* compiled_model); #ifndef TFLITE_GPU_BINARY_RELEASE // Reads serialized representation previously created with // CompiledModel::Serialize call. -absl::Status ReadSerializedModel( - const std::vector& serialized_model, - std::unique_ptr* compiled_model); +Status ReadSerializedModel(const std::vector& serialized_model, + std::unique_ptr* compiled_model); #endif // TFLITE_GPU_BINARY_RELEASE // Encapsulates everything needed for one or more inference executions done @@ -91,13 +89,13 @@ class InferenceContext { virtual RuntimeStats stats() const = 0; // Executes inference. - virtual absl::Status Execute() = 0; + virtual Status Execute() = 0; // Asks context to reset it for another round. Keep in mind that does not // affect inputs nor outputs which are not cleared, so it is possible to // re-use them. // It is an error to call Reset while previous run is still in progress. - virtual absl::Status Reset() = 0; + virtual Status Reset() = 0; }; } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/api2.cc b/tensorflow/lite/delegates/gpu/gl/api2.cc index 64e301338e1..68bfa42411f 100644 --- a/tensorflow/lite/delegates/gpu/gl/api2.cc +++ b/tensorflow/lite/delegates/gpu/gl/api2.cc @@ -50,16 +50,16 @@ std::string GetShaderHeader(uint3 localsize) { } // Wraps given SSBO into GlBuffer object that does not have ownership. -absl::Status WrapSSBO(OpenGlBuffer ssbo, GlBuffer* buffer) { +Status WrapSSBO(OpenGlBuffer ssbo, GlBuffer* buffer) { int64_t size_bytes; RETURN_IF_ERROR(GetSSBOSize(ssbo.id, &size_bytes)); *buffer = GlBuffer(GL_SHADER_STORAGE_BUFFER, ssbo.id, size_bytes, 0, false); - return absl::OkStatus(); + return OkStatus(); } -absl::Status MaybeAllocateGlBuffer(const TensorObjectDef& def, GlBuffer* ssbo) { +Status MaybeAllocateGlBuffer(const TensorObjectDef& def, GlBuffer* ssbo) { if (def.object_def.object_type != gpu::ObjectType::OPENGL_SSBO) { - return absl::InvalidArgumentError("Tensor object is not GL SSBO"); + return InvalidArgumentError("Tensor object is not GL SSBO"); } const uint32_t num_elements = NumElements(def); switch (def.object_def.data_type) { @@ -68,10 +68,10 @@ absl::Status MaybeAllocateGlBuffer(const TensorObjectDef& def, GlBuffer* ssbo) { case DataType::FLOAT16: return CreateReadWriteShaderStorageBuffer(num_elements, ssbo); default: - return absl::InternalError( + return InternalError( "Unable to create new GL SSBO. Unsupported data type."); } - return absl::OkStatus(); + return OkStatus(); } // Does one-step conversion between internal and external objects. @@ -89,59 +89,58 @@ class DefaultTensorTie : public TensorTie { converter_builder.IsSupported(def.external_def, def.internal_def); } - static absl::Status New(const TensorTieDef& def, - TensorObjectConverterBuilder* converter_builder, - ObjectManager* objects, - std::unique_ptr* tie) { + static Status New(const TensorTieDef& def, + TensorObjectConverterBuilder* converter_builder, + ObjectManager* objects, std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def, TensorObject{}, objects); RETURN_IF_ERROR(tie_impl->Init(converter_builder)); *tie = std::move(tie_impl); - return absl::OkStatus(); + return OkStatus(); } - static absl::Status New(const TensorTieDef& def, - TensorObjectConverterBuilder* converter_builder, - TensorObject internal_object, - std::unique_ptr* tie) { + static Status New(const TensorTieDef& def, + TensorObjectConverterBuilder* converter_builder, + TensorObject internal_object, + std::unique_ptr* tie) { if (!IsValid(def.internal_def, internal_object)) { - return absl::InternalError("Internal object does not match definition."); + return InternalError("Internal object does not match definition."); } auto tie_impl = absl::make_unique(def, internal_object, nullptr); RETURN_IF_ERROR(tie_impl->Init(converter_builder)); *tie = std::move(tie_impl); - return absl::OkStatus(); + return OkStatus(); } - absl::Status CopyToExternalObject() final { + Status CopyToExternalObject() final { if (!converter_to_) { - return absl::OkStatus(); + return OkStatus(); } return converter_to_->Convert(internal_obj_, GetExternalObject()); } - absl::Status CopyFromExternalObject() final { + Status CopyFromExternalObject() final { if (!converter_from_) { - return absl::OkStatus(); + return OkStatus(); } return converter_from_->Convert(GetExternalObject(), internal_obj_); } - absl::Status SetExternalObject(TensorObject obj) final { + Status SetExternalObject(TensorObject obj) final { if (!def().external_def.object_def.user_provided) { - return absl::InvalidArgumentError("External object is read-only"); + return InvalidArgumentError("External object is read-only"); } if (!IsValid(def().external_def, obj)) { - return absl::InvalidArgumentError("Given object is not valid"); + return InvalidArgumentError("Given object is not valid"); } // TODO(akulik): external object should propagate to internal. if (IsSameDef()) { - return absl::UnimplementedError("Not supported"); + return UnimplementedError("Not supported"); } external_obj_ = obj; - return absl::OkStatus(); + return OkStatus(); } TensorObject GetExternalObject() final { return external_obj_; } @@ -160,8 +159,7 @@ class DefaultTensorTie : public TensorTie { internal_def.data_layout == DataLayout::DHWC4 && def().external_def.dimensions.c == 4); } - - absl::Status Init(TensorObjectConverterBuilder* converter_builder) { + Status Init(TensorObjectConverterBuilder* converter_builder) { // First check is an object is user provided. const auto& external_def = def().external_def.object_def; @@ -176,7 +174,7 @@ class DefaultTensorTie : public TensorTie { if (external_def.user_provided) { if (is_same_def) { - return absl::OkStatus(); + return OkStatus(); } // Object is provided by a user, but runtime expects different object // type. Therefore, we have to allocate internal object and convert. @@ -188,19 +186,19 @@ class DefaultTensorTie : public TensorTie { // Object is NOT provided by a user, but it matches definition expected // by runtime. Conversion is not needed. external_obj_ = internal_obj_; - return absl::OkStatus(); + return OkStatus(); } // Object is NOT provided by a user. return MaybeAllocateExternalObject(); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status MaybeAllocateInternalObject() { + Status MaybeAllocateInternalObject() { const TensorObjectDef& d = def().internal_def; if (d.object_def.user_provided) { - return absl::OkStatus(); + return OkStatus(); } switch (d.object_def.object_type) { case gpu::ObjectType::OPENGL_SSBO: { @@ -212,12 +210,12 @@ class DefaultTensorTie : public TensorTie { } // TODO(akulik): support textures as internal object when compiler permits default: - return absl::InternalError("Unexpected object type"); + return InternalError("Unexpected object type"); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status MaybeAllocateExternalObject() { + Status MaybeAllocateExternalObject() { const TensorObjectDef& d = def().external_def; switch (d.object_def.object_type) { case gpu::ObjectType::CPU_MEMORY: { @@ -234,9 +232,9 @@ class DefaultTensorTie : public TensorTie { break; } default: - return absl::InternalError("Unexpected object type"); + return InternalError("Unexpected object type"); } - return absl::OkStatus(); + return OkStatus(); } ObjectManager* objects_; @@ -268,27 +266,26 @@ class TwoStepTensorTie : public TensorTie { DefaultTensorTie::IsSupported(defs.second, converter_builder); } - static absl::Status New(const TensorTieDef& def, - TensorObjectConverterBuilder* converter_builder, - ObjectManager* objects, - std::unique_ptr* tie) { + static Status New(const TensorTieDef& def, + TensorObjectConverterBuilder* converter_builder, + ObjectManager* objects, std::unique_ptr* tie) { auto tie_impl = absl::make_unique(def); RETURN_IF_ERROR(tie_impl->Init(converter_builder, objects)); *tie = std::move(tie_impl); - return absl::OkStatus(); + return OkStatus(); } - absl::Status CopyToExternalObject() final { + Status CopyToExternalObject() final { RETURN_IF_ERROR(inner_tie_->CopyToExternalObject()); return outer_tie_->CopyToExternalObject(); } - absl::Status CopyFromExternalObject() final { + Status CopyFromExternalObject() final { RETURN_IF_ERROR(outer_tie_->CopyFromExternalObject()); return inner_tie_->CopyFromExternalObject(); } - absl::Status SetExternalObject(TensorObject obj) final { + Status SetExternalObject(TensorObject obj) final { return outer_tie_->SetExternalObject(obj); } @@ -324,8 +321,8 @@ class TwoStepTensorTie : public TensorTie { return std::make_pair(outer_def, inner_def); } - absl::Status Init(TensorObjectConverterBuilder* converter_builder, - ObjectManager* objects) { + Status Init(TensorObjectConverterBuilder* converter_builder, + ObjectManager* objects) { auto defs = MakeOuterInnerDefs(def()); RETURN_IF_ERROR(DefaultTensorTie::New(defs.second, converter_builder, objects, &inner_tie_)); @@ -349,8 +346,8 @@ class TensorTieFactory { TwoStepTensorTie::IsSupported(def, *converter_builder_)); } - absl::Status NewTensorTie(const TensorTieDef& def, ObjectManager* objects, - std::unique_ptr* tie) { + Status NewTensorTie(const TensorTieDef& def, ObjectManager* objects, + std::unique_ptr* tie) { auto converter = converter_builder_.get(); if (DefaultTensorTie::IsSupported(def, *converter)) { return DefaultTensorTie::New(def, converter, objects, tie); @@ -358,7 +355,7 @@ class TensorTieFactory { if (TwoStepTensorTie::IsSupported(def, *converter)) { return TwoStepTensorTie::New(def, converter, objects, tie); } - return absl::UnimplementedError("Unsupported tensor tie definition."); + return UnimplementedError("Unsupported tensor tie definition."); } private: @@ -371,16 +368,16 @@ class InferenceRunnerImpl : public InferenceRunner { std::unique_ptr objects) : runtime_(std::move(runtime)), objects_(std::move(objects)) {} - absl::Status Initialize(const std::vector& inputs, - const std::vector& outputs, - TensorTieFactory* tie_factory) { + Status Initialize(const std::vector& inputs, + const std::vector& outputs, + TensorTieFactory* tie_factory) { RETURN_IF_ERROR(LinkTensors(inputs, tie_factory, &inputs_)); RETURN_IF_ERROR(LinkTensors(outputs, tie_factory, &outputs_)); for (const auto& def : outputs) { output_to_cpu_ |= def.external_def.object_def.object_type == gpu::ObjectType::CPU_MEMORY; } - return absl::OkStatus(); + return OkStatus(); } std::vector inputs() const override { @@ -391,37 +388,37 @@ class InferenceRunnerImpl : public InferenceRunner { return GetExternalDefinitions(outputs_); } - absl::Status GetInputObject(int index, TensorObject* object) override { + Status GetInputObject(int index, TensorObject* object) override { if (index < 0 || index >= inputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } *object = inputs_[index]->GetExternalObject(); - return absl::OkStatus(); + return OkStatus(); } - absl::Status GetOutputObject(int index, TensorObject* object) override { + Status GetOutputObject(int index, TensorObject* object) override { if (index < 0 || index >= outputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } *object = outputs_[index]->GetExternalObject(); - return absl::OkStatus(); + return OkStatus(); } - absl::Status SetInputObject(int index, TensorObject object) override { + Status SetInputObject(int index, TensorObject object) override { if (index < 0 || index >= inputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } return inputs_[index]->SetExternalObject(object); } - absl::Status SetOutputObject(int index, TensorObject object) override { + Status SetOutputObject(int index, TensorObject object) override { if (index < 0 || index >= outputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } return outputs_[index]->SetExternalObject(object); } - absl::Status Run() override { + Status Run() override { for (auto& obj : inputs_) { RETURN_IF_ERROR(obj->CopyFromExternalObject()); } @@ -433,20 +430,20 @@ class InferenceRunnerImpl : public InferenceRunner { if (output_to_cpu_) { RETURN_IF_ERROR(runtime_->command_queue()->WaitForCompletion()); } - return absl::OkStatus(); + return OkStatus(); } private: - absl::Status LinkTensors(const std::vector& defs, - TensorTieFactory* tie_factory, - std::vector>* objects) { + Status LinkTensors(const std::vector& defs, + TensorTieFactory* tie_factory, + std::vector>* objects) { objects->reserve(defs.size()); for (auto& def : defs) { std::unique_ptr object; RETURN_IF_ERROR(tie_factory->NewTensorTie(def, objects_.get(), &object)); objects->push_back(std::move(object)); } - return absl::OkStatus(); + return OkStatus(); } static std::vector GetExternalDefinitions( @@ -477,10 +474,10 @@ class InferenceBuilderImpl : public InferenceBuilder { gpu_info_(gpu_info), tie_factory_(env_options_) {} - absl::Status Initialize() { + Status Initialize() { inputs_ = LinkTensors(graph_.inputs()); outputs_ = LinkTensors(graph_.outputs()); - return absl::OkStatus(); + return OkStatus(); } std::vector inputs() const final { @@ -491,42 +488,40 @@ class InferenceBuilderImpl : public InferenceBuilder { return GetExternalDefinitions(outputs_); } - absl::Status SetInputShape(int index, const Dimensions& dimensions) final { + Status SetInputShape(int index, const Dimensions& dimensions) final { if (index < 0 || index >= inputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } - return absl::UnimplementedError("Changing input shapes is not supported"); + return UnimplementedError("Changing input shapes is not supported"); } - absl::Status SetInputObjectDef(int index, ObjectDef new_def) final { + Status SetInputObjectDef(int index, ObjectDef new_def) final { if (index < 0 || index >= inputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } auto def = inputs_[index]; def.external_def.object_def = new_def; if (!tie_factory_.IsSupported(def)) { - return absl::InvalidArgumentError( - "New object definition is not supported."); + return InvalidArgumentError("New object definition is not supported."); } inputs_[index] = def; - return absl::OkStatus(); + return OkStatus(); } - absl::Status SetOutputObjectDef(int index, ObjectDef new_def) final { + Status SetOutputObjectDef(int index, ObjectDef new_def) final { if (index < 0 || index >= outputs_.size()) { - return absl::OutOfRangeError("Index is out of range"); + return OutOfRangeError("Index is out of range"); } auto def = outputs_[index]; def.external_def.object_def = new_def; if (!tie_factory_.IsSupported(def)) { - return absl::InvalidArgumentError( - "New object definition is not supported."); + return InvalidArgumentError("New object definition is not supported."); } outputs_[index] = def; - return absl::OkStatus(); + return OkStatus(); } - absl::Status Build(std::unique_ptr* runner) final { + Status Build(std::unique_ptr* runner) final { auto kernels = NewNodeShaderRegistry(); CompilationOptions compiler_options; compiler_options.allow_precision_loss = @@ -556,7 +551,7 @@ class InferenceBuilderImpl : public InferenceBuilder { std::move(runtime), std::move(external_objects)); RETURN_IF_ERROR(runner_impl->Initialize(inputs_, outputs_, &tie_factory_)); RETURN_IF_ERROR( - compiler->Compile(graph_, {}, [&](ShaderCode code) -> absl::Status { + compiler->Compile(graph_, {}, [&](ShaderCode code) -> Status { auto workgroup = workgroup_calculator->Calculate(code); size_t shader_index; std::string shader_src = @@ -579,7 +574,7 @@ class InferenceBuilderImpl : public InferenceBuilder { })); RETURN_IF_ERROR(runtime_ptr->PrepareForExecution()); *runner = std::move(runner_impl); - return absl::OkStatus(); + return OkStatus(); } private: @@ -629,39 +624,39 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { explicit InferenceEnvironmentImpl(const InferenceEnvironmentOptions& options) : env_options_(options) {} - absl::Status Init() { + Status Init() { RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&egl_env_)); RETURN_IF_ERROR(RequestGpuInfo(&gpu_info_)); properties_.is_opengl_available = IsOpenGl31OrAbove(gpu_info_); if (!properties_.is_opengl_available) { - return absl::InternalError( + return InternalError( "OpenGL ES 3.1 or above is required to use OpenGL inference."); } if (!env_options_.queue) { queue_ = NewCommandQueue(gpu_info_); env_options_.queue = queue_.get(); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status NewInferenceBuilder( - GraphFloat32&& model, const InferenceOptions& options, - std::unique_ptr* builder) final { + Status NewInferenceBuilder(GraphFloat32&& model, + const InferenceOptions& options, + std::unique_ptr* builder) final { if (!IsValid(options)) { - return absl::InvalidArgumentError("InferenceOptions are invalid."); + return InvalidArgumentError("InferenceOptions are invalid."); } InferenceOptions resolved_options = options; ResolveAutoPriority(&resolved_options); if (!IsBatchMatchesForAllValues(model)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Only identical batch dimension is supported"); } auto builder_impl = absl::make_unique( env_options_, resolved_options, std::move(model), &gpu_info_); RETURN_IF_ERROR(builder_impl->Initialize()); *builder = std::move(builder_impl); - return absl::OkStatus(); + return OkStatus(); } const InferenceEnvironmentProperties& properties() const { @@ -678,18 +673,18 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { } // namespace -absl::Status NewInferenceEnvironment( +Status NewInferenceEnvironment( const InferenceEnvironmentOptions& options, std::unique_ptr* environment, InferenceEnvironmentProperties* properties) { auto env_impl = absl::make_unique(options); - absl::Status status = env_impl->Init(); + Status status = env_impl->Init(); if (properties) { *properties = env_impl->properties(); } RETURN_IF_ERROR(status); *environment = std::move(env_impl); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/api2.h b/tensorflow/lite/delegates/gpu/gl/api2.h index 05062064dd6..ac58fef0ffa 100644 --- a/tensorflow/lite/delegates/gpu/gl/api2.h +++ b/tensorflow/lite/delegates/gpu/gl/api2.h @@ -41,7 +41,7 @@ class InferenceEnvironment { public: virtual ~InferenceEnvironment() = default; - virtual absl::Status NewInferenceBuilder( + virtual Status NewInferenceBuilder( GraphFloat32&& model, const InferenceOptions& options, std::unique_ptr* builder) = 0; }; @@ -52,7 +52,7 @@ struct InferenceEnvironmentOptions { // Creates a new OpenGL environment that needs to stay around until all // inference runners are destroyed. -absl::Status NewInferenceEnvironment( +Status NewInferenceEnvironment( const InferenceEnvironmentOptions& options, std::unique_ptr* environment, InferenceEnvironmentProperties* properties /* optional */); diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.cc b/tensorflow/lite/delegates/gpu/gl/command_queue.cc index 8500a50859c..87823761127 100644 --- a/tensorflow/lite/delegates/gpu/gl/command_queue.cc +++ b/tensorflow/lite/delegates/gpu/gl/command_queue.cc @@ -30,18 +30,17 @@ namespace { class DefaultCommandQueue : public CommandQueue { public: - absl::Status Dispatch(const GlProgram& program, - const uint3& workgroups) override { + Status Dispatch(const GlProgram& program, const uint3& workgroups) override { RETURN_IF_ERROR(program.Dispatch(workgroups)); return TFLITE_GPU_CALL_GL(glMemoryBarrier, GL_ALL_BARRIER_BITS); } - absl::Status WaitForCompletion() override { + Status WaitForCompletion() override { // TODO(akulik): Maybe let the user choose which wait method to use. return GlActiveSyncWait(); } - absl::Status Flush() override { return absl::OkStatus(); } + Status Flush() override { return OkStatus(); } }; // On Adreno do flush periodically as this affects performance. Command queue @@ -55,27 +54,26 @@ class AdrenoCommandQueue : public DefaultCommandQueue { explicit AdrenoCommandQueue(int flush_every_n) : flush_every_n_(flush_every_n) {} - absl::Status Dispatch(const GlProgram& program, - const uint3& workgroups) final { + Status Dispatch(const GlProgram& program, const uint3& workgroups) final { RETURN_IF_ERROR(DefaultCommandQueue::Dispatch(program, workgroups)); if ((++program_counter_ % flush_every_n_) == 0) { glFlush(); } - return absl::OkStatus(); + return OkStatus(); } - absl::Status WaitForCompletion() override { + Status WaitForCompletion() override { program_counter_ = 0; return DefaultCommandQueue::WaitForCompletion(); } - absl::Status Flush() final { + Status Flush() final { // Flush exactly once after the last dispatch. if (program_counter_ != 0) { program_counter_ = 0; glFlush(); } - return absl::OkStatus(); + return OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.h b/tensorflow/lite/delegates/gpu/gl/command_queue.h index d9bff04a837..6695852fc86 100644 --- a/tensorflow/lite/delegates/gpu/gl/command_queue.h +++ b/tensorflow/lite/delegates/gpu/gl/command_queue.h @@ -35,14 +35,14 @@ class CommandQueue { virtual ~CommandQueue() = default; // Dispatches a program. It may or may not call glFlush. - virtual absl::Status Dispatch(const GlProgram& program, - const uint3& workgroups) = 0; + virtual Status Dispatch(const GlProgram& program, + const uint3& workgroups) = 0; // Called at the end of dispatching of all programs. - virtual absl::Status Flush() = 0; + virtual Status Flush() = 0; // Waits until all programs dispatched prior this call are completed. - virtual absl::Status WaitForCompletion() = 0; + virtual Status WaitForCompletion() = 0; }; // By default memory barrier is inserted after every dispatch. diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.cc b/tensorflow/lite/delegates/gpu/gl/compiler.cc index a5f5b35f2d2..cef8139fe1e 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler.cc @@ -102,9 +102,9 @@ class CompilerImpl : public Compiler { } } - absl::Status Compile(const GraphFloat32& graph, - const std::unordered_set& tflite_graph_io, - const ShaderCodeCallback& callback) final { + Status Compile(const GraphFloat32& graph, + const std::unordered_set& tflite_graph_io, + const ShaderCodeCallback& callback) final { // It is important to have ids in a compiled graph identical to the given // graph. RETURN_IF_ERROR(graph.MakeExactCopy(&compiled_graph_)); @@ -129,22 +129,22 @@ class CompilerImpl : public Compiler { if (options_.fuse_operations) { FuseAutoOutputWithInline fuse_inline; if (!transformer.Apply("fuse_auto_with_inline", &fuse_inline)) { - return absl::InternalError("fuse_auto_with_inline failed"); + return InternalError("fuse_auto_with_inline failed"); } FuseInplaceUpdate fuse_inplace; if (!transformer.Apply("fuse_inplace_update", &fuse_inplace)) { - return absl::InternalError("fuse_inplace failed"); + return InternalError("fuse_inplace failed"); } if (options_.auto_input_fusion) { FuseAutoInput fuse_auto_input; if (!transformer.Apply("fuse_auto_input", &fuse_auto_input)) { - return absl::InternalError("fuse_auto_input failed"); + return InternalError("fuse_auto_input failed"); } } } RemoveUnusedInplaceUpdates remove_inplace_updates; if (!transformer.Apply("remove_inplace_updates", &remove_inplace_updates)) { - return absl::InternalError("remove_inplace_updates failed"); + return InternalError("remove_inplace_updates failed"); } // Prepare internal objects. @@ -176,7 +176,7 @@ class CompilerImpl : public Compiler { auto shape = outputs[0]->tensor.shape; for (auto output : outputs) { if (shape != output->tensor.shape) { - return absl::FailedPreconditionError( + return FailedPreconditionError( "Workload uint3() requires all output sizes to match"); } } @@ -274,7 +274,7 @@ class CompilerImpl : public Compiler { RETURN_IF_ERROR(codegen.Build(std::move(attr), &shader_code)); RETURN_IF_ERROR(callback(std::move(shader_code))); } - return absl::OkStatus(); + return OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.h b/tensorflow/lite/delegates/gpu/gl/compiler.h index 7769890b769..e8b434869e2 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler.h @@ -31,7 +31,7 @@ namespace tflite { namespace gpu { namespace gl { -using ShaderCodeCallback = std::function; +using ShaderCodeCallback = std::function; class Compiler { public: @@ -40,9 +40,9 @@ class Compiler { // Goes over a graph and generates OpenGL shaders for the given graph. // Callback is called for every generated shader. Callback may execute shaders // as they come or store them elsewhere to execute later. - virtual absl::Status Compile(const GraphFloat32& graph, - const std::unordered_set& tflite_graph_io, - const ShaderCodeCallback& callback) = 0; + virtual Status Compile(const GraphFloat32& graph, + const std::unordered_set& tflite_graph_io, + const ShaderCodeCallback& callback) = 0; }; std::unique_ptr NewCompiler( diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc index 4048a07d087..923b0bd47ec 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc @@ -25,8 +25,8 @@ namespace tflite { namespace gpu { namespace gl { -absl::Status MergeCode(CompiledNodeAttributes* attr, - CompiledNodeAttributes* merged_attr) { +Status MergeCode(CompiledNodeAttributes* attr, + CompiledNodeAttributes* merged_attr) { // build a map of known names. std::unordered_set known_names; for (const auto& parameter : merged_attr->code.parameters) { @@ -56,7 +56,7 @@ absl::Status MergeCode(CompiledNodeAttributes* attr, std::back_inserter(merged_attr->code.parameters)); std::move(attr->node_indices.begin(), attr->node_indices.end(), std::back_inserter(merged_attr->node_indices)); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h index 8d36504d0c3..d41a734f4e2 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h @@ -42,8 +42,8 @@ struct CompiledNodeAttributes { // Moves all code objects, parameters and node indices from attr to merged_attr. // Parameters and objects in attr.code.source_code are renamed to ensure // uniqueness. -absl::Status MergeCode(CompiledNodeAttributes* attr, - CompiledNodeAttributes* merged_attr); +Status MergeCode(CompiledNodeAttributes* attr, + CompiledNodeAttributes* merged_attr); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc index 55e6d94eb7d..01ea764b0b0 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc @@ -46,8 +46,8 @@ absl::string_view PastSubstr(absl::string_view s, absl::string_view subs) { } // namespace -absl::Status TextPreprocessor::Rewrite(const std::string& input, - std::string* output) { +Status TextPreprocessor::Rewrite(const std::string& input, + std::string* output) { absl::string_view s = input; std::string result; while (true) { @@ -57,7 +57,7 @@ absl::Status TextPreprocessor::Rewrite(const std::string& input, break; } if (inline_block.size() == 1) { - return absl::NotFoundError("Unable to find end of inline block"); + return NotFoundError("Unable to find end of inline block"); } s = PastSubstr(s, inline_block); bool processed = false; @@ -74,20 +74,20 @@ absl::Status TextPreprocessor::Rewrite(const std::string& input, processed = true; break; case RewriteStatus::ERROR: - return absl::InternalError(absl::StrCat("Error while rewriting '", - inline_block, "': ", result)); + return InternalError(absl::StrCat("Error while rewriting '", + inline_block, "': ", result)); } } if (!processed) { if (!keep_unknown_rewrites_) { - return absl::NotFoundError(absl::StrCat( - "Didn't find inline rewrite for '", inline_block, "'")); + return NotFoundError(absl::StrCat("Didn't find inline rewrite for '", + inline_block, "'")); } absl::StrAppend(&result, inline_block); } } *output = std::move(result); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h index 29fad004d3c..f01698e784f 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h @@ -58,7 +58,7 @@ class TextPreprocessor { } // input and output may point to the same object. - absl::Status Rewrite(const std::string& input, std::string* output); + Status Rewrite(const std::string& input, std::string* output); private: const char inline_delimiter_; diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc index 956f6afae28..674002b74b2 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc @@ -174,17 +174,17 @@ class ObjectRewriter : public InlineRewrite { } // namespace -absl::Status Rename(const NameFunctor& name_func, GeneratedCode* code) { +Status Rename(const NameFunctor& name_func, GeneratedCode* code) { VariableRewriter variable_rewriter("$", name_func); ObjectRewriter object_rewriter("$", name_func); for (auto&& uniform_parameter : code->parameters) { if (!variable_rewriter.AddVariable(std::move(uniform_parameter))) { - return absl::InternalError("Variable name already exists"); + return InternalError("Variable name already exists"); } } for (auto&& object : code->objects) { if (!object_rewriter.AddObject(object.first, std::move(object.second))) { - return absl::InternalError("Object name already exists"); + return InternalError("Object name already exists"); } } TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true); @@ -195,7 +195,7 @@ absl::Status Rename(const NameFunctor& name_func, GeneratedCode* code) { code->source_code = source_code; code->parameters = variable_rewriter.GetUniformParameters(); code->objects = object_rewriter.GetObjects(); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.h b/tensorflow/lite/delegates/gpu/gl/compiler/rename.h index e38ade1a3b9..06921dbe3da 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/rename.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.h @@ -32,7 +32,7 @@ using NameFunctor = std::function; // Rewrites source code, objects and parameters with the new names supplied // by the given functor. -absl::Status Rename(const NameFunctor& name_func, GeneratedCode* code); +Status Rename(const NameFunctor& name_func, GeneratedCode* code); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc index e473f9e77ff..e6100919097 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc @@ -32,8 +32,8 @@ ShaderCodegen::ShaderCodegen(const CompilationOptions& options, const GpuInfo& gpu_info) : options_(options), gpu_type_(gpu_info.type) {} -absl::Status ShaderCodegen::Build(CompiledNodeAttributes attr, - ShaderCode* shader_code) const { +Status ShaderCodegen::Build(CompiledNodeAttributes attr, + ShaderCode* shader_code) const { VariableAccessor variable_accessor(options_.inline_parameters, options_.vulkan_support); ObjectAccessor object_accessor(gpu_type_ == GpuType::MALI, @@ -41,18 +41,18 @@ absl::Status ShaderCodegen::Build(CompiledNodeAttributes attr, const auto add_object = [&](const std::string& name, Object&& object) { if (!object_accessor.AddObject(name, std::forward(object))) { - return absl::AlreadyExistsError(absl::StrCat("Object \"", name, "\"")); + return AlreadyExistsError(absl::StrCat("Object \"", name, "\"")); } - return absl::OkStatus(); + return OkStatus(); }; const auto add_uniform_parameter = [&](Variable&& variable) { const std::string name = variable.name; if (!variable_accessor.AddUniformParameter(std::move(variable))) { - return absl::AlreadyExistsError( + return AlreadyExistsError( absl::StrCat("Uniform parameter \"", name, "\"")); } - return absl::OkStatus(); + return OkStatus(); }; for (auto&& object : attr.code.objects) { @@ -62,8 +62,7 @@ absl::Status ShaderCodegen::Build(CompiledNodeAttributes attr, for (auto&& variable : attr.code.shared_variables) { const std::string name = variable.name; if (!variable_accessor.AddSharedVariable(std::move(variable))) { - return absl::AlreadyExistsError( - absl::StrCat("Shared variable \"", name, "\"")); + return AlreadyExistsError(absl::StrCat("Shared variable \"", name, "\"")); } } @@ -170,7 +169,7 @@ absl::Status ShaderCodegen::Build(CompiledNodeAttributes attr, ShaderCode(variable_accessor.GetUniformParameters(), object_accessor.GetObjects(), attr.code.workload, attr.code.workgroup, partial_source_code, attr.node_indices); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h index 12d2708d221..c4f09a3b6b9 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h @@ -39,8 +39,7 @@ class ShaderCodegen { ShaderCodegen(const CompilationOptions& options, const GpuInfo& gpu_info); // Builds final program representation. - absl::Status Build(CompiledNodeAttributes attr, - ShaderCode* shader_code) const; + Status Build(CompiledNodeAttributes attr, ShaderCode* shader_code) const; private: const CompilationOptions options_; diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc index fc86b0f3cb1..3b37ba26058 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc @@ -31,7 +31,7 @@ namespace tflite { namespace gpu { namespace gl { -absl::Status ConverterBhwcToPhwc4::Create(ConverterBhwcToPhwc4* converter) { +Status ConverterBhwcToPhwc4::Create(ConverterBhwcToPhwc4* converter) { uint3 workgroup_size = uint3(4, 4, 4); std::string shader_source = GetShaderHeader(workgroup_size) + R"( layout(std430) buffer; @@ -69,24 +69,22 @@ absl::Status ConverterBhwcToPhwc4::Create(ConverterBhwcToPhwc4* converter) { GlProgram program; RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); *converter = ConverterBhwcToPhwc4(std::move(program), workgroup_size); - return absl::OkStatus(); + return OkStatus(); } -absl::Status ConverterBhwcToPhwc4::Convert(const BHWC& shape, - const GlBuffer& source, - CommandQueue* command_queue, - GlBuffer* destination) { +Status ConverterBhwcToPhwc4::Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue, + GlBuffer* destination) { if (source.bytes_size() < BytesForBHWC(shape)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "BhwcToPhwc4: Input data size does not match expected size."); } if (destination->bytes_size() < BytesForPHWC4(shape)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "BhwcToPhwc4: output data size does not match expected size."); } if (shape.b != 1) { - return absl::UnimplementedError( - "BhwcToPhwc4: Batch size is not equal to 1."); + return UnimplementedError("BhwcToPhwc4: Batch size is not equal to 1."); } uint3 workload = uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)); uint3 num_workgroups = IntegralDivideRoundUp(workload, workgroup_size_); diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h index 9f699433a50..9d9e6402ffa 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h @@ -32,11 +32,11 @@ class ConverterBhwcToPhwc4 { // Creates invalid object. ConverterBhwcToPhwc4() : program_(), workgroup_size_() {} - static absl::Status Create(ConverterBhwcToPhwc4* converter); + static Status Create(ConverterBhwcToPhwc4* converter); - absl::Status Convert(const BHWC& shape, const GlBuffer& source, - CommandQueue* command_queue /* optional */, - GlBuffer* destination); + Status Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue /* optional */, + GlBuffer* destination); private: explicit ConverterBhwcToPhwc4(GlProgram program, const uint3& workgroup_size) diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc index 73ab9f67d94..6fc424047a1 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc @@ -41,7 +41,7 @@ inline std::vector GenerateFloats(float multiplier, int size) { return v; } -absl::Status RunTest(const BHWC& shape) { +Status RunTest(const BHWC& shape) { // Create random input and calculate expected output for it. std::vector input = GenerateFloats(0.01, shape.DimensionsProduct()); std::vector output(GetElementsSizeForPHWC4(shape), 0); @@ -71,9 +71,9 @@ absl::Status RunTest(const BHWC& shape) { RETURN_IF_ERROR(output_buffer.Read( absl::MakeSpan(converted_output.data(), converted_output.size()))); if (output != converted_output) { - return absl::InternalError("Outputs don't match"); + return InternalError("Outputs don't match"); } - return absl::OkStatus(); + return OkStatus(); } TEST(HwcToPhwc4, Smoke) { diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc index 5a9f51c0425..c63fee9f8bd 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc @@ -31,7 +31,7 @@ namespace tflite { namespace gpu { namespace gl { -absl::Status ConverterPhwc4ToBhwc::Create(ConverterPhwc4ToBhwc* converter) { +Status ConverterPhwc4ToBhwc::Create(ConverterPhwc4ToBhwc* converter) { uint3 workgroup_size = uint3(4, 4, 4); std::string shader_source = GetShaderHeader(workgroup_size) + R"( layout(std430) buffer; @@ -62,24 +62,22 @@ absl::Status ConverterPhwc4ToBhwc::Create(ConverterPhwc4ToBhwc* converter) { GlProgram program; RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); *converter = ConverterPhwc4ToBhwc(std::move(program), workgroup_size); - return absl::OkStatus(); + return OkStatus(); } -absl::Status ConverterPhwc4ToBhwc::Convert(const BHWC& shape, - const GlBuffer& source, - CommandQueue* command_queue, - GlBuffer* destination) { +Status ConverterPhwc4ToBhwc::Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue, + GlBuffer* destination) { if (source.bytes_size() < BytesForPHWC4(shape)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Phwc4ToBhwc: Input data size does not match expected size."); } if (destination->bytes_size() < BytesForBHWC(shape)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Phwc4ToBhwc: output data size does not match expected size."); } if (shape.b != 1) { - return absl::UnimplementedError( - "Phwc4ToBhwc: Batch size is not equal to 1."); + return UnimplementedError("Phwc4ToBhwc: Batch size is not equal to 1."); } uint3 workload = uint3(shape.w, shape.h, shape.c); diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h index d9a4dd34ee8..c8b181223ae 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h @@ -32,11 +32,11 @@ class ConverterPhwc4ToBhwc { // Creates invalid object. ConverterPhwc4ToBhwc() : program_(), workgroup_size_() {} - static absl::Status Create(ConverterPhwc4ToBhwc* converter); + static Status Create(ConverterPhwc4ToBhwc* converter); - absl::Status Convert(const BHWC& shape, const GlBuffer& source, - CommandQueue* command_queue /* optional */, - GlBuffer* destination); + Status Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue /* optional */, + GlBuffer* destination); private: explicit ConverterPhwc4ToBhwc(GlProgram program, const uint3& workgroup_size) diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc index 34346e3ce9d..6f969bb7801 100644 --- a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc @@ -41,7 +41,7 @@ inline std::vector GenerateFloats(float multiplier, int size) { return v; } -absl::Status RunTest(const BHWC& shape) { +Status RunTest(const BHWC& shape) { // Create random input and calculate expected output for it. std::vector input = GenerateFloats(0.01, GetElementsSizeForPHWC4(shape)); @@ -72,9 +72,9 @@ absl::Status RunTest(const BHWC& shape) { RETURN_IF_ERROR(output_buffer.Read( absl::MakeSpan(converted_output.data(), converted_output.size()))); if (output != converted_output) { - return absl::InternalError("Outputs don't match"); + return InternalError("Outputs don't match"); } - return absl::OkStatus(); + return OkStatus(); } TEST(Phwc4ToHwc, Smoke) { diff --git a/tensorflow/lite/delegates/gpu/gl/egl_context.cc b/tensorflow/lite/delegates/gpu/gl/egl_context.cc index f01bafcacff..46fbed24291 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_context.cc +++ b/tensorflow/lite/delegates/gpu/gl/egl_context.cc @@ -26,19 +26,19 @@ namespace gpu { namespace gl { namespace { -absl::Status GetConfig(EGLDisplay display, const EGLint* attributes, - EGLConfig* config) { +Status GetConfig(EGLDisplay display, const EGLint* attributes, + EGLConfig* config) { EGLint config_count; bool chosen = eglChooseConfig(display, attributes, config, 1, &config_count); RETURN_IF_ERROR(GetOpenGlErrors()); if (!chosen || config_count == 0) { - return absl::InternalError("No EGL error, but eglChooseConfig failed."); + return InternalError("No EGL error, but eglChooseConfig failed."); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateContext(EGLDisplay display, EGLContext shared_context, - EGLConfig config, EglContext* egl_context) { +Status CreateContext(EGLDisplay display, EGLContext shared_context, + EGLConfig config, EglContext* egl_context) { static const EGLint attributes[] = {EGL_CONTEXT_CLIENT_VERSION, 3, #ifdef _DEBUG // Add debugging bit EGL_CONTEXT_FLAGS_KHR, @@ -49,10 +49,10 @@ absl::Status CreateContext(EGLDisplay display, EGLContext shared_context, eglCreateContext(display, config, shared_context, attributes); RETURN_IF_ERROR(GetOpenGlErrors()); if (context == EGL_NO_CONTEXT) { - return absl::InternalError("No EGL error, but eglCreateContext failed."); + return InternalError("No EGL error, but eglCreateContext failed."); } *egl_context = EglContext(context, display, config, true); - return absl::OkStatus(); + return OkStatus(); } bool HasExtension(EGLDisplay display, const char* name) { @@ -93,36 +93,34 @@ EglContext& EglContext::operator=(EglContext&& other) { return *this; } -absl::Status EglContext::MakeCurrent(EGLSurface read, EGLSurface write) { +Status EglContext::MakeCurrent(EGLSurface read, EGLSurface write) { bool is_made_current = eglMakeCurrent(display_, write, read, context_); RETURN_IF_ERROR(GetOpenGlErrors()); if (!is_made_current) { - return absl::InternalError("No EGL error, but eglMakeCurrent failed."); + return InternalError("No EGL error, but eglMakeCurrent failed."); } - return absl::OkStatus(); + return OkStatus(); } bool EglContext::IsCurrent() const { return context_ == eglGetCurrentContext(); } -absl::Status CreateConfiglessContext(EGLDisplay display, - EGLContext shared_context, - EglContext* egl_context) { +Status CreateConfiglessContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context) { if (!HasExtension(display, "EGL_KHR_no_config_context")) { - return absl::UnavailableError("EGL_KHR_no_config_context not supported"); + return UnavailableError("EGL_KHR_no_config_context not supported"); } return CreateContext(display, shared_context, EGL_NO_CONFIG_KHR, egl_context); } -absl::Status CreateSurfacelessContext(EGLDisplay display, - EGLContext shared_context, - EglContext* egl_context) { +Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context) { if (!HasExtension(display, "EGL_KHR_create_context")) { - return absl::UnavailableError("EGL_KHR_create_context not supported"); + return UnavailableError("EGL_KHR_create_context not supported"); } if (!HasExtension(display, "EGL_KHR_surfaceless_context")) { - return absl::UnavailableError("EGL_KHR_surfaceless_context not supported"); + return UnavailableError("EGL_KHR_surfaceless_context not supported"); } const EGLint attributes[] = {EGL_RENDERABLE_TYPE, EGL_OPENGL_ES3_BIT_KHR, EGL_NONE}; @@ -131,8 +129,8 @@ absl::Status CreateSurfacelessContext(EGLDisplay display, return CreateContext(display, shared_context, config, egl_context); } -absl::Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, - EglContext* egl_context) { +Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context) { const EGLint attributes[] = { EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_BIND_TO_TEXTURE_RGB, EGL_TRUE, EGL_RENDERABLE_TYPE, EGL_OPENGL_ES3_BIT_KHR, diff --git a/tensorflow/lite/delegates/gpu/gl/egl_context.h b/tensorflow/lite/delegates/gpu/gl/egl_context.h index a93f1fdc4c4..72c53d2dd2e 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_context.h +++ b/tensorflow/lite/delegates/gpu/gl/egl_context.h @@ -61,9 +61,9 @@ class EglContext { // Make this EglContext the current EGL context on this thread, replacing // the existing current. - absl::Status MakeCurrent(EGLSurface read, EGLSurface write); + Status MakeCurrent(EGLSurface read, EGLSurface write); - absl::Status MakeCurrentSurfaceless() { + Status MakeCurrentSurfaceless() { return MakeCurrent(EGL_NO_SURFACE, EGL_NO_SURFACE); } @@ -86,16 +86,14 @@ class EglContext { // It uses the EGL_KHR_no_config_context extension to create a no config context // since most modern hardware supports the extension. -absl::Status CreateConfiglessContext(EGLDisplay display, - EGLContext shared_context, - EglContext* egl_context); +Status CreateConfiglessContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context); -absl::Status CreateSurfacelessContext(EGLDisplay display, - EGLContext shared_context, - EglContext* egl_context); +Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context); -absl::Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, - EglContext* egl_context); +Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/egl_environment.cc b/tensorflow/lite/delegates/gpu/gl/egl_environment.cc index 8ae75acd933..baf6002e6c1 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_environment.cc +++ b/tensorflow/lite/delegates/gpu/gl/egl_environment.cc @@ -28,28 +28,28 @@ namespace { // TODO(akulik): detect power management event when all contexts are destroyed // and OpenGL ES is reinitialized. See eglMakeCurrent -absl::Status InitDisplay(EGLDisplay* egl_display) { +Status InitDisplay(EGLDisplay* egl_display) { RETURN_IF_ERROR( TFLITE_GPU_CALL_EGL(eglGetDisplay, egl_display, EGL_DEFAULT_DISPLAY)); if (*egl_display == EGL_NO_DISPLAY) { - return absl::UnavailableError("eglGetDisplay returned nullptr"); + return UnavailableError("eglGetDisplay returned nullptr"); } bool is_initialized; RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglInitialize, &is_initialized, *egl_display, nullptr, nullptr)); if (!is_initialized) { - return absl::InternalError("No EGL error, but eglInitialize failed"); + return InternalError("No EGL error, but eglInitialize failed"); } - return absl::OkStatus(); + return OkStatus(); } } // namespace -absl::Status EglEnvironment::NewEglEnvironment( +Status EglEnvironment::NewEglEnvironment( std::unique_ptr* egl_environment) { *egl_environment = absl::make_unique(); RETURN_IF_ERROR((*egl_environment)->Init()); - return absl::OkStatus(); + return OkStatus(); } EglEnvironment::~EglEnvironment() { @@ -61,12 +61,12 @@ EglEnvironment::~EglEnvironment() { } } -absl::Status EglEnvironment::Init() { +Status EglEnvironment::Init() { bool is_bound; RETURN_IF_ERROR( TFLITE_GPU_CALL_EGL(eglBindAPI, &is_bound, EGL_OPENGL_ES_API)); if (!is_bound) { - return absl::InternalError("No EGL error, but eglBindAPI failed"); + return InternalError("No EGL error, but eglBindAPI failed"); } // Re-use context and display if it was created on this thread. @@ -77,7 +77,7 @@ absl::Status EglEnvironment::Init() { } else { RETURN_IF_ERROR(InitDisplay(&display_)); - absl::Status status = InitConfiglessContext(); + Status status = InitConfiglessContext(); if (!status.ok()) { status = InitSurfacelessContext(); } @@ -94,30 +94,33 @@ absl::Status EglEnvironment::Init() { } // TODO(akulik): when do we need ForceSyncTurning? ForceSyncTurning(); - return absl::OkStatus(); + return OkStatus(); } -absl::Status EglEnvironment::InitConfiglessContext() { +Status EglEnvironment::InitConfiglessContext() { RETURN_IF_ERROR(CreateConfiglessContext(display_, EGL_NO_CONTEXT, &context_)); return context_.MakeCurrentSurfaceless(); } -absl::Status EglEnvironment::InitSurfacelessContext() { +Status EglEnvironment::InitSurfacelessContext() { RETURN_IF_ERROR( CreateSurfacelessContext(display_, EGL_NO_CONTEXT, &context_)); - RETURN_IF_ERROR(context_.MakeCurrentSurfaceless()); + Status status = context_.MakeCurrentSurfaceless(); + if (!status.ok()) { + return status; + } // PowerVR support EGL_KHR_surfaceless_context, but glFenceSync crashes on // PowerVR when it is surface-less. RETURN_IF_ERROR(RequestGpuInfo(&gpu_info_)); if (gpu_info_.type == GpuType::POWERVR) { - return absl::UnavailableError( + return UnavailableError( "Surface-less context is not properly supported on powervr."); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status EglEnvironment::InitPBufferContext() { +Status EglEnvironment::InitPBufferContext() { RETURN_IF_ERROR(CreatePBufferContext(display_, EGL_NO_CONTEXT, &context_)); RETURN_IF_ERROR(CreatePbufferRGBSurface(context_.config(), display_, 1, 1, &surface_read_)); diff --git a/tensorflow/lite/delegates/gpu/gl/egl_environment.h b/tensorflow/lite/delegates/gpu/gl/egl_environment.h index cb6616496dd..fa7ca047b6e 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_environment.h +++ b/tensorflow/lite/delegates/gpu/gl/egl_environment.h @@ -36,7 +36,7 @@ namespace gl { // EGL environment needs to be created once per thread. class EglEnvironment { public: - static absl::Status NewEglEnvironment( + static Status NewEglEnvironment( std::unique_ptr* egl_environment); EglEnvironment() = default; @@ -47,10 +47,10 @@ class EglEnvironment { const GpuInfo& gpu_info() const { return gpu_info_; } private: - absl::Status Init(); - absl::Status InitConfiglessContext(); - absl::Status InitSurfacelessContext(); - absl::Status InitPBufferContext(); + Status Init(); + Status InitConfiglessContext(); + Status InitSurfacelessContext(); + Status InitPBufferContext(); EGLDisplay display_ = EGL_NO_DISPLAY; EglSurface surface_draw_; diff --git a/tensorflow/lite/delegates/gpu/gl/egl_surface.cc b/tensorflow/lite/delegates/gpu/gl/egl_surface.cc index d0f062af392..eaccea6411e 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_surface.cc +++ b/tensorflow/lite/delegates/gpu/gl/egl_surface.cc @@ -44,9 +44,9 @@ void EglSurface::Invalidate() { } } -absl::Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, - uint32_t height, uint32_t width, - EglSurface* egl_surface) { +Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, + uint32_t height, uint32_t width, + EglSurface* egl_surface) { const EGLint pbuffer_attributes[] = {EGL_WIDTH, static_cast(width), EGL_HEIGHT, @@ -60,11 +60,10 @@ absl::Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, eglCreatePbufferSurface(display, config, pbuffer_attributes); RETURN_IF_ERROR(GetOpenGlErrors()); if (surface == EGL_NO_SURFACE) { - return absl::InternalError( - "No EGL error, but eglCreatePbufferSurface failed"); + return InternalError("No EGL error, but eglCreatePbufferSurface failed"); } *egl_surface = EglSurface(surface, display); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/egl_surface.h b/tensorflow/lite/delegates/gpu/gl/egl_surface.h index 5d39aed33fb..793dc7a9dc6 100644 --- a/tensorflow/lite/delegates/gpu/gl/egl_surface.h +++ b/tensorflow/lite/delegates/gpu/gl/egl_surface.h @@ -56,9 +56,9 @@ class EglSurface { }; // Creates off-screen pbuffer-based surface of the given height and width. -absl::Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, - uint32_t height, uint32_t width, - EglSurface* egl_surface); +Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, + uint32_t height, uint32_t width, + EglSurface* egl_surface); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc b/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc index 1de49676219..509cadca60d 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc @@ -21,10 +21,9 @@ namespace tflite { namespace gpu { namespace gl { -absl::Status CopyBuffer(const GlBuffer& read_buffer, - const GlBuffer& write_buffer) { +Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer) { if (read_buffer.bytes_size() != write_buffer.bytes_size()) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Read buffer does not match write buffer size."); } gl_buffer_internal::BufferBinder read_buffer_binder(GL_COPY_READ_BUFFER, @@ -36,7 +35,7 @@ absl::Status CopyBuffer(const GlBuffer& read_buffer, write_buffer.offset(), read_buffer.bytes_size()); } -absl::Status GetSSBOSize(GLuint id, int64_t* size_bytes) { +Status GetSSBOSize(GLuint id, int64_t* size_bytes) { GLuint prev_id; RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetIntegerv, GL_SHADER_STORAGE_BUFFER_BINDING, @@ -76,19 +75,19 @@ void GlBuffer::Invalidate() { } } -absl::Status GlBuffer::BindToIndex(uint32_t index) const { +Status GlBuffer::BindToIndex(uint32_t index) const { return TFLITE_GPU_CALL_GL(glBindBufferRange, target_, index, id_, offset_, bytes_size_); } -absl::Status GlBuffer::MakeView(size_t offset, size_t bytes_size, - GlBuffer* gl_buffer) { +Status GlBuffer::MakeView(size_t offset, size_t bytes_size, + GlBuffer* gl_buffer) { if (offset + bytes_size > bytes_size_) { - return absl::OutOfRangeError("GlBuffer view is out of range."); + return OutOfRangeError("GlBuffer view is out of range."); } *gl_buffer = GlBuffer(target_, id_, bytes_size, offset_ + offset, /*has_ownership=*/false); - return absl::OkStatus(); + return OkStatus(); } GlBuffer GlBuffer::MakeRef() { @@ -122,13 +121,12 @@ GlPersistentBuffer::~GlPersistentBuffer() { glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); } -absl::Status CreatePersistentBuffer(size_t size, - GlPersistentBuffer* gl_buffer) { +Status CreatePersistentBuffer(size_t size, GlPersistentBuffer* gl_buffer) { PFNGLBUFFERSTORAGEEXTPROC glBufferStorageEXT = nullptr; glBufferStorageEXT = reinterpret_cast( eglGetProcAddress("glBufferStorageEXT")); if (!glBufferStorageEXT) { - return absl::UnavailableError("glBufferStorageEXT is not supported"); + return UnavailableError("glBufferStorageEXT is not supported"); } gl_buffer_internal::BufferId id; gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id()); @@ -142,7 +140,7 @@ absl::Status CreatePersistentBuffer(size_t size, GL_MAP_READ_BIT | GL_MAP_WRITE_BIT | GL_MAP_PERSISTENT_BIT_EXT)); *gl_buffer = GlPersistentBuffer{ GL_SHADER_STORAGE_BUFFER, id.Release(), size, 0, true, data}; - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer.h b/tensorflow/lite/delegates/gpu/gl/gl_buffer.h index 3225679ec5a..a7e19abde70 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_buffer.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer.h @@ -60,31 +60,30 @@ class GlBuffer { // Reads data from buffer into CPU memory. Data should point to a region that // has at least bytes_size available. template - absl::Status Read(absl::Span data) const; + Status Read(absl::Span data) const; // Writes data to a buffer. template - absl::Status Write(absl::Span data); + Status Write(absl::Span data); // Maps GPU memory to CPU address space and calls reader that may read from // that memory. template - absl::Status MappedRead( - const std::function)>& reader) const; + Status MappedRead( + const std::function)>& reader) const; // Maps GPU memory to CPU address space and calls writer that may write into // that memory. template - absl::Status MappedWrite( - const std::function)>& writer); + Status MappedWrite(const std::function)>& writer); - absl::Status MakeView(size_t offset, size_t bytes_size, GlBuffer* gl_buffer); + Status MakeView(size_t offset, size_t bytes_size, GlBuffer* gl_buffer); // Makes a copy without ownership of the buffer. GlBuffer MakeRef(); // Binds a buffer to an index. - absl::Status BindToIndex(uint32_t index) const; + Status BindToIndex(uint32_t index) const; // Releases the ownership of the buffer object. void Release() { has_ownership_ = false; } @@ -113,10 +112,9 @@ class GlBuffer { bool has_ownership_; }; -absl::Status CopyBuffer(const GlBuffer& read_buffer, - const GlBuffer& write_buffer); +Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer); -absl::Status GetSSBOSize(GLuint id, int64_t* size_bytes); +Status GetSSBOSize(GLuint id, int64_t* size_bytes); // Creates new shader storage buffer that will be modified and used many // times. @@ -124,20 +122,20 @@ absl::Status GetSSBOSize(GLuint id, int64_t* size_bytes); // See https://www.khronos.org/opengl/wiki/Shader_Storage_Buffer_Object for // details. template -absl::Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, - GlBuffer* gl_buffer); +Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, + GlBuffer* gl_buffer); // Creates new shader storage buffer that will be filled with data once which // will be used many times. template -absl::Status CreateReadOnlyShaderStorageBuffer(absl::Span data, - GlBuffer* gl_buffer); +Status CreateReadOnlyShaderStorageBuffer(absl::Span data, + GlBuffer* gl_buffer); // Adapts raw Buffer::Read method to read data into a vector. template -absl::Status AppendFromBuffer(const GlBuffer& buffer, std::vector* data) { +Status AppendFromBuffer(const GlBuffer& buffer, std::vector* data) { if (buffer.bytes_size() % sizeof(T) != 0) { - return absl::InvalidArgumentError("Buffer is not aligned"); + return InvalidArgumentError("Buffer is not aligned"); } size_t num_elements = buffer.bytes_size() / sizeof(T); data->resize(data->size() + num_elements); @@ -169,7 +167,7 @@ class GlPersistentBuffer : public GlBuffer { }; // Creates read-write persistent buffer with valid CPU pointer -absl::Status CreatePersistentBuffer(size_t size, GlPersistentBuffer* gl_buffer); +Status CreatePersistentBuffer(size_t size, GlPersistentBuffer* gl_buffer); //////////////////////////////////////////////////////////////////////////////// // Implementation details are below. @@ -245,8 +243,8 @@ class BufferMapper { } // namespace gl_buffer_internal template -absl::Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, - GlBuffer* gl_buffer) { +Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, + GlBuffer* gl_buffer) { gl_buffer_internal::BufferId id; gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id()); // TODO(akulik): benchmark DYNAMIC vs STREAM buffer @@ -255,12 +253,12 @@ absl::Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, GL_STREAM_COPY)); *gl_buffer = GlBuffer{GL_SHADER_STORAGE_BUFFER, id.Release(), num_elements * sizeof(T), 0, true}; - return absl::OkStatus(); + return OkStatus(); } template -absl::Status CreateReadOnlyShaderStorageBuffer(absl::Span data, - GlBuffer* gl_buffer) { +Status CreateReadOnlyShaderStorageBuffer(absl::Span data, + GlBuffer* gl_buffer) { gl_buffer_internal::BufferId id; gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id()); RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glBufferData, GL_SHADER_STORAGE_BUFFER, @@ -268,26 +266,26 @@ absl::Status CreateReadOnlyShaderStorageBuffer(absl::Span data, GL_STATIC_READ)); *gl_buffer = GlBuffer{GL_SHADER_STORAGE_BUFFER, id.Release(), data.size() * sizeof(T), 0, true}; - return absl::OkStatus(); + return OkStatus(); } template -absl::Status GlBuffer::Read(absl::Span data) const { +Status GlBuffer::Read(absl::Span data) const { if (data.size() * sizeof(T) < bytes_size()) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Read from buffer failed. Destination data is shorter than buffer."); } // TODO(akulik): glCopyBufferSubData is actually available in ES 3.1, try it. return MappedRead([this, data](absl::Span src) { std::memcpy(data.data(), src.data(), bytes_size()); - return absl::OkStatus(); + return OkStatus(); }); } template -absl::Status GlBuffer::Write(absl::Span data) { +Status GlBuffer::Write(absl::Span data) { if (data.size() * sizeof(T) > bytes_size_) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Write to buffer failed. Source data is larger than buffer."); } gl_buffer_internal::BufferBinder binder(target_, id_); @@ -296,10 +294,10 @@ absl::Status GlBuffer::Write(absl::Span data) { } template -absl::Status GlBuffer::MappedRead( - const std::function d)>& reader) const { +Status GlBuffer::MappedRead( + const std::function d)>& reader) const { if (bytes_size_ % sizeof(T) != 0) { - return absl::InvalidArgumentError("Buffer is not aligned"); + return InvalidArgumentError("Buffer is not aligned"); } gl_buffer_internal::BufferBinder binder(target_, id_); gl_buffer_internal::BufferMapper mapper(target_, offset_, bytes_size_, @@ -312,10 +310,10 @@ absl::Status GlBuffer::MappedRead( } template -absl::Status GlBuffer::MappedWrite( - const std::function d)>& writer) { +Status GlBuffer::MappedWrite( + const std::function d)>& writer) { if (bytes_size_ % sizeof(T) != 0) { - return absl::InvalidArgumentError("Buffer is not aligned"); + return InvalidArgumentError("Buffer is not aligned"); } gl_buffer_internal::BufferBinder binder(target_, id_); gl_buffer_internal::BufferMapper mapper(target_, offset_, bytes_size_, diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc b/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc index 863f5ec6020..1d8031fcf39 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc @@ -89,7 +89,7 @@ TEST(Buffer, SubView) { GlBuffer view1; ASSERT_TRUE(buffer.MakeView(4, 16, &view1).ok()); GlBuffer view2; - EXPECT_FALSE(view1.MakeView(1, 16, &view2).ok()); + EXPECT_NE(view1.MakeView(1, 16, &view2), OkStatus()); ASSERT_TRUE(view1.MakeView(2, 2, &view2).ok()); EXPECT_FALSE(view2.has_ownership()); diff --git a/tensorflow/lite/delegates/gpu/gl/gl_call.h b/tensorflow/lite/delegates/gpu/gl/gl_call.h index 1a392d6aca3..a8a81bae608 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_call.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_call.h @@ -53,13 +53,12 @@ namespace gl_call_internal { template struct Caller { template - absl::Status operator()(const std::string& context, F func, ErrorF error_func, - T* result, Params&&... params) { + Status operator()(const std::string& context, F func, ErrorF error_func, + T* result, Params&&... params) { *result = func(std::forward(params)...); const auto status = error_func(); - if (status.ok()) return absl::OkStatus(); - return absl::Status(status.code(), - std::string(status.message()) + ": " + context); + if (status.ok()) return OkStatus(); + return Status(status.code(), status.error_message() + ": " + context); } }; @@ -67,27 +66,25 @@ struct Caller { template<> struct Caller { template - absl::Status operator()(const std::string& context, F func, ErrorF error_func, - Params&&... params) { + Status operator()(const std::string& context, F func, ErrorF error_func, + Params&&... params) { func(std::forward(params)...); const auto status = error_func(); - if (status.ok()) return absl::OkStatus(); - return absl::Status(status.code(), - std::string(status.message()) + ": " + context); + if (status.ok()) return OkStatus(); + return Status(status.code(), status.error_message() + ": " + context); } }; template -absl::Status CallAndCheckError(const std::string& context, F func, - ErrorF error_func, ResultT* result, - ParamsT&&... params) { +Status CallAndCheckError(const std::string& context, F func, ErrorF error_func, + ResultT* result, ParamsT&&... params) { return Caller()(context, func, error_func, result, std::forward(params)...); } template -absl::Status CallAndCheckError(const std::string& context, F func, - ErrorF error_func, Params&&... params) { +Status CallAndCheckError(const std::string& context, F func, ErrorF error_func, + Params&&... params) { return Caller()(context, func, error_func, std::forward(params)...); } diff --git a/tensorflow/lite/delegates/gpu/gl/gl_errors.cc b/tensorflow/lite/delegates/gpu/gl/gl_errors.cc index 3ad6be8a25e..1a40e38ea9c 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_errors.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_errors.cc @@ -58,83 +58,83 @@ struct ErrorFormatter { // TODO(akulik): create new error space for GL error. -absl::Status GetOpenGlErrors() { +Status GetOpenGlErrors() { auto error = glGetError(); if (error == GL_NO_ERROR) { - return absl::OkStatus(); + return OkStatus(); } auto error2 = glGetError(); if (error2 == GL_NO_ERROR) { - return absl::InternalError(ErrorToString(error)); + return InternalError(ErrorToString(error)); } std::vector errors = {error, error2}; for (error = glGetError(); error != GL_NO_ERROR; error = glGetError()) { errors.push_back(error); } - return absl::InternalError(absl::StrJoin(errors, ",", ErrorFormatter())); + return InternalError(absl::StrJoin(errors, ",", ErrorFormatter())); } -absl::Status GetEglError() { +Status GetEglError() { EGLint error = eglGetError(); switch (error) { case EGL_SUCCESS: - return absl::OkStatus(); + return OkStatus(); case EGL_NOT_INITIALIZED: - return absl::InternalError( + return InternalError( "EGL is not initialized, or could not be initialized, for the " "specified EGL display connection."); case EGL_BAD_ACCESS: - return absl::InternalError( + return InternalError( "EGL cannot access a requested resource (for example a context is " "bound in another thread)."); case EGL_BAD_ALLOC: - return absl::InternalError( + return InternalError( "EGL failed to allocate resources for the requested operation."); case EGL_BAD_ATTRIBUTE: - return absl::InternalError( + return InternalError( "An unrecognized attribute or attribute value was passed in the " "attribute list."); case EGL_BAD_CONTEXT: - return absl::InternalError( + return InternalError( "An EGLContext argument does not name a valid EGL rendering " "context."); case EGL_BAD_CONFIG: - return absl::InternalError( + return InternalError( "An EGLConfig argument does not name a valid EGL frame buffer " "configuration."); case EGL_BAD_CURRENT_SURFACE: - return absl::InternalError( + return InternalError( "The current surface of the calling thread is a window, pixel buffer " "or pixmap that is no longer valid."); case EGL_BAD_DISPLAY: - return absl::InternalError( + return InternalError( "An EGLDisplay argument does not name a valid EGL display " "connection."); case EGL_BAD_SURFACE: - return absl::InternalError( + return InternalError( "An EGLSurface argument does not name a valid surface (window, pixel " "buffer or pixmap) configured for GL rendering."); case EGL_BAD_MATCH: - return absl::InternalError( + return InternalError( "Arguments are inconsistent (for example, a valid context requires " "buffers not supplied by a valid surface)."); case EGL_BAD_PARAMETER: - return absl::InternalError("One or more argument values are invalid."); + return InternalError("One or more argument values are invalid."); case EGL_BAD_NATIVE_PIXMAP: - return absl::InternalError( + return InternalError( "A NativePixmapType argument does not refer to a valid native " "pixmap."); case EGL_BAD_NATIVE_WINDOW: - return absl::InternalError( + return InternalError( "A NativeWindowType argument does not refer to a valid native " "window."); case EGL_CONTEXT_LOST: - return absl::InternalError( + return InternalError( "A power management event has occurred. The application must destroy " "all contexts and reinitialize OpenGL ES state and objects to " "continue rendering."); } - return absl::UnknownError("EGL error: " + std::to_string(error)); + return UnknownError("EGL error: " + std::to_string(error)); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_errors.h b/tensorflow/lite/delegates/gpu/gl/gl_errors.h index 761eddd8901..978e642abaa 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_errors.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_errors.h @@ -23,10 +23,10 @@ namespace gpu { namespace gl { // @return recent opengl errors and packs them into Status. -absl::Status GetOpenGlErrors(); +Status GetOpenGlErrors(); // @return the error of the last called EGL function in the current thread. -absl::Status GetEglError(); +Status GetEglError(); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/gl_program.cc b/tensorflow/lite/delegates/gpu/gl/gl_program.cc index d6e56ca64c4..def82357a6a 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_program.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_program.cc @@ -29,19 +29,19 @@ namespace gpu { namespace gl { namespace { -absl::Status CreateNewProgramId(GLuint* program_id) { +Status CreateNewProgramId(GLuint* program_id) { RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glCreateProgram, program_id)); if (!*program_id) { - return absl::UnknownError("Can't create opengl program: 0 program_id"); + return UnknownError("Can't create opengl program: 0 program_id"); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CheckProgramLinked(GLuint program_id) { +Status CheckProgramLinked(GLuint program_id) { GLint linked; glGetProgramiv(program_id, GL_LINK_STATUS, &linked); if (linked == GL_TRUE) { - return absl::OkStatus(); + return OkStatus(); } GLint info_size; glGetProgramiv(program_id, GL_INFO_LOG_LENGTH, &info_size); @@ -49,26 +49,26 @@ absl::Status CheckProgramLinked(GLuint program_id) { errors.resize(info_size + 1 /* plus \0 */); glGetProgramInfoLog(program_id, info_size + 1, nullptr, &errors[0]); // TODO(akulik): use glValidateProgram to gather more info. - return absl::UnavailableError("Program is not properly linked: " + errors); + return UnavailableError("Program is not properly linked: " + errors); } struct ParameterSetter { - absl::Status operator()(int value) { + Status operator()(int value) { return TFLITE_GPU_CALL_GL(glProgramUniform1i, program_id, uniform_id, value); } - absl::Status operator()(const int2& value) { + Status operator()(const int2& value) { return TFLITE_GPU_CALL_GL(glProgramUniform2i, program_id, uniform_id, value.x, value.y); } - absl::Status operator()(const int4& value) { + Status operator()(const int4& value) { return TFLITE_GPU_CALL_GL(glProgramUniform4i, program_id, uniform_id, value.x, value.y, value.z, value.w); } - absl::Status operator()(const std::vector& value) { + Status operator()(const std::vector& value) { std::vector ints(value.size() * 2, 0); for (int i = 0; i < value.size(); ++i) { ints[i * 2] = value[i].x; @@ -78,32 +78,32 @@ struct ParameterSetter { ints.size(), ints.data()); } - absl::Status operator()(unsigned int value) { + Status operator()(unsigned int value) { return TFLITE_GPU_CALL_GL(glProgramUniform1ui, program_id, uniform_id, value); } - absl::Status operator()(const uint4& value) { + Status operator()(const uint4& value) { return TFLITE_GPU_CALL_GL(glProgramUniform4ui, program_id, uniform_id, value.x, value.y, value.z, value.w); } - absl::Status operator()(float value) { + Status operator()(float value) { return TFLITE_GPU_CALL_GL(glProgramUniform1f, program_id, uniform_id, value); } - absl::Status operator()(const float2& value) { + Status operator()(const float2& value) { return TFLITE_GPU_CALL_GL(glProgramUniform2f, program_id, uniform_id, value.x, value.y); } - absl::Status operator()(const float4& value) { + Status operator()(const float4& value) { return TFLITE_GPU_CALL_GL(glProgramUniform4f, program_id, uniform_id, value.x, value.y, value.z, value.w); } - absl::Status operator()(const std::vector& value) { + Status operator()(const std::vector& value) { std::vector floats(value.size() * 4, 0); for (int i = 0; i < value.size(); ++i) { floats[i * 4] = value[i].x; @@ -121,8 +121,8 @@ struct ParameterSetter { } // namespace -absl::Status GlProgram::CreateWithShader(const GlShader& shader, - GlProgram* gl_program) { +Status GlProgram::CreateWithShader(const GlShader& shader, + GlProgram* gl_program) { GLuint program_id; RETURN_IF_ERROR(CreateNewProgramId(&program_id)); @@ -136,11 +136,11 @@ absl::Status GlProgram::CreateWithShader(const GlShader& shader, RETURN_IF_ERROR(CheckProgramLinked(program.id())); *gl_program = std::move(program); - return absl::OkStatus(); + return OkStatus(); } -absl::Status GlProgram::CreateWithBinaryShader(const BinaryShader& shader, - GlProgram* gl_program) { +Status GlProgram::CreateWithBinaryShader(const BinaryShader& shader, + GlProgram* gl_program) { GLuint program_id; RETURN_IF_ERROR(CreateNewProgramId(&program_id)); @@ -154,15 +154,15 @@ absl::Status GlProgram::CreateWithBinaryShader(const BinaryShader& shader, RETURN_IF_ERROR(CheckProgramLinked(program.id())); *gl_program = std::move(program); - return absl::OkStatus(); + return OkStatus(); } -absl::Status GlProgram::GetBinary(BinaryShader* binary_shader) { +Status GlProgram::GetBinary(BinaryShader* binary_shader) { GLint size = 0; RETURN_IF_ERROR( TFLITE_GPU_CALL_GL(glGetProgramiv, id_, GL_PROGRAM_BINARY_LENGTH, &size)); if (!size) { - return absl::InternalError("Getting binary size failed."); + return InternalError("Getting binary size failed."); } // TODO(akulik): call // glProgramParameteri(id_, GL_PROGRAM_BINARY_RETRIEVABLE_HINT, GL_TRUE) @@ -174,10 +174,10 @@ absl::Status GlProgram::GetBinary(BinaryShader* binary_shader) { &returned_size, &format, reinterpret_cast(&binary[0]))); if (size != returned_size) { - return absl::InternalError("Getting binary is failed."); + return InternalError("Getting binary is failed."); } *binary_shader = BinaryShader(format, std::move(binary)); - return absl::OkStatus(); + return OkStatus(); } GlProgram::GlProgram(GlProgram&& program) : id_(program.id_) { @@ -201,16 +201,16 @@ GlProgram& GlProgram::operator=(GlProgram&& program) { GlProgram::~GlProgram() { Invalidate(); } -absl::Status GlProgram::SetParameter(const Variable& param) { +Status GlProgram::SetParameter(const Variable& param) { GLint uniform_location; RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetUniformLocation, &uniform_location, id_, param.name.c_str())); return absl::visit(ParameterSetter{id_, uniform_location}, param.value); } -absl::Status GlProgram::Dispatch(const uint3& workgroups) const { +Status GlProgram::Dispatch(const uint3& workgroups) const { if (workgroups.x == 0 || workgroups.y == 0 || workgroups.z == 0) { - return absl::InvalidArgumentError("Invalid workgroups"); + return InvalidArgumentError("Invalid workgroups"); } RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glUseProgram, id_)); return TFLITE_GPU_CALL_GL(glDispatchCompute, workgroups.x, workgroups.y, diff --git a/tensorflow/lite/delegates/gpu/gl/gl_program.h b/tensorflow/lite/delegates/gpu/gl/gl_program.h index 892cb8e0850..dfd6bde4c59 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_program.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_program.h @@ -40,13 +40,12 @@ class GlProgram { // a program. Thus, if this call returns a program, one may set parameters and // finally execute a program. // therefore it needs to be handled elsewhere. - static absl::Status CreateWithShader(const GlShader& shader, - GlProgram* gl_program); + static Status CreateWithShader(const GlShader& shader, GlProgram* gl_program); // Same as CreateWithShader but takes compiled shader in a binary form, // therefore compilation step is avoided. - static absl::Status CreateWithBinaryShader(const BinaryShader& shader, - GlProgram* gl_program); + static Status CreateWithBinaryShader(const BinaryShader& shader, + GlProgram* gl_program); // move-only GlProgram(GlProgram&& program); @@ -60,12 +59,12 @@ class GlProgram { // Returns a binary representation for a shader currently attached and linked // into this program. - absl::Status GetBinary(BinaryShader* binary_shader); + Status GetBinary(BinaryShader* binary_shader); - absl::Status SetParameter(const Variable& param); + Status SetParameter(const Variable& param); // Executes program - absl::Status Dispatch(const uint3& workgroups) const; + Status Dispatch(const uint3& workgroups) const; bool is_valid() const { return id_ != 0; } diff --git a/tensorflow/lite/delegates/gpu/gl/gl_shader.cc b/tensorflow/lite/delegates/gpu/gl/gl_shader.cc index e3823a24d93..32391749985 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_shader.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_shader.cc @@ -42,9 +42,9 @@ GlShader& GlShader::operator=(GlShader&& shader) { GlShader::~GlShader() { Invalidate(); } -absl::Status GlShader::CompileShader(GLenum shader_type, - const std::string& shader_source, - GlShader* gl_shader) { +Status GlShader::CompileShader(GLenum shader_type, + const std::string& shader_source, + GlShader* gl_shader) { // NOTE: code compilation can fail due to gl errors happened before GLuint shader_id; RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glCreateShader, &shader_id, shader_type)); @@ -64,12 +64,12 @@ absl::Status GlShader::CompileShader(GLenum shader_type, glGetShaderiv(shader.id(), GL_INFO_LOG_LENGTH, &info_log_len); std::string errors(info_log_len, 0); glGetShaderInfoLog(shader.id(), info_log_len, nullptr, &errors[0]); - return absl::InternalError("Shader compilation failed: " + errors + - "\nProblem shader is:\n" + shader_source); + return InternalError("Shader compilation failed: " + errors + + "\nProblem shader is:\n" + shader_source); } *gl_shader = std::move(shader); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_shader.h b/tensorflow/lite/delegates/gpu/gl/gl_shader.h index 45adc59207b..d0ec421bb16 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_shader.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_shader.h @@ -33,9 +33,9 @@ class GlShader { // // @param shader_type is one of GL_VERTEX_SHADER, GL_FRAGMENT_SHADER, or // GL_COMPUTE_SHADER. - static absl::Status CompileShader(GLenum shader_type, - const std::string& shader_source, - GlShader* gl_shader); + static Status CompileShader(GLenum shader_type, + const std::string& shader_source, + GlShader* gl_shader); GlShader() : id_(0) {} diff --git a/tensorflow/lite/delegates/gpu/gl/gl_sync.cc b/tensorflow/lite/delegates/gpu/gl/gl_sync.cc index 89d3a88d16f..92caaa5c78a 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_sync.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_sync.cc @@ -25,7 +25,7 @@ namespace tflite { namespace gpu { namespace gl { -absl::Status GlSyncWait() { +Status GlSyncWait() { GlSync sync; RETURN_IF_ERROR(GlSync::NewSync(&sync)); // Flush sync and loop afterwards without it. @@ -37,16 +37,16 @@ absl::Status GlSyncWait() { break; case GL_CONDITION_SATISFIED: case GL_ALREADY_SIGNALED: - return absl::OkStatus(); + return OkStatus(); case GL_WAIT_FAILED: return GetOpenGlErrors(); } status = glClientWaitSync(sync.sync(), 0, /* timeout ns = */ 10000000); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status GlActiveSyncWait() { +Status GlActiveSyncWait() { GlSync sync; RETURN_IF_ERROR(GlSync::NewSync(&sync)); // Since creating a Sync object is itself a GL command it *must* be flushed. @@ -59,7 +59,7 @@ absl::Status GlActiveSyncWait() { break; case GL_CONDITION_SATISFIED: case GL_ALREADY_SIGNALED: - return absl::OkStatus(); + return OkStatus(); case GL_WAIT_FAILED: return GetOpenGlErrors(); } @@ -69,7 +69,7 @@ absl::Status GlActiveSyncWait() { while (true) { glGetSynciv(sync.sync(), GL_SYNC_STATUS, sizeof(GLint), nullptr, &result); if (result == GL_SIGNALED) { - return absl::OkStatus(); + return OkStatus(); } #ifdef __ARM_ACLE // Try to save CPU power by yielding CPU to another thread. @@ -78,7 +78,7 @@ absl::Status GlActiveSyncWait() { } } -absl::Status GlShaderSync::NewSync(GlShaderSync* gl_sync) { +Status GlShaderSync::NewSync(GlShaderSync* gl_sync) { GlShaderSync sync; RETURN_IF_ERROR(CreatePersistentBuffer(sizeof(int), &sync.flag_buffer_)); static const std::string* kCode = new std::string(R"(#version 310 es @@ -94,16 +94,16 @@ absl::Status GlShaderSync::NewSync(GlShaderSync* gl_sync) { RETURN_IF_ERROR(GlShader::CompileShader(GL_COMPUTE_SHADER, *kCode, &shader)); RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &sync.flag_program_)); *gl_sync = std::move(sync); - return absl::OkStatus(); + return OkStatus(); } // How it works: GPU writes a buffer and CPU checks the buffer value to be // changed. The buffer is accessible for writing by GPU and reading by CPU // simultaneously - persistent buffer or buffer across shild context can be used // for that. -absl::Status GlShaderSync::Wait() { +Status GlShaderSync::Wait() { if (!flag_buffer_.is_valid()) { - return absl::UnavailableError("GlShaderSync is not initialized."); + return UnavailableError("GlShaderSync is not initialized."); } RETURN_IF_ERROR(flag_buffer_.BindToIndex(0)); volatile int* flag_ptr_ = reinterpret_cast(flag_buffer_.data()); @@ -115,7 +115,7 @@ absl::Status GlShaderSync::Wait() { // Wait for the value is being updated by the shader. while (*flag_ptr_ != 1) { } - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_sync.h b/tensorflow/lite/delegates/gpu/gl/gl_sync.h index 8b5d910910d..dadb4b1192f 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_sync.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_sync.h @@ -32,12 +32,12 @@ namespace gl { // GlSync is moveable but not copyable. class GlSync { public: - static absl::Status NewSync(GlSync* gl_sync) { + static Status NewSync(GlSync* gl_sync) { GLsync sync; RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glFenceSync, &sync, GL_SYNC_GPU_COMMANDS_COMPLETE, 0)); *gl_sync = GlSync(sync); - return absl::OkStatus(); + return OkStatus(); } // Creates invalid object. @@ -75,12 +75,12 @@ class GlSync { }; // Waits until GPU is done with processing. -absl::Status GlSyncWait(); +Status GlSyncWait(); // Waits until all commands are flushed and then performs active waiting by // spinning a thread and checking sync status. It leads to shorter wait time // (up to tens of ms) but consumes more CPU. -absl::Status GlActiveSyncWait(); +Status GlActiveSyncWait(); // CPU checks the value in the buffer that is going to be written by GPU. The // persistent buffer is used for the simultaneous access to the buffer by GPU @@ -88,9 +88,9 @@ absl::Status GlActiveSyncWait(); // is not supported by the device. class GlShaderSync { public: - static absl::Status NewSync(GlShaderSync* gl_sync); + static Status NewSync(GlShaderSync* gl_sync); GlShaderSync() {} - absl::Status Wait(); + Status Wait(); private: GlProgram flag_program_; diff --git a/tensorflow/lite/delegates/gpu/gl/gl_texture.cc b/tensorflow/lite/delegates/gpu/gl/gl_texture.cc index 0267a52e44f..eb20deca758 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_texture.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_texture.cc @@ -120,31 +120,31 @@ void GlTexture::Invalidate() { } } -absl::Status GlTexture::BindImage(uint32_t index, GLenum access) const { +Status GlTexture::BindImage(uint32_t index, GLenum access) const { return TFLITE_GPU_CALL_GL(glBindImageTexture, index, id_, /* level = */ 0, /* layered = */ GL_TRUE, layer_, access, format_); } -absl::Status GlTexture::BindAsReadonlyImage(uint32_t index) const { +Status GlTexture::BindAsReadonlyImage(uint32_t index) const { return BindImage(index, GL_READ_ONLY); } -absl::Status GlTexture::BindAsWriteonlyImage(uint32_t index) const { +Status GlTexture::BindAsWriteonlyImage(uint32_t index) const { return BindImage(index, GL_WRITE_ONLY); } -absl::Status GlTexture::BindAsReadWriteImage(uint32_t index) const { +Status GlTexture::BindAsReadWriteImage(uint32_t index) const { return BindImage(index, GL_READ_WRITE); } -absl::Status GlTexture::BindAsSampler2D(uint32_t index) const { +Status GlTexture::BindAsSampler2D(uint32_t index) const { RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glActiveTexture, GL_TEXTURE0 + index)); return TFLITE_GPU_CALL_GL(glBindTexture, GL_TEXTURE_2D, id_); } namespace { -absl::Status SetTextureWrapAndFilter(GLenum target, GLenum texture_format) { +Status SetTextureWrapAndFilter(GLenum target, GLenum texture_format) { if (texture_format == GL_RGBA32F) { RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, GL_TEXTURE_WRAP_S, GL_REPEAT)); @@ -177,16 +177,14 @@ absl::Status SetTextureWrapAndFilter(GLenum target, GLenum texture_format) { RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, GL_TEXTURE_MIN_FILTER, GL_LINEAR)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateReadOnlyRgba2dImageTexture(DataType data_type, - const uint2& size, - const void* data, - size_t byte_size, - GlTexture* gl_texture) { +Status CreateReadOnlyRgba2dImageTexture(DataType data_type, const uint2& size, + const void* data, size_t byte_size, + GlTexture* gl_texture) { if (byte_size != /* RGBA=*/4 * SizeOf(data_type) * size.x * size.y) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Creating image texture failed. Source data size is not matching " "expected dimensions."); } @@ -204,16 +202,14 @@ absl::Status CreateReadOnlyRgba2dImageTexture(DataType data_type, 0, 0, size.x, size.y, format, type, data)); *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, 0, /*owned=*/true); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateReadOnlyRgba3dImageTexture(DataType data_type, - const uint3& size, - const void* data, - size_t byte_size, - GlTexture* gl_texture) { +Status CreateReadOnlyRgba3dImageTexture(DataType data_type, const uint3& size, + const void* data, size_t byte_size, + GlTexture* gl_texture) { if (byte_size != /* RGBA=*/4 * SizeOf(data_type) * size.x * size.y * size.z) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Creating image texture failed. Source data is larger than dimensions " "product."); } @@ -232,54 +228,53 @@ absl::Status CreateReadOnlyRgba3dImageTexture(DataType data_type, type, data)); *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, 0, /*owned=*/true); - return absl::OkStatus(); + return OkStatus(); } } // namespace -absl::Status CreateReadOnlyImageTexture(const uint2& size, - absl::Span data, - GlTexture* gl_texture) { +Status CreateReadOnlyImageTexture(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba2dImageTexture(DataType::FLOAT32, size, data.data(), data.size() * sizeof(float), gl_texture); } -absl::Status CreateReadOnlyImageTexture(const uint3& size, - absl::Span data, - GlTexture* gl_texture) { +Status CreateReadOnlyImageTexture(const uint3& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba3dImageTexture(DataType::FLOAT32, size, data.data(), data.size() * sizeof(float), gl_texture); } -absl::Status CreateReadOnlyImageTextureU8(const uint2& size, - absl::Span data, - GlTexture* gl_texture) { +Status CreateReadOnlyImageTextureU8(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba2dImageTexture(DataType::UINT8, size, data.data(), data.size() * sizeof(uint8_t), gl_texture); } -absl::Status CreateReadOnlyImageTextureF16(const uint2& size, - absl::Span data, - GlTexture* gl_texture) { +Status CreateReadOnlyImageTextureF16(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba2dImageTexture(DataType::FLOAT16, size, data.data(), data.size() * sizeof(uint16_t), gl_texture); } -absl::Status CreateReadOnlyImageTextureF16(const uint3& size, - absl::Span data, - GlTexture* gl_texture) { +Status CreateReadOnlyImageTextureF16(const uint3& size, + absl::Span data, + GlTexture* gl_texture) { return CreateReadOnlyRgba3dImageTexture(DataType::FLOAT16, size, data.data(), data.size() * sizeof(uint16_t), gl_texture); } -absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, - const uint2& size, - GlTexture* gl_texture) { +Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint2& size, + GlTexture* gl_texture) { const GLenum kTarget = GL_TEXTURE_2D; const GLenum internal_format = ToTextureInternalFormat(data_type); gl_texture_internal::TextureId id; @@ -292,12 +287,11 @@ absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, /* layer = */ 0, /* owned = */ true); - return absl::OkStatus(); + return OkStatus(); } -absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, - const uint3& size, - GlTexture* gl_texture) { +Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint3& size, + GlTexture* gl_texture) { const GLenum kTarget = GL_TEXTURE_2D_ARRAY; GLenum internal_format = ToTextureInternalFormat(data_type); gl_texture_internal::TextureId id; @@ -311,7 +305,7 @@ absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, /* layer = */ 0, /* owned = */ true); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/gl_texture.h b/tensorflow/lite/delegates/gpu/gl/gl_texture.h index 60e22b47229..951b22f23f1 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_texture.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_texture.h @@ -57,16 +57,16 @@ class GlTexture { ~GlTexture(); // Binds a texture as an image to the given index. - absl::Status BindAsReadonlyImage(uint32_t index) const; + Status BindAsReadonlyImage(uint32_t index) const; // Bind texture as an image for write access at given index. - absl::Status BindAsWriteonlyImage(uint32_t index) const; + Status BindAsWriteonlyImage(uint32_t index) const; // Bind texture as an image for read-write access at given index. - absl::Status BindAsReadWriteImage(uint32_t index) const; + Status BindAsReadWriteImage(uint32_t index) const; // Binds a texture as a sampler to the given index. - absl::Status BindAsSampler2D(uint32_t index) const; + Status BindAsSampler2D(uint32_t index) const; GLenum target() const { return target_; } @@ -87,7 +87,7 @@ class GlTexture { private: void Invalidate(); - absl::Status BindImage(uint32_t index, GLenum access) const; + Status BindImage(uint32_t index, GLenum access) const; GLuint id_; GLenum target_; @@ -101,55 +101,53 @@ class GlTexture { // will be used for reading. // // @param size defines 2D image texture size where each pixel is RGBA. -absl::Status CreateReadOnlyImageTexture(const uint2& size, - absl::Span data, - GlTexture* gl_texture); +Status CreateReadOnlyImageTexture(const uint2& size, + absl::Span data, + GlTexture* gl_texture); // Creates new 2D image texture that will be filled with float16 data once which // will be used for reading. // // @param size defines 2D image texture size where each pixel is RGBA. -absl::Status CreateReadOnlyImageTextureF16(const uint2& size, - absl::Span data, - GlTexture* gl_texture); +Status CreateReadOnlyImageTextureF16(const uint2& size, + absl::Span data, + GlTexture* gl_texture); // Creates new 2D image texture that will be filled with uint8 data once which // will be used for reading. // // @param size defines 2D image texture size where each pixel is RGBA. -absl::Status CreateReadOnlyImageTextureU8(const uint2& size, - absl::Span data, - GlTexture* gl_texture); +Status CreateReadOnlyImageTextureU8(const uint2& size, + absl::Span data, + GlTexture* gl_texture); // Creates new 3D RGBA image texture that will be filled with float32 data once // which will be used for reading. // // @param size defines 3D image texture size where each pixel is RGBA. -absl::Status CreateReadOnlyImageTexture(const uint3& size, - absl::Span data, - GlTexture* gl_texture); +Status CreateReadOnlyImageTexture(const uint3& size, + absl::Span data, + GlTexture* gl_texture); // Creates new 3D RGBA image texture that will be filled with float16 data once // which will be used for reading. // // @param size defines 3D image texture size where each pixel is RGBA. -absl::Status CreateReadOnlyImageTextureF16(const uint3& size, - absl::Span data, - GlTexture* gl_texture); +Status CreateReadOnlyImageTextureF16(const uint3& size, + absl::Span data, + GlTexture* gl_texture); // Creates new RGBA 2D image texture // // @param size defines 2D image texture size where each pixel is RGBA. -absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, - const uint2& size, - GlTexture* gl_texture); +Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint2& size, + GlTexture* gl_texture); // Creates new RGBA 3D image texture // // @param size defines 3D image texture size where each pixel is RGBA. -absl::Status CreateReadWriteRgbaImageTexture(DataType data_type, - const uint3& size, - GlTexture* gl_texture); +Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint3& size, + GlTexture* gl_texture); GLenum ToTextureFormat(DataType type); diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/add.cc b/tensorflow/lite/delegates/gpu/gl/kernels/add.cc index 135253112ba..12124a8cc57 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/add.cc @@ -34,8 +34,8 @@ namespace { class Add : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast(ctx.node->operation.attributes); auto adds = absl::get_if>(&attr.param); auto scalar = absl::get_if(&attr.param); @@ -60,13 +60,13 @@ class Add : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } std::string code = "value_0 = value_0"; for (int index = 1; index < inputs.size(); ++index) { if (inputs[index]->tensor.shape != inputs[0]->tensor.shape) { - return absl::InvalidArgumentError("Shapes are not equal"); + return InvalidArgumentError("Shapes are not equal"); } absl::StrAppend(&code, " + value_", index); } @@ -81,7 +81,7 @@ class Add : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } if (scalar) { @@ -111,7 +111,7 @@ class Add : public NodeShader { }; } - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc b/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc index 43afab2922e..a97d618e0b6 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc @@ -67,10 +67,10 @@ class AlignedConcatByChannels : public NodeShader { return true; } - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { if (!IsSupported(ctx)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "This case is not supported by aligned concat"); } auto inputs = ctx.graph->FindInputs(ctx.node->id); @@ -94,7 +94,7 @@ class AlignedConcatByChannels : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; @@ -127,10 +127,10 @@ class ConcatByAnyChannel : public NodeShader { return true; } - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { if (!IsSupported(ctx)) { - return absl::UnimplementedError("This case is not supported by concat"); + return UnimplementedError("This case is not supported by concat"); } auto inputs = ctx.graph->FindInputs(ctx.node->id); @@ -182,7 +182,7 @@ class ConcatByAnyChannel : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::ONLY_DEFINITIONS, }; - return absl::OkStatus(); + return OkStatus(); } private: @@ -348,8 +348,8 @@ class FlatConcatByHeight : public NodeShader { return true; } - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto inputs = ctx.graph->FindInputs(ctx.node->id); std::string code; std::vector params; @@ -382,7 +382,7 @@ class FlatConcatByHeight : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; @@ -415,8 +415,8 @@ class FlatConcatByWidth : public NodeShader { return true; } - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto inputs = ctx.graph->FindInputs(ctx.node->id); std::string code; std::vector params; @@ -449,22 +449,21 @@ class FlatConcatByWidth : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; class FlatConcat : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { if (FlatConcatByHeight::IsSupported(ctx)) { return flat_concat_by_height_.GenerateCode(ctx, generated_code); } if (FlatConcatByWidth::IsSupported(ctx)) { return flat_concat_by_width_.GenerateCode(ctx, generated_code); } - return absl::InvalidArgumentError( - "This case is not supported by flat concat"); + return InvalidArgumentError("This case is not supported by flat concat"); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc index 5c88402c1d1..0b18a4c4246 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc @@ -37,8 +37,8 @@ namespace { class Convolution : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto attr = absl::any_cast( ctx.node->operation.attributes); @@ -139,7 +139,7 @@ class Convolution : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; @@ -160,24 +160,24 @@ int SelectMultiplier(int32_t input_width, class Convolution1x1 : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = absl::any_cast( ctx.node->operation.attributes); if (attr.weights.shape.h != 1 || attr.weights.shape.w != 1) { - return absl::UnimplementedError("Height and width should be 1."); + return UnimplementedError("Height and width should be 1."); } if (attr.dilations.h != 1 || attr.dilations.w != 1) { - return absl::UnimplementedError("Dilations are not supported."); + return UnimplementedError("Dilations are not supported."); } if (attr.strides.h != 1 || attr.strides.w != 1) { - return absl::UnimplementedError("Strides are not supported."); + return UnimplementedError("Strides are not supported."); } if (attr.padding.appended.h != 0 || attr.padding.appended.w != 0 || attr.padding.prepended.h != 0 || attr.padding.prepended.w != 0) { - return absl::UnimplementedError("Padding is not supported."); + return UnimplementedError("Padding is not supported."); } int multiplier = SelectMultiplier(input->tensor.shape.w, ctx); @@ -280,7 +280,7 @@ class Convolution1x1 : public NodeShader { /*output=*/multiplier == 1 ? IOStructure::AUTO : IOStructure::ONLY_DEFINITIONS, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc index bc4c61075a3..189beedf815 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc @@ -31,11 +31,11 @@ namespace gl { namespace { // Wraps given SSBO into GlBuffer object that does not have ownership. -absl::Status WrapSSBO(OpenGlBuffer ssbo, GlBuffer* buffer) { +Status WrapSSBO(OpenGlBuffer ssbo, GlBuffer* buffer) { int64_t size_bytes; RETURN_IF_ERROR(GetSSBOSize(ssbo.id, &size_bytes)); *buffer = GlBuffer(GL_SHADER_STORAGE_BUFFER, ssbo.id, size_bytes, 0, false); - return absl::OkStatus(); + return OkStatus(); } std::string GetShaderHeader(const uint3& localsize) { @@ -49,12 +49,12 @@ class OpenGlConverterImpl : public TensorObjectConverter { explicit OpenGlConverterImpl(CommandQueue* command_queue) : command_queue_(command_queue) {} - virtual absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def) = 0; + virtual Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def) = 0; protected: - absl::Status InitializeProgram(const uint3& workgroup_size, - const std::string& shader_source) { + Status InitializeProgram(const uint3& workgroup_size, + const std::string& shader_source) { workgroup_size_ = workgroup_size; GlShader shader; RETURN_IF_ERROR(GlShader::CompileShader( @@ -63,7 +63,7 @@ class OpenGlConverterImpl : public TensorObjectConverter { return GlProgram::CreateWithShader(shader, &program_); } - absl::Status Dispatch(const uint3& workload) { + Status Dispatch(const uint3& workload) { uint3 num_workgroups = IntegralDivideRoundUp(workload, workgroup_size_); if (command_queue_) { return command_queue_->Dispatch(program_, num_workgroups); @@ -103,12 +103,12 @@ class FromTensorConverter : public OpenGlConverterImpl { input.data_layout == DataLayout::DHWC4; } - absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def) final { + Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def) final { shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h, output_def.dimensions.w, output_def.dimensions.c); if (shape_.b != 1) { - return absl::UnimplementedError( + return UnimplementedError( "FromTensorConverter: Batch size != 1 is not supported."); } @@ -135,18 +135,18 @@ class FromTensorConverter : public OpenGlConverterImpl { })"); } - absl::Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto output = absl::get_if(&output_obj); if (!output || !output->id) { - return absl::InvalidArgumentError("Missing output in converter"); + return InvalidArgumentError("Missing output in converter"); } auto input = absl::get_if(&input_obj); if (!input || !input->id) { - return absl::InvalidArgumentError("Missing input in converter"); + return InvalidArgumentError("Missing input in converter"); } if (input->id == output->id) { - return absl::InvalidArgumentError("Can not execute inplace conversion"); + return InvalidArgumentError("Can not execute inplace conversion"); } GlBuffer input_ssbo; RETURN_IF_ERROR(WrapSSBO(*input, &input_ssbo)); @@ -154,11 +154,11 @@ class FromTensorConverter : public OpenGlConverterImpl { RETURN_IF_ERROR(WrapSSBO(*output, &output_ssbo)); if (input_ssbo.bytes_size() != SizeInBytesDHWC4(shape_)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "FromTensorConverter: input data size does not match expected size."); } if (output_ssbo.bytes_size() != SizeInBytesBHWC(shape_)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "FromTensorConverter: output data size does not match expected " "size."); } @@ -191,12 +191,12 @@ class ToTensorConverter : public OpenGlConverterImpl { output.data_layout == DataLayout::DHWC4; } - absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def) final { + Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def) final { shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h, output_def.dimensions.w, output_def.dimensions.c); if (shape_.b != 1) { - return absl::UnimplementedError( + return UnimplementedError( "FromTensorConverter: Batch size != 1 is not supported."); } @@ -230,18 +230,18 @@ class ToTensorConverter : public OpenGlConverterImpl { })"); } - absl::Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto output = absl::get_if(&output_obj); if (!output || !output->id) { - return absl::InvalidArgumentError("Missing output in converter"); + return InvalidArgumentError("Missing output in converter"); } auto input = absl::get_if(&input_obj); if (!input || !input->id) { - return absl::InvalidArgumentError("Missing input in converter"); + return InvalidArgumentError("Missing input in converter"); } if (input->id == output->id) { - return absl::InvalidArgumentError("Can not execute inplace conversion"); + return InvalidArgumentError("Can not execute inplace conversion"); } GlBuffer input_ssbo; RETURN_IF_ERROR(WrapSSBO(*input, &input_ssbo)); @@ -249,11 +249,11 @@ class ToTensorConverter : public OpenGlConverterImpl { RETURN_IF_ERROR(WrapSSBO(*output, &output_ssbo)); if (input_ssbo.bytes_size() != SizeInBytesBHWC(shape_)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "ToTensorConverter: input data size does not match expected size."); } if (output_ssbo.bytes_size() != SizeInBytesDHWC4(shape_)) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "ToTensorConverter: output data size does not match expected size."); } auto d = IntegralDivideRoundUp(shape_.c, 4); @@ -279,19 +279,19 @@ class TrivialCopier : public TensorObjectConverter { input.data_layout == output.data_layout; } - absl::Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto ssbo_input = absl::get_if(&input_obj); auto ssbo_output = absl::get_if(&output_obj); if (ssbo_input && ssbo_output) { return Copy(*ssbo_input, *ssbo_output); } - return absl::InternalError("Unexpected object"); + return InternalError("Unexpected object"); } - absl::Status Copy(OpenGlBuffer input, OpenGlBuffer output) { + Status Copy(OpenGlBuffer input, OpenGlBuffer output) { if (input.id == output.id) { - return absl::OkStatus(); + return OkStatus(); } GlBuffer input_obj; RETURN_IF_ERROR(WrapSSBO(input, &input_obj)); @@ -313,8 +313,8 @@ class CpuCopier : public TensorObjectConverter { input.object_type == ObjectType::OPENGL_SSBO)); } - absl::Status Convert(const TensorObject& input_obj, - const TensorObject& output_obj) override { + Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { auto cpu_input = absl::get_if(&input_obj); auto cpu_output = absl::get_if(&output_obj); if (cpu_input) { @@ -335,7 +335,7 @@ class CpuCopier : public TensorObjectConverter { static_cast(cpu_output->data), cpu_output->size_bytes)); } } - return absl::InternalError("Unexpected object"); + return InternalError("Unexpected object"); } }; @@ -355,7 +355,7 @@ class TensorConverterBuilderImpl : public TensorObjectConverterBuilder { ToTensorConverter::IsSupported(input_def, output_def)); } - absl::Status MakeConverter( + Status MakeConverter( const TensorObjectDef& input, const TensorObjectDef& output, std::unique_ptr* converter) final { std::unique_ptr impl; @@ -363,22 +363,20 @@ class TensorConverterBuilderImpl : public TensorObjectConverterBuilder { const auto& output_def = output.object_def; if (TrivialCopier::IsSupported(input_def, output_def)) { *converter = absl::make_unique(); - return absl::OkStatus(); - } - if (CpuCopier::IsSupported(input_def, output_def)) { + return OkStatus(); + } else if (CpuCopier::IsSupported(input_def, output_def)) { *converter = absl::make_unique(); - return absl::OkStatus(); - } - if (FromTensorConverter::IsSupported(input_def, output_def)) { + return OkStatus(); + } else if (FromTensorConverter::IsSupported(input_def, output_def)) { impl = absl::make_unique(command_queue_); } else if (ToTensorConverter::IsSupported(input_def, output_def)) { impl = absl::make_unique(command_queue_); } else { - return absl::UnimplementedError("Unsupported conversion"); + return UnimplementedError("Unsupported conversion"); } RETURN_IF_ERROR(impl->Init(input, output)); *converter = std::move(impl); - return absl::OkStatus(); + return OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc index 5f14f093c55..daba2f6d9ef 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc @@ -45,7 +45,7 @@ Dimensions ToDimensions(const BHWC& shape) { return Dimensions(shape.b, shape.h, shape.w, shape.c); } -absl::Status RunFromTensorTest(const BHWC& shape) { +Status RunFromTensorTest(const BHWC& shape) { // Create random input and calculate expected output for it. std::vector input = GenerateFloats(0.01, GetElementsSizeForPHWC4(shape)); @@ -85,9 +85,9 @@ absl::Status RunFromTensorTest(const BHWC& shape) { RETURN_IF_ERROR(output_buffer.Read( absl::MakeSpan(converted_output.data(), converted_output.size()))); if (output != converted_output) { - return absl::InternalError("Outputs don't match"); + return InternalError("Outputs don't match"); } - return absl::OkStatus(); + return OkStatus(); } TEST(FromTensor, Smoke) { @@ -103,7 +103,7 @@ TEST(FromTensor, Smoke) { } } -absl::Status RunToTensorTest(const BHWC& shape) { +Status RunToTensorTest(const BHWC& shape) { // Create random input and calculate expected output for it. std::vector input = GenerateFloats(0.01, shape.DimensionsProduct()); std::vector output(GetElementsSizeForPHWC4(shape), 0); @@ -142,9 +142,9 @@ absl::Status RunToTensorTest(const BHWC& shape) { RETURN_IF_ERROR(output_buffer.Read( absl::MakeSpan(converted_output.data(), converted_output.size()))); if (output != converted_output) { - return absl::InternalError("Outputs don't match"); + return InternalError("Outputs don't match"); } - return absl::OkStatus(); + return OkStatus(); } TEST(ToTensor, Smoke) { diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc index 38ddbf361b4..a8d71a943b7 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc @@ -36,8 +36,8 @@ namespace { class DepthwiseConvolution : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto attr = absl::any_cast( ctx.node->operation.attributes); @@ -146,7 +146,7 @@ class DepthwiseConvolution : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc index aa254770535..35b233cbdcc 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc @@ -31,8 +31,8 @@ class ElementwiseOneArgument : public NodeShader { public: explicit ElementwiseOneArgument(OperationType operation_type) : operation_type_(operation_type) {} - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::string source; switch (operation_type_) { case OperationType::ABS: @@ -89,8 +89,7 @@ class ElementwiseOneArgument : public NodeShader { source = "value_0 = tanh(value_0);"; break; default: - return absl::InvalidArgumentError( - "Incorrect elementwise operation type."); + return InvalidArgumentError("Incorrect elementwise operation type."); } *generated_code = { /*parameters=*/{}, @@ -102,7 +101,7 @@ class ElementwiseOneArgument : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } private: @@ -145,8 +144,8 @@ class ElementwiseTwoArguments : public NodeShader { return true; } - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::vector parameters; std::vector> objects; std::string argument0, argument1; @@ -160,7 +159,7 @@ class ElementwiseTwoArguments : public NodeShader { const ElementwiseAttributes* attr = absl::any_cast( &ctx.node->operation.attributes); if (!attr) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Couldn't read attributes for the scalar of const vector case."); } auto* tensor = @@ -168,7 +167,7 @@ class ElementwiseTwoArguments : public NodeShader { &attr->param); auto* scalar = absl::get_if(&attr->param); if (!tensor && !scalar) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Couldn't read scalar of const vector data from the attributes."); } @@ -209,7 +208,7 @@ class ElementwiseTwoArguments : public NodeShader { break; } default: - return absl::InvalidArgumentError( + return InvalidArgumentError( "Incorrect elementwise with scalar operation type."); } source = absl::Substitute(source, argument0, argument1); @@ -223,7 +222,7 @@ class ElementwiseTwoArguments : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc index a8246515247..f4ad5b8cc0a 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc @@ -34,8 +34,8 @@ namespace { class FullyConnectedBuffers : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast( ctx.node->operation.attributes); @@ -106,7 +106,7 @@ class FullyConnectedBuffers : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::ONLY_DEFINITIONS, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc index 7179ba00581..e248cdfb31a 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc @@ -43,8 +43,8 @@ namespace { // class LstmNodeShader : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::string code = R"( vec4 prev_state = $input_data_1[gid.x, gid.y, gid.z]$; @@ -80,7 +80,7 @@ class LstmNodeShader : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc index c8961eee087..2e977625489 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc @@ -33,8 +33,8 @@ namespace { class MaxUnpooling : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast( ctx.node->operation.attributes); std::vector parameters = { @@ -66,7 +66,7 @@ class MaxUnpooling : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc index e94c952ffaa..9328351f169 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc @@ -32,11 +32,11 @@ namespace { class Mean : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast(ctx.node->operation.attributes); if (attr.dims != std::set({Axis::HEIGHT, Axis::WIDTH})) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Mean calculation is supported only for height and width."); } @@ -72,7 +72,7 @@ class Mean : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc index 6e825dc862d..7de4caea81d 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc @@ -52,8 +52,8 @@ bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) { return shape1.h == 1 && shape1.w == 1 && shape0.c == shape1.c; } -absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { +Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { const auto inputs = ctx.graph->FindInputs(ctx.node->id); const auto& shape0 = inputs[0]->tensor.shape; const auto& shape1 = inputs[1]->tensor.shape; @@ -80,11 +80,11 @@ absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } -absl::Status GenerateMultiplyScalarCode( - const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) { +Status GenerateMultiplyScalarCode(const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { auto attr = absl::any_cast(ctx.node->operation.attributes); auto muls = absl::get_if>(&attr.param); @@ -103,7 +103,7 @@ absl::Status GenerateMultiplyScalarCode( }; } else { if (!muls) { - return absl::InvalidArgumentError("Empty parameters for Multiplication."); + return InvalidArgumentError("Empty parameters for Multiplication."); } auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape; *generated_code = { @@ -120,13 +120,13 @@ absl::Status GenerateMultiplyScalarCode( }; } - return absl::OkStatus(); + return OkStatus(); } class Multiply : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { if (IsApplyMaskSupported(ctx)) { return GenerateApplyMaskCode(ctx, generated_code); } else { diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc b/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc index 3fc84aa675e..14fe55d943a 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc @@ -34,22 +34,22 @@ namespace { class Pad : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto attr = absl::any_cast(ctx.node->operation.attributes); if (attr.type != PaddingContentType::ZEROS && attr.type != PaddingContentType::REFLECT) { - return absl::UnimplementedError( + return UnimplementedError( "Only ZERO and REFLECT padding types are supported."); } if (attr.appended.h < 0 || attr.appended.w < 0 || attr.appended.c < 0 || attr.prepended.h < 0 || attr.prepended.w < 0 || attr.prepended.c < 0) { - return absl::UnimplementedError("Negative padding is not supported."); + return UnimplementedError("Negative padding is not supported."); } if (attr.appended.b != 0 || attr.prepended.b != 0) { - return absl::UnimplementedError("Padding for BATCH is not supported."); + return UnimplementedError("Padding for BATCH is not supported."); } std::vector parameters = { {"input_data_0_h", input->tensor.shape.h}, @@ -130,7 +130,7 @@ class Pad : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc index 5c6aefcde1c..8f140c33fca 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc @@ -31,14 +31,14 @@ namespace gpu { namespace gl { namespace { -absl::Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr, - const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { +Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr, + const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; if (attr.padding.prepended.h > attr.kernel.h || attr.padding.prepended.w > attr.kernel.w) { - return absl::InvalidArgumentError("Padding is bigger than kernel."); + return InvalidArgumentError("Padding is bigger than kernel."); } std::vector parameters = { @@ -94,12 +94,12 @@ absl::Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr, /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } -absl::Status GenerateAveragePoolingCode( - const Pooling2DAttributes& attr, const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { +Status GenerateAveragePoolingCode(const Pooling2DAttributes& attr, + const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; std::vector parameters = { @@ -136,13 +136,13 @@ absl::Status GenerateAveragePoolingCode( /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } class Pooling : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { const auto& attr = absl::any_cast(ctx.node->operation.attributes); switch (attr.type) { @@ -151,7 +151,7 @@ class Pooling : public NodeShader { case PoolingType::MAX: return GenerateMaxPoolingCode(attr, ctx, generated_code); default: - return absl::InvalidArgumentError("Incorrect attributes' type."); + return InvalidArgumentError("Incorrect attributes' type."); } } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc index 28f8551f530..88078935ee2 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc @@ -35,17 +35,17 @@ namespace { class PReLULinearAlpha : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = absl::any_cast(ctx.node->operation.attributes); auto alpha = absl::get_if>(&attr.alpha); if (!alpha) { - return absl::InvalidArgumentError("Alpha is missing"); + return InvalidArgumentError("Alpha is missing"); } if (alpha->shape.v != output->tensor.shape.c) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Alpha shape does not match the number of channels."); } @@ -79,26 +79,25 @@ class PReLULinearAlpha : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; class PReLUFull : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = absl::any_cast(ctx.node->operation.attributes); auto alpha = absl::get_if>(&attr.alpha); if (!alpha) { - return absl::InvalidArgumentError("Alpha is missing"); + return InvalidArgumentError("Alpha is missing"); } if (alpha->shape.h != output->tensor.shape.h || alpha->shape.w != output->tensor.shape.w || alpha->shape.c != output->tensor.shape.c) { - return absl::InvalidArgumentError( - "Alpha shape does not match input shape."); + return InvalidArgumentError("Alpha shape does not match input shape."); } auto shape = output->tensor.shape; @@ -142,14 +141,14 @@ class PReLUFull : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; class PReLU : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast(ctx.node->operation.attributes); auto alpha = absl::get_if>(&attr.alpha); diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc index 1d45e07aeee..3f21124aee9 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc @@ -31,8 +31,8 @@ namespace { class QuantizeAndDequantize : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::string code; // Constants code += "vec4 scale = vec4($quant_scale$);"; @@ -59,7 +59,7 @@ class QuantizeAndDequantize : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc index 8f6de92acd8..6903abc0b26 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc @@ -120,19 +120,19 @@ class Registry : public NodeShader { ~Registry() final = default; - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { std::vector errors; auto it = shaders_.find(ctx.node->operation.type); if (it != shaders_.end()) { for (auto& shader : it->second) { const auto status = shader->GenerateCode(ctx, generated_code); if (status.ok()) return status; - errors.push_back(std::string(status.message())); + errors.push_back(status.error_message()); } } - return absl::NotFoundError(absl::StrCat( - "Suitable node shader is not found: ", absl::StrJoin(errors, ", "))); + return NotFoundError(absl::StrCat("Suitable node shader is not found: ", + absl::StrJoin(errors, ", "))); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc b/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc index a9357968a90..a8e006ed151 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc @@ -33,8 +33,8 @@ namespace { class ReLU : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto attr = absl::any_cast(ctx.node->operation.attributes); // clamp(value, min(0, alpha * value), clip) std::vector params; @@ -62,7 +62,7 @@ class ReLU : public NodeShader { /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc b/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc index 9734ff14a1e..cd01417cff5 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc @@ -32,19 +32,19 @@ namespace { class Reshape : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; if (input->tensor.shape.DimensionsProduct() != output->tensor.shape.DimensionsProduct()) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Number of elements in input & output tensors don't match."); } auto attr = absl::any_cast(ctx.node->operation.attributes); if (attr.new_shape != output->tensor.shape) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Dimensions for output does not match new_shape attribute"); } @@ -80,7 +80,7 @@ class Reshape : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc index 004ae14fe8b..33d59518987 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc @@ -33,8 +33,10 @@ namespace { class Resize : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Resize() {} + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = @@ -42,15 +44,15 @@ class Resize : public NodeShader { if (input->tensor.shape.w > output->tensor.shape.w || input->tensor.shape.h > output->tensor.shape.h) { - return absl::InvalidArgumentError("Output size is less than input size."); + return InvalidArgumentError("Output size is less than input size."); } if (output->tensor.shape.w != attr.new_shape.w || output->tensor.shape.h != attr.new_shape.h) { - return absl::InvalidArgumentError( + return InvalidArgumentError( "Output size does not match new_size in attributes."); } if (input->tensor.shape.c != output->tensor.shape.c) { - return absl::InvalidArgumentError("Input/output channels mismatch."); + return InvalidArgumentError("Input/output channels mismatch."); } if (input->tensor.shape.h == 1 && input->tensor.shape.w == 1) { // Copy a single element from input. @@ -64,7 +66,7 @@ class Resize : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } std::vector parameters = { {"input_data_0_h", input->tensor.shape.h}, @@ -105,7 +107,7 @@ class Resize : public NodeShader { value_0 = $input_data_0[coord.x, coord.y, gid.z]$; )"; } else { - return absl::InvalidArgumentError("Unknown sampling type"); + return InvalidArgumentError("Unknown sampling type"); } *generated_code = { /*parameters=*/std::move(parameters), @@ -117,7 +119,7 @@ class Resize : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc b/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc index ab4497c4b62..d0fe1923d4e 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc @@ -33,8 +33,8 @@ namespace { class Slice : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; auto attr = @@ -107,7 +107,7 @@ class Slice : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc index b6c8e144a09..e59343df7b6 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc @@ -41,19 +41,17 @@ float4 GetMask(int num_channels) { class Softmax : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { const auto* input = ctx.graph->FindInputs(ctx.node->id)[0]; const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0]; const auto& attr = absl::any_cast( ctx.node->operation.attributes); if (input->tensor.shape != output->tensor.shape) { - return absl::InvalidArgumentError( - "Input and output shapes do not match."); + return InvalidArgumentError("Input and output shapes do not match."); } if (attr.axis != Axis::CHANNELS) { - return absl::UnimplementedError( - "Softmax is only supported for channels axis."); + return UnimplementedError("Softmax is only supported for channels axis."); } return input->tensor.shape.h == 1 && input->tensor.shape.w == 1 ? GenerateCodeFor1x1(ctx, generated_code) @@ -61,8 +59,8 @@ class Softmax : public NodeShader { } private: - absl::Status GenerateCodeFor1x1(const GenerationContext& ctx, - GeneratedCode* generated_code) const { + Status GenerateCodeFor1x1(const GenerationContext& ctx, + GeneratedCode* generated_code) const { const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0]; const int depth = IntegralDivideRoundUp(output->tensor.shape.c, 4); std::vector shared_variables = { @@ -135,11 +133,11 @@ class Softmax : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::ONLY_DEFINITIONS, }; - return absl::OkStatus(); + return OkStatus(); } - absl::Status GenerateCodeGeneral(const GenerationContext& ctx, - GeneratedCode* generated_code) const { + Status GenerateCodeGeneral(const GenerationContext& ctx, + GeneratedCode* generated_code) const { const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0]; std::vector parameters = { {"src_depth", IntegralDivideRoundUp(output->tensor.shape.c, 4)}, @@ -174,7 +172,7 @@ class Softmax : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::ONLY_DEFINITIONS, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.cc b/tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.cc index b1e650a1ffc..1d49da0e3fa 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.cc @@ -31,8 +31,8 @@ namespace { class SpaceToDepth : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { const auto attr = absl::any_cast(ctx.node->operation.attributes); const auto& input_data_0 = ctx.graph->FindInputs(ctx.node->id)[0]->tensor; @@ -60,7 +60,7 @@ class SpaceToDepth : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; } // namespace diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc index e9abec7eec6..de6e324017d 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc @@ -68,9 +68,9 @@ bool SingleOpModel::PopulateTensor(int index, std::vector&& data) { return true; } -absl::Status SingleOpModel::Invoke(const CompilationOptions& compile_options, - const RuntimeOptions& runtime_options, - const NodeShader& shader) { +Status SingleOpModel::Invoke(const CompilationOptions& compile_options, + const RuntimeOptions& runtime_options, + const NodeShader& shader) { std::unique_ptr env; RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env)); @@ -125,10 +125,10 @@ absl::Status SingleOpModel::Invoke(const CompilationOptions& compile_options, CopyFromPHWC4Buffer(*objects.FindBuffer(output->id), &tensor)); outputs_.push_back(std::move(tensor)); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status SingleOpModel::Invoke(const NodeShader& shader) { +Status SingleOpModel::Invoke(const NodeShader& shader) { return Invoke(CompilationOptions(), RuntimeOptions(), shader); } diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h index 42a789020df..c917220d075 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h +++ b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h @@ -48,10 +48,10 @@ class SingleOpModel { bool PopulateTensor(int index, std::vector&& data); - absl::Status Invoke(const NodeShader& shader); - absl::Status Invoke(const CompilationOptions& compile_options, - const RuntimeOptions& runtime_options, - const NodeShader& shader); + Status Invoke(const NodeShader& shader); + Status Invoke(const CompilationOptions& compile_options, + const RuntimeOptions& runtime_options, + const NodeShader& shader); const std::vector& GetOutput(int index) const { return outputs_[index].data; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc index eb28672d49f..7fcfde4f92a 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc @@ -35,8 +35,8 @@ namespace { class ConvolutionTransposedBuffers : public NodeShader { public: - absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { auto input = ctx.graph->FindInputs(ctx.node->id)[0]; auto attr = absl::any_cast( ctx.node->operation.attributes); @@ -63,10 +63,10 @@ class ConvolutionTransposedBuffers : public NodeShader { ivec2 p0 = ($padding$ + $stride$ - gid.xy % $stride$) % $stride$; for (int y = p0.y; y < $kernel_size.y$; y += $stride.y$) { for (int x = p0.x; x < $kernel_size.x$; x += $stride.x$) { - - int i = int(float(y * $kernel_size.x$) + float(x)); + + int i = int(float(y * $kernel_size.x$) + float(x)); ivec2 idx = ivec2(vec2(gid.xy + ivec2(x, y)) - vec2($padding$)); - + if (IN_BOUNDS(idx, ivec2(0), ivec2($input_data_0_w$, $input_data_0_h$) * $stride$)) { ivec2 coord = idx / $stride$; for (int l = 0; l < $src_depth$; ++l) { @@ -94,7 +94,7 @@ class ConvolutionTransposedBuffers : public NodeShader { /*input=*/IOStructure::ONLY_DEFINITIONS, /*output=*/IOStructure::AUTO, }; - return absl::OkStatus(); + return OkStatus(); } }; diff --git a/tensorflow/lite/delegates/gpu/gl/node_shader.h b/tensorflow/lite/delegates/gpu/gl/node_shader.h index d98bdbf8914..38364656b7a 100644 --- a/tensorflow/lite/delegates/gpu/gl/node_shader.h +++ b/tensorflow/lite/delegates/gpu/gl/node_shader.h @@ -101,8 +101,8 @@ class NodeShader { }; // Generates shader code for a node. The code should be just a function body. - virtual absl::Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const = 0; + virtual Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const = 0; // Limit the size of the const offsets array static constexpr int kMaxConstArraySize = 9; diff --git a/tensorflow/lite/delegates/gpu/gl/object_manager.cc b/tensorflow/lite/delegates/gpu/gl/object_manager.cc index c37be507b2b..4eca794a20a 100644 --- a/tensorflow/lite/delegates/gpu/gl/object_manager.cc +++ b/tensorflow/lite/delegates/gpu/gl/object_manager.cc @@ -24,22 +24,21 @@ namespace tflite { namespace gpu { namespace gl { -absl::Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, - GlBuffer* gl_buffer) { +Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, + GlBuffer* gl_buffer) { std::vector transposed(GetElementsSizeForPHWC4(tensor.shape)); RETURN_IF_ERROR( ConvertToPHWC4(tensor.data, tensor.shape, absl::MakeSpan(transposed))); return CreateReadOnlyShaderStorageBuffer(transposed, gl_buffer); } -absl::Status CreatePHWC4BufferFromTensorRef(const TensorRef& tensor_ref, - GlBuffer* gl_buffer) { +Status CreatePHWC4BufferFromTensorRef(const TensorRef& tensor_ref, + GlBuffer* gl_buffer) { return CreateReadWriteShaderStorageBuffer( GetElementsSizeForPHWC4(tensor_ref.shape), gl_buffer); } -absl::Status CopyFromPHWC4Buffer(const GlBuffer& buffer, - TensorFloat32* tensor) { +Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor) { return buffer.MappedRead( [tensor, &buffer](absl::Span data) { tensor->data.resize(tensor->shape.DimensionsProduct()); @@ -48,12 +47,12 @@ absl::Status CopyFromPHWC4Buffer(const GlBuffer& buffer, }); } -absl::Status ObjectManager::RegisterBuffer(uint32_t id, GlBuffer buffer) { +Status ObjectManager::RegisterBuffer(uint32_t id, GlBuffer buffer) { if (id >= buffers_.size()) { buffers_.resize(id + 1); } buffers_[id] = absl::make_unique(std::move(buffer)); - return absl::OkStatus(); + return OkStatus(); } void ObjectManager::RemoveBuffer(uint32_t id) { @@ -66,12 +65,12 @@ GlBuffer* ObjectManager::FindBuffer(uint32_t id) const { return id >= buffers_.size() ? nullptr : buffers_[id].get(); } -absl::Status ObjectManager::RegisterTexture(uint32_t id, GlTexture texture) { +Status ObjectManager::RegisterTexture(uint32_t id, GlTexture texture) { if (id >= textures_.size()) { textures_.resize(id + 1); } textures_[id] = absl::make_unique(std::move(texture)); - return absl::OkStatus(); + return OkStatus(); } void ObjectManager::RemoveTexture(uint32_t id) { diff --git a/tensorflow/lite/delegates/gpu/gl/object_manager.h b/tensorflow/lite/delegates/gpu/gl/object_manager.h index 0a7de28e1dc..8fa82871b50 100644 --- a/tensorflow/lite/delegates/gpu/gl/object_manager.h +++ b/tensorflow/lite/delegates/gpu/gl/object_manager.h @@ -41,7 +41,7 @@ namespace gl { class ObjectManager { public: // Moves ownership over the given buffer to the manager. - absl::Status RegisterBuffer(uint32_t id, GlBuffer buffer); + Status RegisterBuffer(uint32_t id, GlBuffer buffer); void RemoveBuffer(uint32_t id); @@ -49,7 +49,7 @@ class ObjectManager { GlBuffer* FindBuffer(uint32_t id) const; // Moves ownership over the given texture to the manager. - absl::Status RegisterTexture(uint32_t id, GlTexture texture); + Status RegisterTexture(uint32_t id, GlTexture texture); void RemoveTexture(uint32_t id); @@ -67,17 +67,17 @@ class ObjectManager { // Creates read-only buffer from the given tensor. Tensor data is converted to // PHWC4 layout. -absl::Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, - GlBuffer* gl_buffer); +Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, + GlBuffer* gl_buffer); // Creates read-write buffer for the given tensor shape, where data layout is // supposed to be PHWC4. -absl::Status CreatePHWC4BufferFromTensorRef(const TensorRef& tensor_ref, - GlBuffer* gl_buffer); +Status CreatePHWC4BufferFromTensorRef(const TensorRef& tensor_ref, + GlBuffer* gl_buffer); // Copies data from a buffer that holds data in PHWC4 layout to the given // tensor. -absl::Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor); +Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc index 0769a5014b4..7134fc010d0 100644 --- a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc @@ -28,7 +28,7 @@ namespace tflite { namespace gpu { namespace gl { -absl::Status RequestGpuInfo(GpuInfo* gpu_info) { +Status RequestGpuInfo(GpuInfo* gpu_info) { GpuInfo info; const GLubyte* renderer_name = glGetString(GL_RENDERER); @@ -73,7 +73,7 @@ absl::Status RequestGpuInfo(GpuInfo* gpu_info) { glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS, &info.max_array_texture_layers); RETURN_IF_ERROR(GetOpenGlErrors()); *gpu_info = info; - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h index f9d203e2325..4eba7a55c2a 100644 --- a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h +++ b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h @@ -28,7 +28,7 @@ namespace gl { // This method performs multiple GL calls, therefore, egl context needs to be // created upfront. -absl::Status RequestGpuInfo(GpuInfo* gpu_info); +Status RequestGpuInfo(GpuInfo* gpu_info); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.cc b/tensorflow/lite/delegates/gpu/gl/runtime.cc index 2a48b59c8d9..14e30389cf0 100644 --- a/tensorflow/lite/delegates/gpu/gl/runtime.cc +++ b/tensorflow/lite/delegates/gpu/gl/runtime.cc @@ -41,13 +41,13 @@ namespace gl { namespace { struct TextureF16Maker { - absl::Status operator()(const uint3& size) const { + Status operator()(const uint3& size) const { return CreateReadOnlyImageTextureF16(size, data, gl_texture); } - absl::Status operator()(const uint2& size) const { + Status operator()(const uint2& size) const { return CreateReadOnlyImageTextureF16(size, data, gl_texture); } - absl::Status operator()(const size_t& size) const { + Status operator()(const size_t& size) const { return CreateReadOnlyImageTextureF16(uint2(static_cast(size), 1U), data, gl_texture); } @@ -56,13 +56,13 @@ struct TextureF16Maker { }; struct TextureF32Maker { - absl::Status operator()(const uint3& size) const { + Status operator()(const uint3& size) const { return CreateReadOnlyImageTexture(size, data, gl_texture); } - absl::Status operator()(const uint2& size) const { + Status operator()(const uint2& size) const { return CreateReadOnlyImageTexture(size, data, gl_texture); } - absl::Status operator()(const size_t& size) const { + Status operator()(const size_t& size) const { return CreateReadOnlyImageTexture(uint2(static_cast(size), 1U), data, gl_texture); } @@ -70,21 +70,20 @@ struct TextureF32Maker { GlTexture* gl_texture; }; -absl::Status MakeGlTexture(const Object& object, const ObjectData& data, - GlTexture* gl_texture) { +Status MakeGlTexture(const Object& object, const ObjectData& data, + GlTexture* gl_texture) { if (object.access == AccessType::READ_WRITE || object.access == AccessType::WRITE) { - return absl::InvalidArgumentError("Read-write textures are not supported"); + return InvalidArgumentError("Read-write textures are not supported"); } if (object.data_type != DataType::FLOAT16 && object.data_type != DataType::FLOAT32) { - return absl::InvalidArgumentError( - "Textures support float16 or float32 only."); + return InvalidArgumentError("Textures support float16 or float32 only."); } switch (object.data_type) { case DataType::FLOAT16: { if (data.size() % 2 != 0) { - return absl::InvalidArgumentError("Texture size is not aligned"); + return InvalidArgumentError("Texture size is not aligned"); } return absl::visit( TextureF16Maker{ @@ -97,7 +96,7 @@ absl::Status MakeGlTexture(const Object& object, const ObjectData& data, } case DataType::FLOAT32: { if (data.size() % sizeof(float) != 0) { - return absl::InvalidArgumentError("Texture size is not aligned"); + return InvalidArgumentError("Texture size is not aligned"); } return absl::visit( TextureF32Maker{ @@ -109,18 +108,18 @@ absl::Status MakeGlTexture(const Object& object, const ObjectData& data, object.size); } default: - return absl::InvalidArgumentError("Unsupported textures data type."); + return InvalidArgumentError("Unsupported textures data type."); } } struct TextureRefMaker { - absl::Status operator()(const uint3& size) const { + Status operator()(const uint3& size) const { return CreateReadWriteRgbaImageTexture(type, size, gl_texture); } - absl::Status operator()(const uint2& size) const { + Status operator()(const uint2& size) const { return CreateReadWriteRgbaImageTexture(type, size, gl_texture); } - absl::Status operator()(const size_t& size) const { + Status operator()(const size_t& size) const { return CreateReadWriteRgbaImageTexture( type, uint2(static_cast(size), 1U), gl_texture); } @@ -129,38 +128,37 @@ struct TextureRefMaker { }; // Makes read-write gl texture -absl::Status MakeGlTextureRef(const Object& object, GlTexture* gl_texture) { +Status MakeGlTextureRef(const Object& object, GlTexture* gl_texture) { return absl::visit(TextureRefMaker{object.data_type, gl_texture}, object.size); } -absl::Status MakeGlBuffer(const Object& object, const ObjectData& data, - GlBuffer* gl_buffer) { +Status MakeGlBuffer(const Object& object, const ObjectData& data, + GlBuffer* gl_buffer) { if (data.size() % SizeOf(object.data_type) != 0) { - return absl::InvalidArgumentError("Buffer size is not aligned"); + return InvalidArgumentError("Buffer size is not aligned"); } return CreateReadOnlyShaderStorageBuffer(absl::MakeConstSpan(data), gl_buffer); } // Looks up an object with the given id. If found, makes a binding function. -absl::Status MakeBindingFunc(const Object& object, uint32_t id, - const ObjectManager& objects, - std::function* binding_func) { +Status MakeBindingFunc(const Object& object, uint32_t id, + const ObjectManager& objects, + std::function* binding_func) { const uint32_t binding = object.binding; switch (object.object_type) { case ObjectType::BUFFER: { auto ptr = objects.FindBuffer(id); if (!ptr) { - return absl::NotFoundError( - absl::StrCat("Buffer ", id, " is not found")); + return NotFoundError(absl::StrCat("Buffer ", id, " is not found")); } // Validate buffer. size_t size_in_bytes = ByteSizeOf(object); // TODO(akulik): make comparison != instead of < if (ptr->bytes_size() < size_in_bytes) { - return absl::FailedPreconditionError( + return FailedPreconditionError( absl::StrCat("Buffer ", id, " size in bytes ", ptr->bytes_size(), " < requested size_in_bytes ", size_in_bytes)); } @@ -170,16 +168,15 @@ absl::Status MakeBindingFunc(const Object& object, uint32_t id, case ObjectType::TEXTURE: { auto ptr = objects.FindTexture(id); if (!ptr) { - return absl::NotFoundError( - absl::StrCat("Texture ", id, " is not found")); + return NotFoundError(absl::StrCat("Texture ", id, " is not found")); } *binding_func = [=]() { return ptr->BindAsReadWriteImage(binding); }; break; } case ObjectType::UNKNOWN: - return absl::InvalidArgumentError("Unknown object type"); + return InvalidArgumentError("Unknown object type"); } - return absl::OkStatus(); + return OkStatus(); } } // namespace @@ -197,10 +194,10 @@ Runtime::Runtime(const RuntimeOptions& options, const GpuInfo& gpu_info, } } -absl::Status Runtime::AddProgram(const GlShader& shader, - const std::vector& parameters, - const std::vector& objects, - const uint3& num_workgroups) { +Status Runtime::AddProgram(const GlShader& shader, + const std::vector& parameters, + const std::vector& objects, + const uint3& num_workgroups) { GlProgram program; RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); @@ -220,10 +217,10 @@ absl::Status Runtime::AddProgram(const GlShader& shader, // Reference object could be provided externally as a model input/output // but also for debugging purposes. Otherwise all references are collected // and allocated later. - absl::Status status = MakeBindingFunc(object, GetRef(object), - *external_objects_, &binding_func); + Status status = MakeBindingFunc(object, GetRef(object), + *external_objects_, &binding_func); if (!status.ok()) { - if (absl::IsNotFound(status)) { + if (status.code() == StatusCode::kNotFound) { program.refs.push_back(object); continue; // don't add to binding. } @@ -241,10 +238,10 @@ absl::Status Runtime::AddProgram(const GlShader& shader, // All parameters once set stay with program, therefore, we only need to keep // program and bindings for execution. - return absl::OkStatus(); + return OkStatus(); } -absl::Status Runtime::AllocateInternalObject(const Object& object) { +Status Runtime::AllocateInternalObject(const Object& object) { const ObjectRef ref = GetRef(object); switch (object.object_type) { case ObjectType::BUFFER: { @@ -263,16 +260,15 @@ absl::Status Runtime::AllocateInternalObject(const Object& object) { break; } default: - return absl::InternalError("Unexpected internal object type"); + return InternalError("Unexpected internal object type"); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Runtime::AllocateConstObject(const Object& object, uint32_t* id) { +Status Runtime::AllocateConstObject(const Object& object, uint32_t* id) { const ObjectData* data = GetData(object); if (data == nullptr) { - return absl::InternalError( - "Unable to allocate reference as a const object"); + return InternalError("Unable to allocate reference as a const object"); } *id = next_const_id_++; switch (object.object_type) { @@ -293,12 +289,12 @@ absl::Status Runtime::AllocateConstObject(const Object& object, uint32_t* id) { break; } case ObjectType::UNKNOWN: - return absl::InternalError("Unknown object type"); + return InternalError("Unknown object type"); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Runtime::PrepareForExecution() { +Status Runtime::PrepareForExecution() { if (shared_readonly_buffer_ && !shared_readonly_buffer_->empty()) { GlBuffer shared_buffer; RETURN_IF_ERROR( @@ -324,10 +320,11 @@ absl::Status Runtime::PrepareForExecution() { // Check whether it is created already. BindFunc binding; ObjectRef ref = GetRef(object); - absl::Status status = - MakeBindingFunc(object, ref, internal_objects_, &binding); + Status status = MakeBindingFunc(object, ref, internal_objects_, &binding); if (!status.ok()) { - if (absl::IsNotFound(status)) return status; + if (status.code() != StatusCode::kNotFound) { + return status; + } RETURN_IF_ERROR(AllocateInternalObject(object)); RETURN_IF_ERROR( MakeBindingFunc(object, ref, internal_objects_, &binding)); @@ -336,7 +333,7 @@ absl::Status Runtime::PrepareForExecution() { } program.refs.clear(); } - return absl::OkStatus(); + return OkStatus(); } namespace { @@ -402,8 +399,8 @@ struct AddUsageRecordForTextureFunc { // We assume that AddUsageRecord for different objects is called in order of // program_id. -absl::Status AddUsageRecord(CombinedUsageRecords* usage_records, - const Object& object, const size_t program_id) { +Status AddUsageRecord(CombinedUsageRecords* usage_records, const Object& object, + const size_t program_id) { auto ref = GetRef(object); if (ref >= usage_records->usage_refs.size()) { usage_records->usage_refs.resize(ref + 1, kNotAssigned); @@ -419,17 +416,17 @@ absl::Status AddUsageRecord(CombinedUsageRecords* usage_records, } else { UpdateUsageRecord(&usage_records->buffers[usage_ref], program_id); } - return absl::OkStatus(); + return OkStatus(); } if (object.object_type == ObjectType::TEXTURE) { absl::visit(AddUsageRecordForTextureFunc{usage_records, ref, program_id}, object.size); - return absl::OkStatus(); + return OkStatus(); } - return absl::InternalError("Unexpected object type"); + return InternalError("Unexpected object type"); } -absl::Status ApplyBuffersAssignment( +Status ApplyBuffersAssignment( const ObjectsAssignment& assignment, const std::vector& global_ref_to_usage_rec, const std::vector& global_ref_to_object_ptr, @@ -465,11 +462,11 @@ absl::Status ApplyBuffersAssignment( } (*global_ref_to_shared_ref)[global_ref] = shared_ref; } - return absl::OkStatus(); + return OkStatus(); } template -absl::Status ApplyTexturesAssignment( +Status ApplyTexturesAssignment( const ObjectsAssignment& assignment, const std::vector& global_ref_to_usage_rec, const std::vector& global_ref_to_object_ptr, @@ -507,7 +504,7 @@ absl::Status ApplyTexturesAssignment( } (*global_ref_to_shared_ref)[global_ref] = shared_ref; } - return absl::OkStatus(); + return OkStatus(); } } // namespace @@ -515,8 +512,7 @@ absl::Status ApplyTexturesAssignment( // Assign shared objects to internal objects, using memory allocation // algorithms. Usage records for the algorithms are calculated separately for // each data type and object type. -absl::Status Runtime::AssignInternalObjects( - std::vector* shared_objects) { +Status Runtime::AssignInternalObjects(std::vector* shared_objects) { // Build tensor usage records, clusterized by object type and data type. std::map usage_records_by_data_type; std::vector global_ref_to_object_ptr; @@ -583,10 +579,10 @@ absl::Status Runtime::AssignInternalObjects( object.object = global_ref_to_shared_ref[GetRef(object)]; } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status Runtime::Execute() { +Status Runtime::Execute() { for (const auto& descriptor : programs_) { for (auto& b : descriptor.bindings) { RETURN_IF_ERROR(b()); @@ -594,7 +590,7 @@ absl::Status Runtime::Execute() { RETURN_IF_ERROR(command_queue_->Dispatch(descriptor.program, descriptor.num_workgroups)); } - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.h b/tensorflow/lite/delegates/gpu/gl/runtime.h index 97f0f732834..b66a7fdfaa4 100644 --- a/tensorflow/lite/delegates/gpu/gl/runtime.h +++ b/tensorflow/lite/delegates/gpu/gl/runtime.h @@ -44,17 +44,17 @@ class Runtime { CommandQueue* command_queue, const ObjectManager* external_objects); // Takes parameters and objects and prepares GL program. - absl::Status AddProgram(const GlShader& shader, - const std::vector& parameters, - const std::vector& objects, - const uint3& num_workgroups); + Status AddProgram(const GlShader& shader, + const std::vector& parameters, + const std::vector& objects, + const uint3& num_workgroups); // Needs to be called once all programs and shaders has been added to runtime. - absl::Status PrepareForExecution(); + Status PrepareForExecution(); // Executes all compiled programs. // TODO(akulik): add more controls over execution. Execution policy? - absl::Status Execute(); + Status Execute(); // Gets access to objects created while executing generated code. const ObjectManager* internal_objects() const { return &internal_objects_; } @@ -72,14 +72,14 @@ class Runtime { } private: - absl::Status AllocateInternalObject(const Object& object); + Status AllocateInternalObject(const Object& object); - absl::Status AllocateConstObject(const Object& object, uint32_t* id); + Status AllocateConstObject(const Object& object, uint32_t* id); // Goes over objects in programs and decides how to allocate them to // minimize total allocated memory. Returns a collection of objects to be // allocated and shared by internal objects. - absl::Status AssignInternalObjects(std::vector* objects); + Status AssignInternalObjects(std::vector* objects); const RuntimeOptions options_; const GpuInfo gpu_info_; @@ -92,7 +92,7 @@ class Runtime { std::unique_ptr shared_readonly_buffer_; - using BindFunc = std::function; + using BindFunc = std::function; // Encapsulates a program and all object to bind before dispatch. struct CompiledProgramDescriptor { diff --git a/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h b/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h index 11b094637f2..d4f49d1952c 100644 --- a/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h +++ b/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h @@ -55,7 +55,7 @@ class SharedBufferData { bool empty() const { return shared_data_.empty(); } // Returns a single GlBuffer that owns entire shared data. - absl::Status CreateSharedGlBuffer(GlBuffer* gl_buffer) { + Status CreateSharedGlBuffer(GlBuffer* gl_buffer) { // Upload data to a buffer gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, buffer_id_.id()); @@ -64,7 +64,7 @@ class SharedBufferData { GL_STATIC_READ)); *gl_buffer = GlBuffer(GL_SHADER_STORAGE_BUFFER, buffer_id_.Release(), shared_data_.size(), 0, /*has_ownership=*/true); - return absl::OkStatus(); + return OkStatus(); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/serialization.cc b/tensorflow/lite/delegates/gpu/gl/serialization.cc index 7e15cf2d271..17db339fa98 100644 --- a/tensorflow/lite/delegates/gpu/gl/serialization.cc +++ b/tensorflow/lite/delegates/gpu/gl/serialization.cc @@ -390,15 +390,15 @@ absl::Span SerializedCompiledModelBuilder::Finalize( namespace { -absl::Status ParseParameter(const data::UniformParameter& fb_parameter, - Variable* parameter) { +Status ParseParameter(const data::UniformParameter& fb_parameter, + Variable* parameter) { parameter->name = fb_parameter.name()->str(); switch (fb_parameter.type()) { case data::ParameterType::INT32: { auto* ptr = fb_parameter.data_as_DataInt32(); if (ptr == nullptr) { - return absl::InvalidArgumentError("Unexpected data type '" + - parameter->name + "'"); + return InvalidArgumentError("Unexpected data type '" + parameter->name + + "'"); } switch (ptr->data()->size()) { case 1: @@ -412,16 +412,16 @@ absl::Status ParseParameter(const data::UniformParameter& fb_parameter, (*ptr->data())[2], (*ptr->data())[3]); break; default: - return absl::InvalidArgumentError("Unexpected size for parameter '" + - parameter->name + "'"); + return InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); } break; } case data::ParameterType::UINT32: { auto* ptr = fb_parameter.data_as_DataUint32(); if (ptr == nullptr) { - return absl::InvalidArgumentError("Unexpected data type '" + - parameter->name + "'"); + return InvalidArgumentError("Unexpected data type '" + parameter->name + + "'"); } switch (ptr->data()->size()) { case 1: @@ -432,16 +432,16 @@ absl::Status ParseParameter(const data::UniformParameter& fb_parameter, (*ptr->data())[2], (*ptr->data())[3]); break; default: - return absl::InvalidArgumentError("Unexpected size for parameter '" + - parameter->name + "'"); + return InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); } break; } case data::ParameterType::FLOAT32: { auto* ptr = fb_parameter.data_as_DataFloat(); if (ptr == nullptr) { - return absl::InvalidArgumentError("Unexpected data type '" + - parameter->name + "'"); + return InvalidArgumentError("Unexpected data type '" + parameter->name + + "'"); } switch (ptr->data()->size()) { case 1: @@ -455,21 +455,21 @@ absl::Status ParseParameter(const data::UniformParameter& fb_parameter, (*ptr->data())[2], (*ptr->data())[3]); break; default: - return absl::InvalidArgumentError("Unexpected size for parameter '" + - parameter->name + "'"); + return InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); } break; } case data::ParameterType::INT32_2: { auto* ptr = fb_parameter.data_as_DataInt32(); if (ptr == nullptr) { - return absl::InvalidArgumentError("Unexpected data type '" + - parameter->name + "'"); + return InvalidArgumentError("Unexpected data type '" + parameter->name + + "'"); } if (ptr->data()->size() % 2 != 0) { - return absl::InvalidArgumentError("Unexpected size for parameter '" + - parameter->name + "'"); + return InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); } std::vector values(ptr->data()->size() / 2); @@ -480,7 +480,7 @@ absl::Status ParseParameter(const data::UniformParameter& fb_parameter, break; } } - return absl::OkStatus(); + return OkStatus(); } DataType ToEnum(data::DataType type) { @@ -520,7 +520,7 @@ AccessType ToEnum(data::AccessType type) { } } -absl::Status ParseObject(const data::Object& fb_object, Object* object) { +Status ParseObject(const data::Object& fb_object, Object* object) { object->access = ToEnum(fb_object.access()); object->binding = fb_object.binding(); object->object_type = ToEnum(fb_object.type()); @@ -543,7 +543,7 @@ absl::Status ParseObject(const data::Object& fb_object, Object* object) { break; } case data::ObjectSize::NONE: - return absl::InvalidArgumentError("Texture size is not set"); + return InvalidArgumentError("Texture size is not set"); } switch (fb_object.object_type()) { @@ -560,10 +560,10 @@ absl::Status ParseObject(const data::Object& fb_object, Object* object) { break; } case data::ObjectVariant::NONE: { - return absl::InvalidArgumentError("Object is not set"); + return InvalidArgumentError("Object is not set"); } } - return absl::OkStatus(); + return OkStatus(); } CompiledModelOptions ParseParameters(const data::Parameters& fb_parameters) { @@ -574,11 +574,11 @@ CompiledModelOptions ParseParameters(const data::Parameters& fb_parameters) { } // namespace -absl::Status DeserializeCompiledModel(absl::Span serialized, - DeserializationHandler* handler) { +Status DeserializeCompiledModel(absl::Span serialized, + DeserializationHandler* handler) { flatbuffers::Verifier verifier(serialized.data(), serialized.size()); if (!data::VerifyCompiledModelBuffer(verifier)) { - return absl::InvalidArgumentError("Serialized model is corrupted."); + return InvalidArgumentError("Serialized model is corrupted."); } auto model = data::GetCompiledModel(serialized.data()); @@ -612,7 +612,7 @@ absl::Status DeserializeCompiledModel(absl::Span serialized, program->shader_index())); } handler->OnOptions(ParseParameters(*model->parameters())); - return absl::OkStatus(); + return OkStatus(); } } // namespace gl diff --git a/tensorflow/lite/delegates/gpu/gl/serialization.h b/tensorflow/lite/delegates/gpu/gl/serialization.h index 82b76a475f5..c3c88b4c462 100644 --- a/tensorflow/lite/delegates/gpu/gl/serialization.h +++ b/tensorflow/lite/delegates/gpu/gl/serialization.h @@ -67,19 +67,19 @@ class DeserializationHandler { public: virtual ~DeserializationHandler() = default; - virtual absl::Status OnShader(absl::Span shader_src) = 0; + virtual Status OnShader(absl::Span shader_src) = 0; - virtual absl::Status OnProgram(const std::vector& parameters, - const std::vector& objects, - const uint3& workgroup_size, - const uint3& num_workgroups, - size_t shader_index) = 0; + virtual Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, + const uint3& num_workgroups, + size_t shader_index) = 0; virtual void OnOptions(const CompiledModelOptions& options) = 0; }; -absl::Status DeserializeCompiledModel(absl::Span serialized, - DeserializationHandler* handler); +Status DeserializeCompiledModel(absl::Span serialized, + DeserializationHandler* handler); } // namespace gl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/serialization_test.cc b/tensorflow/lite/delegates/gpu/gl/serialization_test.cc index 37c08129139..25aa9be73b2 100644 --- a/tensorflow/lite/delegates/gpu/gl/serialization_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/serialization_test.cc @@ -45,19 +45,18 @@ struct ProgramDesc { }; struct Handler : public DeserializationHandler { - absl::Status OnShader(absl::Span shader_src) final { + Status OnShader(absl::Span shader_src) final { shaders.push_back(std::string(shader_src.data(), shader_src.size())); - return absl::OkStatus(); + return OkStatus(); } - absl::Status OnProgram(const std::vector& parameters, - const std::vector& objects, - const uint3& workgroup_size, - const uint3& num_workgroups, - size_t shader_index) final { + Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, const uint3& num_workgroups, + size_t shader_index) final { programs.push_back( {parameters, objects, workgroup_size, num_workgroups, shader_index}); - return absl::OkStatus(); + return OkStatus(); } void OnOptions(const CompiledModelOptions& o) final { options = o; } diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc index 5ebefb4a6eb..16aaafa5c94 100644 --- a/tensorflow/lite/delegates/gpu/gl_delegate.cc +++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc @@ -93,8 +93,7 @@ class Delegate { } } - absl::Status CopyFromBufferHandle(TfLiteBufferHandle handle, - TfLiteTensor* tensor) { + Status CopyFromBufferHandle(TfLiteBufferHandle handle, TfLiteTensor* tensor) { ValueRef ref; RETURN_IF_ERROR(FindObject(handle, &ref)); auto buffer = phwc4_objects_.FindBuffer(handle); @@ -106,8 +105,8 @@ class Delegate { }); } - absl::Status CopyToBufferHandle(TfLiteBufferHandle handle, - TfLiteTensor* tensor) const { + Status CopyToBufferHandle(TfLiteBufferHandle handle, + TfLiteTensor* tensor) const { ValueRef ref; RETURN_IF_ERROR(FindObject(handle, &ref)); auto buffer = phwc4_objects_.FindBuffer(handle); @@ -118,7 +117,7 @@ class Delegate { }); } - absl::Status BindBufferToTensor(GLuint ssbo, int tensor_index) { + Status BindBufferToTensor(GLuint ssbo, int tensor_index) { int64_t bytes_size; RETURN_IF_ERROR(GetSSBOSize(ssbo, &bytes_size)); return bhwc_objects_.RegisterBuffer( @@ -127,8 +126,8 @@ class Delegate { /* has_ownership = */ false)); } - absl::Status Prepare(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params) { + Status Prepare(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { // Extract TFLite delegate execution plan from the context and convert it // into FlowGraph32. GraphFloat32 graph; @@ -138,7 +137,7 @@ class Delegate { NullTransformationReporter reporter; ModelTransformer transformer(&graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + return InternalError("Graph general transformations failed"); } if (!env_) RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env_)); @@ -177,7 +176,7 @@ class Delegate { tflite_graph_io.insert(tensor_index); const auto* input = find_value(tensor_index); if (!input || tensor->type != TfLiteType::kTfLiteFloat32) { - return absl::NotFoundError("Input tensor is not found in the graph."); + return NotFoundError("Input tensor is not found in the graph."); } inputs_.push_back(input->id); @@ -216,8 +215,7 @@ class Delegate { tflite_graph_io.insert(tensor_index); const auto* output = find_value(tensor_index); if (!output || tensor->type != TfLiteType::kTfLiteFloat32) { - return absl::NotFoundError( - "Output tensor is not found in the graph."); + return NotFoundError("Output tensor is not found in the graph."); } outputs_.push_back(output->id); @@ -272,14 +270,14 @@ class Delegate { RETURN_IF_ERROR(compiled_model->NewRun(runtime_options, &phwc4_objects_, command_queue_.get(), &inference_context_)); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Invoke(TfLiteContext* context) { + Status Invoke(TfLiteContext* context) { const EGLContext egl_context_at_delegate_init = env_->context().context(); const EGLContext egl_context_at_delegate_invoke = eglGetCurrentContext(); if (egl_context_at_delegate_init != egl_context_at_delegate_invoke) { - return absl::FailedPreconditionError( + return FailedPreconditionError( "Delegate should run on the same thread where it was initialized."); } @@ -332,18 +330,18 @@ class Delegate { RETURN_IF_ERROR(CopyFromBufferHandle(id, &tensor)); } } - return absl::OkStatus(); + return OkStatus(); } TfLiteDelegate* tflite_delegate() { return &delegate_; } private: - absl::Status FindObject(ValueId id, ValueRef* ref) const { + Status FindObject(ValueId id, ValueRef* ref) const { if (id >= tensors_.size()) { - return absl::InvalidArgumentError("Invalid buffer id"); + return InvalidArgumentError("Invalid buffer id"); } *ref = tensors_[id]; - return absl::OkStatus(); + return OkStatus(); } TfLiteDelegate delegate_ = { @@ -389,7 +387,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = gpu_delegate->Prepare(context, params); if (status.ok()) return gpu_delegate; context->ReportError(context, "TfLiteGpuDelegate Prepare: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return nullptr; }, // .free @@ -403,7 +401,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = GetGpuDelegate(node)->Invoke(context); if (status.ok()) return kTfLiteOk; context->ReportError(context, "TfLiteGpuDelegate Invoke: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return kTfLiteError; }, nullptr, // .profiling_string @@ -427,7 +425,7 @@ TfLiteStatus DelegateCopyFromBufferHandle(TfLiteContext* context, const auto status = gpu_delegate->CopyFromBufferHandle(buffer_handle, tensor); if (status.ok()) return kTfLiteOk; context->ReportError(context, "TfLiteGpuDelegate CopyFromBufferHandle: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return kTfLiteError; } @@ -440,7 +438,7 @@ TfLiteStatus DelegateCopyToBufferHandle(TfLiteContext* context, const auto status = gpu_delegate->CopyToBufferHandle(buffer_handle, tensor); if (status.ok()) return kTfLiteOk; context->ReportError(context, "TfLiteGpuDelegate CopyToBufferHandle: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return kTfLiteError; } diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index dedb2aa8df1..6abcbcaed4f 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -141,14 +141,13 @@ std::vector SelectSpaceToDepth( return SpaceToDepth(id, input_id, output_id, attr); } -absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, - const std::vector& inputs, - const std::vector& outputs, - const RuntimeOptions& options, - std::vector* tasks) { +Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, + const std::vector& inputs, + const std::vector& outputs, + const RuntimeOptions& options, + std::vector* tasks) { if (!IsBatchMatchesForAllValues(graph)) { - return absl::InvalidArgumentError( - "Only identical batch dimension is supported"); + return InvalidArgumentError("Only identical batch dimension is supported"); } int node_id = static_cast(node->id); auto op_type = OperationTypeFromString(node->operation.type); @@ -241,7 +240,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, case OperationType::PAD: { auto attr = absl::any_cast(node->operation.attributes); if (attr.appended.b != 0 || attr.prepended.b != 0) { - return absl::UnimplementedError("Padding for BATCH is not supported."); + return UnimplementedError("Padding for BATCH is not supported."); } *tasks = Padding(node_id, inputs[0], outputs[0], attr); break; @@ -278,8 +277,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, case OperationType::SOFTMAX: { auto attr = absl::any_cast(node->operation.attributes); if (attr.axis != Axis::CHANNELS) { - return absl::UnimplementedError( - "Softmax supports only CHANNELS dimension"); + return UnimplementedError("Softmax supports only CHANNELS dimension"); } *tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]); break; @@ -331,16 +329,15 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, case OperationType::SPACE_TO_BATCH: case OperationType::TRANSPOSE: case OperationType::UNKNOWN: - return absl::UnimplementedError("Unsupported op: " + - node->operation.type); + return UnimplementedError("Unsupported op: " + node->operation.type); } - return absl::OkStatus(); + return OkStatus(); } } // namespace -absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, - CompiledModel* compiled_model) { +Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, + CompiledModel* compiled_model) { for (const auto& node : graph.nodes()) { std::vector inputs; for (auto& input : graph.FindInputs(node->id)) { @@ -357,11 +354,11 @@ absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, auto primary_status = RegisterPrimaryOps(graph, node, inputs, outputs, options, &tasks); if (!primary_status.ok()) { - return absl::UnimplementedError( - absl::Substitute("Unsupported op type: $0; custom registry error: " - "$1; primary registry error: $2;", - node->operation.type, custom_status.message(), - primary_status.message())); + return UnimplementedError(absl::Substitute( + "Unsupported op type: $0; custom registry error: " + "$1; primary registry error: $2;", + node->operation.type, custom_status.error_message(), + primary_status.error_message())); } } for (auto task : tasks) { @@ -369,7 +366,7 @@ absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, } compiled_model->insert(compiled_model->end(), tasks.begin(), tasks.end()); } - return absl::OkStatus(); + return OkStatus(); } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/api.h b/tensorflow/lite/delegates/gpu/metal/api.h index c1c7648638c..dd3c423a612 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.h +++ b/tensorflow/lite/delegates/gpu/metal/api.h @@ -26,8 +26,8 @@ namespace gpu { namespace metal { // Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions. -absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, - CompiledModel* compiled_model); +Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, + CompiledModel* compiled_model); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/common.h b/tensorflow/lite/delegates/gpu/metal/common.h index 6f4e94ed2e7..9d7d66176f6 100644 --- a/tensorflow/lite/delegates/gpu/metal/common.h +++ b/tensorflow/lite/delegates/gpu/metal/common.h @@ -39,9 +39,10 @@ id GetBestSupportedMetalDevice(); /// both. /// @discussion The function autoselects the maximum shader language version supported by the target /// OS. FastMath is enabled. -absl::Status CreateComputeProgram(id device, NSString* code, NSString* functionName, - NSDictionary* macros, - id* program); +::tflite::gpu::Status CreateComputeProgram(id device, NSString* code, + NSString* functionName, + NSDictionary* macros, + id* program); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/common.mm b/tensorflow/lite/delegates/gpu/metal/common.mm index cc5a98dfffc..7167430a343 100644 --- a/tensorflow/lite/delegates/gpu/metal/common.mm +++ b/tensorflow/lite/delegates/gpu/metal/common.mm @@ -34,9 +34,9 @@ namespace metal { id GetBestSupportedMetalDevice() { return MTLCreateSystemDefaultDevice(); } -absl::Status CreateComputeProgram(id device, NSString* code, NSString* functionName, - NSDictionary* macros, - id* program) { +Status CreateComputeProgram(id device, NSString* code, NSString* functionName, + NSDictionary* macros, + id* program) { MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; // Runtime checks for the iOS version independently of minimum target iOS. @@ -70,14 +70,14 @@ absl::Status CreateComputeProgram(id device, NSString* code, NSString if (!library) { NSString* errorString = [NSString stringWithFormat:@"newLibraryWithSource: %@", [error localizedDescription]]; - return absl::InternalError([errorString UTF8String]); + return InternalError([errorString UTF8String]); } id function = [library newFunctionWithName:functionName]; if (!function) { NSString* errorString = [NSString stringWithFormat:@"newFunctionWithName: %@", [error localizedDescription]]; - return absl::InternalError([errorString UTF8String]); + return InternalError([errorString UTF8String]); } *program = [device newComputePipelineStateWithFunction:function error:&error]; @@ -85,9 +85,9 @@ absl::Status CreateComputeProgram(id device, NSString* code, NSString NSString* errorString = [NSString stringWithFormat:@"newComputePipelineStateWithFunction error: %@", [error localizedDescription]]; - return absl::InternalError([errorString UTF8String]); + return InternalError([errorString UTF8String]); } - return absl::OkStatus(); + return OkStatus(); } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/common_test.mm b/tensorflow/lite/delegates/gpu/metal/common_test.mm index 7cedac0f799..18a495ebd18 100644 --- a/tensorflow/lite/delegates/gpu/metal/common_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/common_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include @@ -26,6 +25,7 @@ limitations under the License. using ::tflite::gpu::metal::GetBestSupportedMetalDevice; using ::tflite::gpu::metal::CreateComputeProgram; +using ::tflite::gpu::Status; @interface CommonTest : XCTestCase @@ -53,16 +53,16 @@ kernel void FunctionName(device TYPE* const src_buffer[[buffer(0)]], XCTAssertNotNil(device, @"The Metal device must exists on real device"); NSString* functionName = @"FunctionName"; id program; - absl::Status status; + Status status; NSDictionary* macrosFloat4 = @{@"TYPE" : @"float4"}; status = CreateComputeProgram(device, code, functionName, macrosFloat4, &program); - XCTAssertTrue(status.ok(), @"%s", std::string(status.messasge()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); XCTAssertNotNil(program); NSDictionary* macrosHalf4 = @{@"TYPE" : @"half4"}; status = CreateComputeProgram(device, code, functionName, macrosHalf4, &program); - XCTAssertTrue(status.ok(), @"%s", std::string(status.messasge()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); XCTAssertNotNil(program); // This compilation is intended to be incorrect diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc index 06cc10a0520..5545ad39161 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc @@ -564,10 +564,10 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { } // namespace -absl::Status ValidateOptimizeModel(const std::vector& input_buffers, - const std::vector& output_buffers, - const CompiledModel& input_vector, - CompiledModel* output) { +Status ValidateOptimizeModel(const std::vector& input_buffers, + const std::vector& output_buffers, + const CompiledModel& input_vector, + CompiledModel* output) { std::list input; input.insert(input.end(), input_vector.begin(), input_vector.end()); OptimizationInfo info; @@ -606,10 +606,10 @@ absl::Status ValidateOptimizeModel(const std::vector& input_buffers, std::to_string(info.unused_input_buffer_ids.size()) + "\nMissing output buffers " + std::to_string(info.missing_output_buffer_ids.size()); - return absl::InternalError(message); + return InternalError(message); } for (const auto& chain : sorted_chains) output->push_back(FuseChain(chain)); - return absl::OkStatus(); + return OkStatus(); } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.h b/tensorflow/lite/delegates/gpu/metal/compiled_model.h index 222534402d9..5f9982d0a66 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model.h +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.h @@ -31,10 +31,9 @@ using CompiledModel = std::vector; // Receives input CompiledModel, validates, optimizes it and returns output // CompiledModel. No shader compilation or memory allocation happen here, this // function just does high-level operations fusion. -absl::Status ValidateOptimizeModel(const std::vector& input_buffers, - const std::vector& output_buffers, - const CompiledModel& input, - CompiledModel* output); +Status ValidateOptimizeModel(const std::vector& input_buffers, + const std::vector& output_buffers, + const CompiledModel& input, CompiledModel* output); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm b/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm index 83870123321..59827ce2c08 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm @@ -183,7 +183,7 @@ static std::vector Add2Linkable(int id, ValueId input_ auto nodes = MulLinkable(1, 1, 2); std::vector model; auto status = ValidateOptimizeModel({1}, {2}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } // Outputs: one missing, one unused. @@ -195,8 +195,8 @@ static std::vector Add2Linkable(int id, ValueId input_ std::vector errorMessages = {"Input operations count 1", "Unused operations 1", "Unused inputs 1", "Missing output buffers 1"}; for (const std::string& message : errorMessages) { - bool doesContainMessage = std::string(status.message()).find(message) != std::string::npos; - XCTAssertTrue(doesContainMessage, @"%s", std::string(status.message()).c_str()); + bool doesContainMessage = status.error_message().find(message) != std::string::npos; + XCTAssertTrue(doesContainMessage, @"%s", status.error_message().c_str()); } } @@ -205,7 +205,7 @@ static std::vector Add2Linkable(int id, ValueId input_ auto nodes = MulLinkable(1, 1, 2); std::vector model; auto status = ValidateOptimizeModel({1}, {2, 3}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } // Unused input => empty graph, missing output. @@ -216,8 +216,8 @@ static std::vector Add2Linkable(int id, ValueId input_ std::vector errorMessages = {"Input operations count 1", "Unused operations 0", "Unused inputs 1", "Missing output buffers 1"}; for (const std::string& message : errorMessages) { - bool doesContainMessage = std::string(status.message()).find(message) != std::string::npos; - XCTAssertTrue(doesContainMessage, @"%s", std::string(status.message()).c_str()); + bool doesContainMessage = status.error_message().find(message) != std::string::npos; + XCTAssertTrue(doesContainMessage, @"%s", status.error_message().c_str()); } } @@ -228,7 +228,7 @@ static std::vector Add2Linkable(int id, ValueId input_ nodes.insert(nodes.end(), nodes2.begin(), nodes2.end()); std::vector model; auto status = ValidateOptimizeModel({1}, {3}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } // Two sequential operations. Not fused. @@ -238,14 +238,14 @@ static std::vector Add2Linkable(int id, ValueId input_ nodes.insert(nodes.end(), nodes2.begin(), nodes2.end()); std::vector model; auto status = ValidateOptimizeModel({1}, {3}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testAddOperationSuccess { auto nodes = Add2(1, 1, 2, 3); std::vector model; auto status = ValidateOptimizeModel({1, 2}, {3}, nodes, &model); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testAddOperationFused { @@ -254,7 +254,7 @@ static std::vector Add2Linkable(int id, ValueId input_ graph.insert(graph.end(), graph2.begin(), graph2.end()); std::vector model; auto status = ValidateOptimizeModel({1, 2}, {4}, graph, &model); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); XCTAssertTrue(model.size() == 1, @"Not fused, more than one task descriptor."); } @@ -266,7 +266,7 @@ static std::vector Add2Linkable(int id, ValueId input_ graph.insert(graph.end(), graph3.begin(), graph3.end()); std::vector model; auto status = ValidateOptimizeModel({1, 2}, {5}, graph, &model); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.h b/tensorflow/lite/delegates/gpu/metal/compute_task.h index b03a8436077..611185b8fc1 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.h +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.h @@ -31,12 +31,12 @@ limitations under the License. @interface TFLComputeTask : NSObject /// Returns empty string or error if shader can't be compiled. -- (absl::Status)compileWithDevice:(id)device - taskDescriptor:(::tflite::gpu::metal::ComputeTaskDescriptorPtr)desc - runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; +- (::tflite::gpu::Status)compileWithDevice:(id)device + taskDescriptor:(::tflite::gpu::metal::ComputeTaskDescriptorPtr)desc + runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; /// Updates dimensions for inputs/outputs/intermediate tensors -- (absl::Status) +- (::tflite::gpu::Status) setInputDimensionsWithDevice:(id)device dimensions:(std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions; @@ -50,11 +50,12 @@ limitations under the License. /// @param sharedBufferIds contain shared buffer id for each tensor usage record id. /// @param sharedBuffers contain metal handles to the allocated buffers for each shared buffer id. /// TODO(ypisarchyk): probably we can decrease the number of parameters here -- (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers - outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds - usageRecordIds:(const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds - sharedBufferIds:(const std::vector&)sharedBufferIds - sharedBuffers:(const std::vector>&)sharedBuffers; +- (::tflite::gpu::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers + outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds + usageRecordIds: + (const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds + sharedBufferIds:(const std::vector&)sharedBufferIds + sharedBuffers:(const std::vector>&)sharedBuffers; - (void)encodeWithEncoder:(id)encoder inputOutputBuffers: diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.mm b/tensorflow/lite/delegates/gpu/metal/compute_task.mm index d3e3466ca6f..24b89c1b11c 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.mm +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.mm @@ -29,6 +29,8 @@ limitations under the License. using ::tflite::gpu::AlignByN; using ::tflite::gpu::BHWC; +using ::tflite::gpu::InternalError; +using ::tflite::gpu::InvalidArgumentError; using ::tflite::gpu::HalfBits; using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; using ::tflite::gpu::metal::CreateComputeProgram; @@ -36,6 +38,8 @@ using ::tflite::gpu::metal::DispatchParamsFunction; using ::tflite::gpu::metal::OutputDimensions; using ::tflite::gpu::metal::RuntimeOptions; using ::tflite::gpu::metal::UniformsFunction; +using ::tflite::gpu::OkStatus; +using ::tflite::gpu::Status; using ::tflite::gpu::uint3; using ::tflite::gpu::ValueId; @@ -66,9 +70,9 @@ using ::tflite::gpu::ValueId; std::string _description; } -- (absl::Status)compileWithDevice:(id)device - taskDescriptor:(ComputeTaskDescriptorPtr)desc - runtimeOptions:(const RuntimeOptions&)options { +- (Status)compileWithDevice:(id)device + taskDescriptor:(ComputeTaskDescriptorPtr)desc + runtimeOptions:(const RuntimeOptions&)options { NSString* barrier; // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0 if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) { @@ -119,7 +123,7 @@ using ::tflite::gpu::ValueId; id program; RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program)); if (!program) { - return absl::InternalError("Unknown shader compilation error"); + return InternalError("Unknown shader compilation error"); } for (auto& buffer : desc->input_buffers) { _inputBuffers.emplace_back(InputBuffer{buffer.id, nil}); @@ -144,13 +148,12 @@ using ::tflite::gpu::ValueId; _resizeFunction = desc->resize_function; _program = program; _description = desc->description; - return absl::OkStatus(); + return OkStatus(); } -- (absl::Status)setInputDimensionsWithDevice:(id)device - dimensions: - (std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*) - dimensions { +- (Status)setInputDimensionsWithDevice:(id)device + dimensions: + (std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions { // Re-calculate output buffers dimensions for (auto& buffer : _outputBuffers) { auto outputDimensions = buffer.dimensionsFunction(*dimensions); @@ -177,23 +180,23 @@ using ::tflite::gpu::ValueId; error += "is larger than the MTLDevice can support: "; error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) + ", " + std::to_string(threadsPerGroup.depth); - return absl::InvalidArgumentError(error); + return InvalidArgumentError(error); } _groupsCount = workGroups.second; - return absl::OkStatus(); + return OkStatus(); } -- (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers - outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds - usageRecordIds:(const std::map&)usageRecordIds - sharedBufferIds:(const std::vector&)sharedBufferIds - sharedBuffers:(const std::vector>&)sharedBuffers { +- (Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers + outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds + usageRecordIds:(const std::map&)usageRecordIds + sharedBufferIds:(const std::vector&)sharedBufferIds + sharedBuffers:(const std::vector>&)sharedBuffers { for (auto& buffer : _outputBuffers) { // If the buffer is intermediate: set its metalHandle from sharedBuffers if (std::find(outputIds.begin(), outputIds.end(), buffer.uid) == outputIds.end()) { auto usageRecordIt = usageRecordIds.find(buffer.uid); if (usageRecordIt == usageRecordIds.end()) { - return absl::InternalError("TensorUsageRecord for intermediate tensor is not found."); + return InternalError("TensorUsageRecord for intermediate tensor is not found."); } buffer.metalHandle = sharedBuffers.at(sharedBufferIds.at(usageRecordIt->second)); (*buffers)[buffer.uid] = buffer.metalHandle; @@ -204,7 +207,7 @@ using ::tflite::gpu::ValueId; for (auto& buffer : _inputBuffers) { buffer.metalHandle = (*buffers)[buffer.uid]; } - return absl::OkStatus(); + return OkStatus(); } - (void)encodeWithEncoder:(id)encoder diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h index 97a6f3b3b18..8569a4ed009 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.h +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h @@ -50,12 +50,12 @@ limitations under the License. /// @return Status signals whether model is compiled successfully or not. /// @discussion Previously added operations are distilled into sorted list of sets of /// ComputeTaskDescriptors, which can be fused into a single GPU task. -- (absl::Status)compileModelWithDevice:(id)device - taskDescriptors: - (const std::vector<::tflite::gpu::metal::ComputeTaskDescriptorPtr>&) - taskDescriptors - outputBufferIDs:(const std::vector<::tflite::gpu::ValueId>&)outputBufferIDs - runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; +- (::tflite::gpu::Status) + compileModelWithDevice:(id)device + taskDescriptors: + (const std::vector<::tflite::gpu::metal::ComputeTaskDescriptorPtr>&)taskDescriptors + outputBufferIDs:(const std::vector<::tflite::gpu::ValueId>&)outputBufferIDs + runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; /// Creates intermediate buffers. The model is ready to be used after this call. /// @param inputDimensions Used to create resources: shaders, buffers. @@ -63,7 +63,7 @@ limitations under the License. /// @return Status signals whether intermediate buffers are successfully created or not. /// @discussion The operation is intended to be lightweight with minimum overhead. A preceding call /// compileModelWithDevice() must be made with the proper device parameter set. -- (absl::Status) +- (::tflite::gpu::Status) setInputDimensions:(const std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>&)inputDimensions outputDimensions:(std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)outputDimensions taskDescriptors: diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.mm b/tensorflow/lite/delegates/gpu/metal/inference_context.mm index d5589ae8ab4..fb3a51f4694 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.mm +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.mm @@ -32,6 +32,9 @@ limitations under the License. using ::tflite::gpu::BHWC; using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; using ::tflite::gpu::metal::RuntimeOptions; +using ::tflite::gpu::InternalError; +using ::tflite::gpu::OkStatus; +using ::tflite::gpu::Status; using ::tflite::gpu::ValueId; using ::tflite::gpu::AlignByN; using ::tflite::gpu::HalfBits; @@ -45,10 +48,10 @@ using ::tflite::gpu::TensorUsageRecord; RuntimeOptions _options; } -- (absl::Status)compileModelWithDevice:(id)device - taskDescriptors:(const std::vector&)taskDescriptors - outputBufferIDs:(const std::vector&)requestedOutputBufferIDs - runtimeOptions:(const RuntimeOptions&)options { +- (Status)compileModelWithDevice:(id)device + taskDescriptors:(const std::vector&)taskDescriptors + outputBufferIDs:(const std::vector&)requestedOutputBufferIDs + runtimeOptions:(const RuntimeOptions&)options { _device = device; _outputIds = requestedOutputBufferIDs; _options = options; @@ -58,12 +61,12 @@ using ::tflite::gpu::TensorUsageRecord; RETURN_IF_ERROR([task compileWithDevice:_device taskDescriptor:node runtimeOptions:_options]); _computeTasks.emplace_back(task); } - return absl::OkStatus(); + return OkStatus(); } -- (absl::Status)setInputDimensions:(const std::map&)inputDimensions - outputDimensions:(std::map*)outputDimensions - taskDescriptors:(const std::vector&)taskDescriptors { +- (Status)setInputDimensions:(const std::map&)inputDimensions + outputDimensions:(std::map*)outputDimensions + taskDescriptors:(const std::vector&)taskDescriptors { // These maps contain all input/output/intermediate buffers shared across model. std::map dimensions = inputDimensions; std::map> buffers; @@ -94,7 +97,7 @@ using ::tflite::gpu::TensorUsageRecord; if (!usageRecordIds.count(outputId)) { const auto it = dimensions.find(outputId); if (it == dimensions.end()) { - return absl::InternalError("Dimensions for intermediate tensor not found."); + return InternalError("Dimensions for intermediate tensor not found."); } usageRecordIds[outputId] = usageRecords.size(); usageRecords.emplace_back(it->second.w * it->second.h * AlignByN(it->second.c, 4), i, i); @@ -130,14 +133,14 @@ using ::tflite::gpu::TensorUsageRecord; error += std::to_string(assignment.object_ids[i]) + " with size: " + std::to_string(bufferSize) + " exceeds MTLDevice maxBufferLength: " + std::to_string([_device maxBufferLength]); - return absl::ResourceExhaustedError(error); + return ::tflite::gpu::ResourceExhaustedError(error); } #endif #if defined(__MAC_10_12) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_12 if ([_device currentAllocatedSize] + bufferSize > [_device recommendedMaxWorkingSetSize]) { std::string error("Out of memory in MTLBuffer allocation. Currently allocated: "); error += std::to_string([_device currentAllocatedSize]); - return absl::ResourceExhaustedError(error); + return ::tflite::gpu::ResourceExhaustedError(error); } #endif @@ -151,7 +154,7 @@ using ::tflite::gpu::TensorUsageRecord; sharedBufferIds:assignment.object_ids sharedBuffers:sharedBuffers]); } - return absl::OkStatus(); + return OkStatus(); } - (void)encodeWithEncoder:(id)commandEncoder diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context_test.mm b/tensorflow/lite/delegates/gpu/metal/inference_context_test.mm index 4d9e54a0ca0..14ea40c68b4 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/inference_context_test.mm @@ -17,8 +17,6 @@ limitations under the License. #import -#include - #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" @@ -172,9 +170,9 @@ static std::vector MulArrayLinkable( std::map inputs{{inputBufferID, input}}; std::map outputs{{outputBufferID, {}}}; auto status = RunGraph(graph, _device, inputs, &outputs); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2.2f, 3.3f, 4.4f}, outputs[outputBufferID].data, 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testImmutableShaderOutput { @@ -189,9 +187,9 @@ static std::vector MulArrayLinkable( std::map inputs{{inputBufferID, input}}; std::map outputs{{outputBufferID, {}}}; auto status = RunGraph(graph, _device, inputs, &outputs); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 4, 9, 16, 25, 36, 49}, outputs[outputBufferID].data, 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testUniformShaderOutput { @@ -205,9 +203,9 @@ static std::vector MulArrayLinkable( std::map inputs{{inputBufferID, input}}; std::map outputs{{outputBufferID, {}}}; auto status = RunGraph(graph, _device, inputs, &outputs); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 4, 6}, outputs[outputBufferID].data, 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testUniformAndImmutableShaderOutput { @@ -224,9 +222,9 @@ static std::vector MulArrayLinkable( std::map inputs{{inputBufferID, input}}; std::map outputs{{outputBufferID, {}}}; auto status = RunGraph(graph, _device, inputs, &outputs); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 6, 12, 20, 26, 38, 52}, outputs[outputBufferID].data, 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm index 540308f23b4..10481b2a867 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -66,9 +65,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8})); XCTAssertTrue(model.PopulateTensor(1, {0.1, 0.2, 0.3, 0.5})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({-1.9, 0.4, 1.0, 1.3}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testInputTensorAndScalar { @@ -86,9 +85,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testInputTensorWithConstantBroadcast { @@ -113,10 +112,10 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({11.0, 22.0, 13.0, 24.0, 15.0, 26.0, 17.0, 28.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm index 195a2986628..b67c1ca839c 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -67,9 +66,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 3, 5, 7})); XCTAssertTrue(model.PopulateTensor(1, {2, 4, 6, 8})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6, 7, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTwoInputTensorsByAlignedChannel { @@ -93,9 +92,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(1, {5, 6, 7, 8})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6, 7, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTwoInputTensorsByHeight { @@ -119,9 +118,9 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 2})); XCTAssertTrue(model.PopulateTensor(1, {3, 4, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTwoInputTensorsByWidth { @@ -145,8 +144,8 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 4})); XCTAssertTrue(model.PopulateTensor(1, {2, 3, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm index a74b22cf13e..8f1b24a4735 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -83,9 +82,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({4, 8, 4, 8, 2, 4, 2, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testO1H2W2I1Stride1x1Dilation2x2 { @@ -121,9 +120,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1, 1, 1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({10}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testO1H3W3I1Stride1x1Dilation1x1 { @@ -159,9 +158,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({11}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testO2H1W1I2Stride1x1Dilation1x1 { @@ -197,9 +196,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({4, 8, 4, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testO1H1W1I1Stride2x2Dilation1x1 { @@ -236,9 +235,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 0, 2, 0, 0, 0, 4, 0, 8})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 4, 8, 16}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc index 620a4581c52..228583c6e30 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc @@ -26,12 +26,12 @@ namespace tflite { namespace gpu { namespace metal { -absl::Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, - const std::vector& inputs, - const std::vector& outputs, - const RuntimeOptions& options, - std::vector* tasks) { - return absl::UnimplementedError("Unsupported op: " + node->operation.type); +Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, + const std::vector& inputs, + const std::vector& outputs, + const RuntimeOptions& options, + std::vector* tasks) { + return UnimplementedError("Unsupported op: " + node->operation.type); } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h index eee1632a644..bef2ba20def 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h @@ -28,11 +28,11 @@ namespace gpu { namespace metal { // Registers custom operations. -absl::Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, - const std::vector& inputs, - const std::vector& outputs, - const RuntimeOptions& options, - std::vector* tasks); +Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, + const std::vector& inputs, + const std::vector& outputs, + const RuntimeOptions& options, + std::vector* tasks); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm index d76507253a9..5f262238464 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -84,9 +83,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 3})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 4, 12, 16}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testO2H1W1I1Strides2x2Dilation1x1 { @@ -123,9 +122,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 0, 1, 1, 0, 1, 1, 0, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 3, 1, 3, 1, 3, 1, 3}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testO2H2W2I1Strides1x1Dilation2x2 { @@ -162,9 +161,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 0, 1, 1, 0, 1, 1, 0, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({10, 26}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm index 6b30bc5c703..2c3f6b942ac 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -60,9 +59,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, 6.2, 2.0, 4.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testCos { @@ -73,9 +72,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, -1.0, -1.0, 0.540302}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testDiv { @@ -87,9 +86,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, -3.1, -4.0, 1.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testExp { @@ -100,11 +99,11 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0f, 1.0f, -1.0f, 100.0f, -100.0f, 0.01f, -0.01f})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({std::exp(0.0f), std::exp(1.0f), std::exp(-1.0f), std::exp(100.0f), std::exp(-100.0f), std::exp(0.01f), std::exp(-0.01f)}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testHardSwish { @@ -115,10 +114,10 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {-4.5f, -3.0f, -1.5f, 0.0f, 1.5f, 3.0f, 4.5f})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testLog { @@ -129,9 +128,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 3.1415926, 1.0, 1.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, 1.14473, 0.0, 0.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testMaximum { @@ -143,9 +142,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 2.0, 3.0, -2.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testMaximumWithScalar { @@ -158,9 +157,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, -1.0, 2.0, -1.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testMinimum { @@ -172,9 +171,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, -6.2, 2.0, -3.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testMinimumWithScalar { @@ -187,9 +186,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({-1.0, -6.2, -1.0, -3.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testPow { @@ -201,9 +200,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, 1.0, 8.0, 256.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testRsqrt { @@ -214,9 +213,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 0.707106, 0.5, 0.333333}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSigmoid { @@ -227,9 +226,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.5, 0.002473, 0.880797, 0.982014}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSin { @@ -240,9 +239,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, 0.0, 0.0, 0.841471}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSqrt { @@ -253,9 +252,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, 1.0, 1.414213, 2.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSquare { @@ -266,9 +265,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 4.0, 0.25, 9.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSquaredDiff { @@ -280,9 +279,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 1.0, 9.0, 0.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSub { @@ -294,9 +293,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({-1.0, -8.2, -1.0, 0.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTanh { @@ -307,9 +306,9 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { /*outputs=*/{GetTensorRef(1, shape)}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, -0.999987, 0.964027, 0.999329}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testMulBroadcastChannels { diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm index e57f9aa84e2..6d3a3e697b8 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -76,9 +75,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::FULLY_CONNECTED), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({6, 13, 20, 27}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm index cf4aacf724f..cacd501f0bd 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -73,10 +72,10 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(1, {0, 0, 0, 0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 0, 2, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm index 67325c1adb7..69eed7d86b0 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -63,9 +62,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::MEAN), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2.5}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm index d881950c831..2a1054d73eb 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -64,9 +63,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 4, 6, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testMulLinear { @@ -90,9 +89,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 6, 6, 12}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm index 9c55cfc45b0..22fa11a89fb 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -79,9 +78,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PAD), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors(expected, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)runPrepending:(const HWC&)prepend @@ -165,9 +164,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PAD), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testMirrorPadChannelsOperation { @@ -189,9 +188,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PAD), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm index d2d95b30af2..f79d53c7bd3 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -74,11 +73,11 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output, indices}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 1, 2, 3, 4, 3, 4, 7, 8, 7, 8, 5, 6, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({4, 4, 8, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({3, 3, 1, 1}, model.GetOutput(1), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testPoolingMaxKernel2x2Stride2x2WithoutIndices { @@ -102,9 +101,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 1, 2, 3, 4, 3, 4, 7, 8, 7, 8, 5, 6, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({4, 4, 8, 8}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testPoolingAverageKernel2x2Stride2x2 { @@ -128,9 +127,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm index 1df08be61db..b805ed81c76 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -70,9 +69,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PRELU), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {-1.0, -2.0, 1.0, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({-2, -4, 1, 2}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testPReluLinearAlphaWithClip { @@ -97,9 +96,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::PRELU), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {-1.0, -2.0, 1.0, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({-2, -4, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testPRelu3DAlphaNoClip { @@ -125,9 +124,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(op_type), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -1.0, 2.0, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0, -2, 2, -6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testPRelu3DAlphaWithClip { @@ -153,9 +152,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(op_type), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.0, -1.0, 2.0, -3.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0, -2, 1, -6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm index 52de77e0ee4..3687c0ecd65 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -61,9 +60,9 @@ TensorRef GetTensorRef(int ref) { SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, {GetTensorRef(1)}); XCTAssertTrue(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, 0.0, 2.0, 8.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testReluClipOnly { @@ -74,9 +73,9 @@ TensorRef GetTensorRef(int ref) { SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, {GetTensorRef(1)}); XCTAssertTrue(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0, 0.0, 2.0, 6.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testReluAlphaOnly { @@ -87,9 +86,9 @@ TensorRef GetTensorRef(int ref) { SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, {GetTensorRef(1)}); XCTAssertTrue(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({-3.0, 0.0, 2.0, 8.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testReluClipAndAlpha { @@ -100,9 +99,9 @@ TensorRef GetTensorRef(int ref) { SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, {GetTensorRef(1)}); XCTAssertTrue(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({-3.0, 0.0, 2.0, 6.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm index 684e83b2db1..48d292e2a1b 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm @@ -62,9 +62,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESHAPE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testReshape3x1x2To2x1x3 { @@ -84,9 +84,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESHAPE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4, 5, 6})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4, 5, 6}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testReshape1x1x4To2x2x1 { @@ -106,9 +106,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESHAPE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testReshapeBatchIsUnsupported { @@ -128,9 +128,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESHAPE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(std::string(status.message()).find("Only identical batch dimension is supported") != + XCTAssertTrue(status.error_message().find("Only identical batch dimension is supported") != std::string::npos, - @"%s", std::string(status.message()).c_str()); + @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm index f00b2766bdc..49febc1d4c6 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -66,9 +65,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testResizeBilinear1x2x1To1x4x1 { @@ -90,9 +89,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 2.5, 4.0, 4.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testResizeBilinear2x2x1To4x4x1 { @@ -114,11 +113,11 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 4.0, 6.0, 8.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors( {1.0, 2.5, 4.0, 4.0, 3.5, 4.75, 6.0, 6.0, 6.0, 7.0, 8.0, 8.0, 6.0, 7.0, 8.0, 8.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testResizeBilinear2x2x1To3x3x1WithoutHalfPixel { @@ -141,10 +140,10 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 1.666666, 2.0, 2.333333, 3.0, 3.333333, 3.0, 3.666666, 4.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testResizeBilinear2x2x1To3x3x1WithHalfPixel { @@ -167,9 +166,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 1.5, 2.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testResizeNearest1x2x1To2x4x1 { @@ -191,9 +190,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm index e0c29561f9b..827f85fe00a 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -65,9 +64,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 2, 3, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSliceNoStrides { @@ -89,9 +88,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 3}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSliceNoStridesStartOffset { @@ -113,9 +112,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({3, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSliceStridesByHeight { @@ -137,9 +136,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 3}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSliceStridesByWidth { @@ -161,9 +160,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSliceStridesByChannels { @@ -185,9 +184,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 4}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm index 9196e9fe094..f5c4770bd8b 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm @@ -17,7 +17,6 @@ limitations under the License. #import -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -63,9 +62,9 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.1, 0.2, 0.1, 0.2})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 1, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSoftmaxDoesNotWorkForHeightAxis { @@ -85,7 +84,7 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4})); auto status = model.Invoke(); - XCTAssertFalse(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertFalse(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSoftmaxDoesNotWorkForWidthAxis { @@ -105,7 +104,7 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4})); auto status = model.Invoke(); - XCTAssertFalse(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertFalse(status.ok(), @"%s", status.error_message().c_str()); } - (void)testSoftmax1x1 { @@ -127,11 +126,11 @@ using ::tflite::gpu::metal::SingleOpModel; SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output}); XCTAssertTrue(model.PopulateTensor(0, {0.1f, 0.2f, 0.3f, 0.4f})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors( {std::exp(0.1f) / sum, std::exp(0.2f) / sum, std::exp(0.3f) / sum, std::exp(0.4f) / sum}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm index 17e398817b2..6e82ebe0361 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm @@ -51,7 +51,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTFail(@"PopulateTensor()"); } const auto status = model.Invoke(); - if (!status.ok()) XCTFail(@"%s", std::string(status.message()).c_str()); + if (!status.ok()) XCTFail(@"%s", status.error_message().c_str()); const std::vector& actual = model.GetOutput(0); const std::vector expected = {1.0f, 2.0f, 3.0f, 4.0f}; XCTAssertEqual(actual[0], expected[0]); @@ -69,7 +69,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTFail(@"PopulateTensor()"); } const auto status = model.Invoke(); - if (!status.ok()) XCTFail(@"%s", std::string(status.message()).c_str()); + if (!status.ok()) XCTFail(@"%s", status.error_message().c_str()); const std::vector& actual = model.GetOutput(0); const std::vector expected = {1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f}; XCTAssertEqual(actual[0], expected[0]); @@ -94,7 +94,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTFail(@"PopulateTensor()"); } const auto status = model.Invoke(); - if (!status.ok()) XCTFail(@"%s", std::string(status.message()).c_str()); + if (!status.ok()) XCTFail(@"%s", status.error_message().c_str()); const std::vector& actual = model.GetOutput(0); const std::vector expected = {1.0f, 2.0f, 3.0f, // 4.0f, 5.0f, 6.0f, // @@ -126,7 +126,7 @@ using ::tflite::gpu::metal::SingleOpModel; XCTFail(@"PopulateTensor()"); } const auto status = model.Invoke(); - if (!status.ok()) XCTFail(@"%s", std::string(status.message()).c_str()); + if (!status.ok()) XCTFail(@"%s", status.error_message().c_str()); const std::vector& actual = model.GetOutput(0); const std::vector expected = {1.0f, 2.0f, 3.0f, 4.0f, // 5.0f, 6.0f, 7.0f, 8.0f, // diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h index ffa567a5a9d..7a4066fea0a 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h @@ -45,7 +45,7 @@ class SingleOpModel { return true; } - absl::Status Invoke(); + Status Invoke(); const std::vector& GetOutput(int index) const { return outputs_[index].data; @@ -57,16 +57,16 @@ class SingleOpModel { std::vector outputs_; }; -absl::Status CompareVectors(const std::vector& reference, - const std::vector& output, float max_error); +Status CompareVectors(const std::vector& reference, + const std::vector& output, float max_error); /// Helper function that compiles previously configured graph (with added /// tasks), initializes graph with specified inputs, invokes and fills specified /// outputs -absl::Status RunGraph(const std::vector& graph, - id device, - const std::map& inputs, - std::map* outputs); +Status RunGraph(const std::vector& graph, + id device, + const std::map& inputs, + std::map* outputs); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm index 80c0e2457af..3edc8669f2c 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm @@ -65,7 +65,7 @@ SingleOpModel::SingleOpModel(Operation&& operation, const std::vector input_ids; input_ids.reserve(inputs_.size()); for (const auto& input : inputs_) { @@ -143,16 +143,16 @@ absl::Status SingleOpModel::Invoke() { RETURN_IF_ERROR(ConvertFromPHWC4(absl::MakeConstSpan(output_pointer, elements_count), output.shape, absl::MakeSpan(output.data))); } - return absl::OkStatus(); + return OkStatus(); } -absl::Status CompareVectors(const std::vector& reference, const std::vector& output, - float max_error) { +Status CompareVectors(const std::vector& reference, const std::vector& output, + float max_error) { if (reference.size() != output.size()) { const std::string message = "CompareVectors: vectors size does not match for reference: " + std::to_string(reference.size()) + " vs. output: " + std::to_string(output.size()); - return absl::InternalError(message); + return tflite::gpu::InternalError(message); } for (int i = 0; i < reference.size(); i++) { float error = std::abs(reference[i] - output[i]); @@ -160,15 +160,15 @@ absl::Status CompareVectors(const std::vector& reference, const std::vect const std::string message = "Reference: " + std::to_string(reference[i]) + ", output: " + std::to_string(output[i]) + ", error: " + std::to_string(error) + ", max allowed error: " + std::to_string(max_error); - return absl::InternalError(message); + return tflite::gpu::InternalError(message); } } - return absl::OkStatus(); + return OkStatus(); } -absl::Status RunGraph(const std::vector& nodes, id device, - const std::map& inputs, - std::map* outputs) { +Status RunGraph(const std::vector& nodes, id device, + const std::map& inputs, + std::map* outputs) { std::vector inputBufferIDs; inputBufferIDs.reserve(inputs.size()); for (const auto& input : inputs) { @@ -251,7 +251,7 @@ absl::Status RunGraph(const std::vector& nodes, id -#include #include #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -82,10 +81,10 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 4, 2, 4, 1, 1, 4, 8, 4, 8, 1, 1, 3, 5, 3, 5, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTransposeConvO1H2W2I1Stride1x1Adjacent2x2 { @@ -121,11 +120,11 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1, 1, 1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({1, 3, 3, 2, 0, 0, 4, 10, 10, 6, 0, 0, 4, 10, 10, 6, 0, 0, 3, 7, 7, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTransposeConvO1H3W3I1Stride1x1Adjacent1x1 { @@ -161,10 +160,10 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({7, 11, 7, 1, 7, 11, 7, 1, 4, 6, 4, 1, 1, 1, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTransposeConvO2H1W1I2Stride1x1Dilation1x1 { @@ -200,9 +199,9 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 1, 1, 1})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({4, 8, 1, 1, 4, 8, 1, 1, 1, 1, 1, 1}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTransposeConvO1H1W1I1Stride2x2Dilation1x1 { @@ -239,11 +238,11 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {1, 0, 2, 0, 0, 0, 4, 0, 8})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } - (void)testTransposeConv4x4 { @@ -278,13 +277,13 @@ using ::tflite::gpu::metal::SingleOpModel; {output}); XCTAssertTrue(model.PopulateTensor(0, {0.0f, 1.0f, 2.0f, 3.0f})); auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); status = CompareVectors({0.0f, 0.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 2.0f, 4.0f, 6.0f, 12.0f, 6.0f, 12.0f, 4.0f, 8.0f, 2.0f, 4.0f, 6.0f, 12.0f, 6.0f, 12.0f, 4.0f, 8.0f, 2.0f, 4.0f, 5.0f, 10.0f, 5.0f, 10.0f, 3.0f, 6.0f}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } @end diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm index 4c6bb140a96..f7f08b273ae 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.mm +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -198,13 +198,13 @@ class Delegate { } } - absl::Status BindBufferToTensor(id buffer, int tensor_index) { + Status BindBufferToTensor(id buffer, int tensor_index) { for (auto& input : graph_inputs_) { if (input.tensor_id == tensor_index) { input_output_buffers_[input.id] = buffer; bphwc4_buffers_[input.id] = buffer; input.set_externally = true; - return absl::OkStatus(); + return OkStatus(); } } for (auto& output : graph_outputs_) { @@ -212,10 +212,10 @@ class Delegate { input_output_buffers_[output.id] = buffer; bphwc4_buffers_[output.id] = buffer; output.set_externally = true; - return absl::OkStatus(); + return OkStatus(); } } - return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index)); + return NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index)); } void SetCommandEncoder( @@ -225,7 +225,7 @@ class Delegate { external_command_encoder_ = encoder; } - absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { + Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { // Extract TFLite delegate execution plan from the context and convert it into FlowGraph32. GraphFloat32 graph; RETURN_IF_ERROR(BuildModel(context, delegate_params, &graph)); @@ -234,7 +234,7 @@ class Delegate { NullTransformationReporter reporter; ModelTransformer transformer(&graph, &reporter); if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + return InternalError("Graph general transformations failed"); } // TODO(impjdi): Remove code duplication. @@ -265,7 +265,7 @@ class Delegate { if (tensor->allocation_type == TfLiteAllocationType::kTfLiteMmapRo) continue; const auto* input = find_value(tensor_index); if (!input || tensor->type != TfLiteType::kTfLiteFloat32) { - return absl::NotFoundError("Input tensor is not found in the graph."); + return NotFoundError("Input tensor is not found in the graph."); } inputs_.push_back(input->id); @@ -283,7 +283,7 @@ class Delegate { auto* tensor = context->tensors + tensor_index; const auto* output = find_value(tensor_index); if (!output || tensor->type != TfLiteType::kTfLiteFloat32) { - return absl::NotFoundError("Output tensor is not found in the graph."); + return NotFoundError("Output tensor is not found in the graph."); } outputs_.push_back(output->id); @@ -323,9 +323,7 @@ class Delegate { const auto& input_tensor = tensors_[input]; const auto tensor_id = input_tensor.tensor_id; input_ids.push_back(input); - if (input_tensor.shape.b != 1) { - return absl::UnimplementedError("Batching is not supported yet."); - } + if (input_tensor.shape.b != 1) return UnimplementedError("Batching is not supported yet."); input_dimensions[input] = input_tensor.shape; graph_inputs_.push_back({ input, // .id @@ -348,7 +346,7 @@ class Delegate { isFloat16:options_.allow_precision_loss convertToPBHWC4:true]; if (converter_to_BPHWC4_ == nil) { - return absl::InternalError("Error initialization of input buffer converter"); + return InternalError("Error initialization of input buffer converter"); } } } else { @@ -385,7 +383,7 @@ class Delegate { isFloat16:options_.allow_precision_loss convertToPBHWC4:false]; if (converter_from_BPHWC4_ == nil) { - return absl::InternalError("Error initialization of output buffer converter"); + return InternalError("Error initialization of output buffer converter"); } } } else { @@ -408,10 +406,10 @@ class Delegate { RETURN_IF_ERROR([inference_context_ setInputDimensions:input_dimensions outputDimensions:&output_dimensions taskDescriptors:optimized_model]); - return absl::OkStatus(); + return OkStatus(); } - absl::Status Invoke(TfLiteContext* context) { + Status Invoke(TfLiteContext* context) { if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) gpu_alarm_clock_->Stop(); // We need only synchronization so volatile works better than atomic which reads from global @@ -516,11 +514,11 @@ class Delegate { // External command encoder is assigned so all output buffers are controlled by a user. for (const auto& output : graph_outputs_) { if (!output.set_externally) { - return absl::InternalError( + return InternalError( "External command encoder is used, but not all output buffers are bound."); } } - return absl::OkStatus(); + return OkStatus(); } // Retrieve data from GPU and convert from PHWC4 to HWC. @@ -531,7 +529,7 @@ class Delegate { const void* gpu_ptr = [input_output_buffers_[output.id] contents]; std::memcpy(tensor->data.f, gpu_ptr, output.shape.DimensionsProduct() * sizeof(float)); } - return absl::OkStatus(); + return OkStatus(); } TfLiteDelegate* tflite_delegate() { return &delegate_; } @@ -598,7 +596,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = metal_delegate->Prepare(context, params); if (status.ok()) return metal_delegate; context->ReportError(context, "TfLiteGpuDelegate Prepare: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return nullptr; }, // .free @@ -612,7 +610,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { const auto status = GetMetalDelegate(node)->Invoke(context); if (status.ok()) return kTfLiteOk; context->ReportError(context, "TfLiteMetalDelegate Invoke: %s", - std::string(status.message()).c_str()); + status.error_message().c_str()); return kTfLiteError; }, nullptr, // .profiling_string diff --git a/tensorflow/lite/delegates/gpu/spi.h b/tensorflow/lite/delegates/gpu/spi.h index a70f8dbb326..c7f041f3db1 100644 --- a/tensorflow/lite/delegates/gpu/spi.h +++ b/tensorflow/lite/delegates/gpu/spi.h @@ -33,8 +33,8 @@ class TensorObjectConverter { public: virtual ~TensorObjectConverter() = default; - virtual absl::Status Convert(const TensorObject& input, - const TensorObject& output) = 0; + virtual Status Convert(const TensorObject& input, + const TensorObject& output) = 0; }; class TensorObjectConverterBuilder { @@ -44,7 +44,7 @@ class TensorObjectConverterBuilder { virtual bool IsSupported(const TensorObjectDef& input, const TensorObjectDef& output) const = 0; - virtual absl::Status MakeConverter( + virtual Status MakeConverter( const TensorObjectDef& input, const TensorObjectDef& output, std::unique_ptr* converter) = 0; }; @@ -66,13 +66,13 @@ class TensorTie { virtual ~TensorTie() = default; - virtual absl::Status SetExternalObject(TensorObject obj) = 0; + virtual Status SetExternalObject(TensorObject obj) = 0; virtual TensorObject GetExternalObject() = 0; - virtual absl::Status CopyToExternalObject() = 0; + virtual Status CopyToExternalObject() = 0; - virtual absl::Status CopyFromExternalObject() = 0; + virtual Status CopyFromExternalObject() = 0; const TensorTieDef& def() const { return def_; } From f68e082e2be3b843f69f965e5d0e47d50901f1b0 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Mon, 23 Mar 2020 17:13:59 -0700 Subject: [PATCH 464/492] Disable flaky eager:remote_test PiperOrigin-RevId: 302552940 Change-Id: Iebabdccbf3adf831deb64242d2e56d647f146f60 --- tensorflow/python/eager/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 8832f043457..9df6113b95f 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -852,9 +852,7 @@ cuda_py_test( python_version = "PY3", shard_count = 2, tags = [ - "manual", "no_oss", # This test launches local server. - "notap", # TODO(b/152224115) "optonly", # times out ], deps = [ From 309574cdb27b2b5c87d66482236f7d87ee956ee8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 17:19:13 -0700 Subject: [PATCH 465/492] Unlock XLA non pivoting tridiagonal solver for all back-ends. Also modify XLA implementation to speed up compilation time. PiperOrigin-RevId: 302553861 Change-Id: I2fb6108fa146f413a8271f562267e759f1dc86f6 --- .../tests/tridiagonal_solve_ops_test.py | 2 +- .../tf2xla/kernels/tridiagonal_ops.cc | 28 ++- tensorflow/compiler/xla/client/lib/BUILD | 3 + .../compiler/xla/client/lib/tridiagonal.cc | 179 ++++++++++++++---- .../xla/client/lib/tridiagonal_test.cc | 37 ++-- .../base_api/api_def_TridiagonalSolve.pbtxt | 1 + .../linalg/linear_operator_tridiag_test.py | 7 +- .../kernel_tests/tridiagonal_solve_op_test.py | 76 ++++++-- tensorflow/python/ops/linalg/linalg_impl.py | 4 +- 9 files changed, 234 insertions(+), 103 deletions(-) diff --git a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py index 26da5865c27..0fe745f869a 100644 --- a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py +++ b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py @@ -225,7 +225,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase): with self.session() as sess, self.test_scope(): with self.assertRaisesRegexp( errors_impl.UnimplementedError, - "Pivoting is not yet supported in XLA tridiagonal solver."): + "Current implementation does not yet support pivoting."): diags = array_ops.placeholder( shape=(batch_size, 3, num_dims), dtype=dtypes.float32) rhs = array_ops.placeholder( diff --git a/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc b/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc index c09003ee9e0..7ce2dd060f1 100644 --- a/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc @@ -17,24 +17,27 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/tridiagonal.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { class TridiagonalSolveOp : public XlaOpKernel { public: - explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("partial_pivoting", &pivoting_)); - } + explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES( - ctx, !pivoting_, - errors::Unimplemented( - "Pivoting is not yet supported in XLA tridiagonal solver.")); - auto diagonals = ctx->Input(0); auto rhs = ctx->Input(1); + bool partial_pivoting = false; + OP_REQUIRES_OK(ctx, + GetNodeAttr(def(), "partial_pivoting", &partial_pivoting)); + if (partial_pivoting) { + ctx->SetStatus(errors::Unimplemented( + "Current implementation does not yet support pivoting.")); + return; + } auto result = xla::tridiagonal::ThomasSolver(diagonals, rhs); if (!result.ok()) { @@ -43,16 +46,9 @@ class TridiagonalSolveOp : public XlaOpKernel { } ctx->SetOutput(0, result.ValueOrDie()); } - - private: - bool pivoting_; }; -// TODO(belletti): address test breakage in tridiagonal_solve_op_test_xla_gpu.py -// to support all XLA devices. -REGISTER_XLA_OP(Name("TridiagonalSolve") - .Device("XLA_TPU_JIT") - .TypeConstraint("T", kFloatTypes), +REGISTER_XLA_OP(Name("TridiagonalSolve").TypeConstraint("T", kFloatTypes), TridiagonalSolveOp); } // namespace diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index b821785d6d4..6fcdef46f29 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -531,12 +531,15 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/types:span", ], ) xla_test( name = "tridiagonal_test", srcs = ["tridiagonal_test.cc"], + real_hardware_only = True, + shard_count = 10, tags = ["optonly"], deps = [ ":constants", diff --git a/tensorflow/compiler/xla/client/lib/tridiagonal.cc b/tensorflow/compiler/xla/client/lib/tridiagonal.cc index 13cc3630137..89323b029b1 100644 --- a/tensorflow/compiler/xla/client/lib/tridiagonal.cc +++ b/tensorflow/compiler/xla/client/lib/tridiagonal.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -33,13 +34,6 @@ namespace tridiagonal { namespace { -struct TridiagonalSystemShape { - const int64 rank; - const int64 num_equations; - TridiagonalSystemShape(int64 rk, int64 num_eqs) - : rank(rk), num_equations(num_eqs) {} -}; - Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank, int64 expected, const std::string& op_name) { const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2); @@ -53,10 +47,10 @@ Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank, return Status::OK(); } -StatusOr CheckSystemAndReturnShape(XlaOp lower_diagonal, - XlaOp main_diagonal, - XlaOp upper_diagonal, - XlaOp rhs) { +StatusOr CheckSystemAndReturnNumEquations(XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, + XlaOp rhs) { XlaBuilder* builder = lower_diagonal.builder(); TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape, @@ -111,11 +105,27 @@ StatusOr CheckSystemAndReturnShape(XlaOp lower_diagonal, TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1, "upper diagonal")); - return TridiagonalSystemShape(rank, num_equations); + return num_equations; } -XlaOp Coefficient(XlaOp operand, int64 i) { - return SliceInMinorDims(operand, /*start=*/{i}, /*end=*/{i + 1}); +XlaOp Coefficient(XlaOp operand, int32 i) { + return DynamicSliceInMinorDims(operand, + /*starts=*/{ConstantR0(operand.builder(), i)}, + /*sizes=*/{1}); +} + +XlaOp Coefficient(XlaOp operand, XlaOp i) { + return DynamicSliceInMinorDims(operand, + /*starts=*/{i}, /*sizes=*/{1}); +} + +XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) { + return DynamicUpdateSliceInMinorDims( + updated, update, /*starts=*/{ConstantR0(updated.builder(), i)}); +} + +XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) { + return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i}); } } // namespace @@ -134,48 +144,133 @@ XlaOp Coefficient(XlaOp operand, int64 i) { // solution will have the shape [..., num_rhs, num_equations]. StatusOr ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal, XlaOp upper_diagonal, XlaOp rhs) { - TF_ASSIGN_OR_RETURN(TridiagonalSystemShape system_shape, - CheckSystemAndReturnShape(lower_diagonal, main_diagonal, - upper_diagonal, rhs)); + XlaBuilder* builder = lower_diagonal.builder(); - auto rank = system_shape.rank; - auto num_eqs = system_shape.num_equations; + TF_ASSIGN_OR_RETURN(int64 num_eqs, + CheckSystemAndReturnNumEquations( + lower_diagonal, main_diagonal, upper_diagonal, rhs)); - std::vector main_diag_after_elimination(num_eqs); - std::vector rhs_after_elimination(num_eqs); - std::vector upper_diagonal_coeffs(num_eqs); + XlaOp main_diag_after_elimination = ZerosLike(main_diagonal); + XlaOp rhs_after_elimination = ZerosLike(rhs); + XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal); + XlaOp x_coeffs = ZerosLike(rhs); - main_diag_after_elimination[0] = Coefficient(main_diagonal, 0); - rhs_after_elimination[0] = Coefficient(rhs, 0); - for (int64 i = 0; i < num_eqs - 1; i++) { - upper_diagonal_coeffs[i] = Coefficient(upper_diagonal, i); - } + // main_diag_after_elimination[:, 0] = main_diagonal[:, 0]; + main_diag_after_elimination = + UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0)); + + // rhs_after_elimination[:, 0] = rhs[:, 0]; + rhs_after_elimination = + UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0)); + + auto preparation_body_fn = + [](XlaOp i, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + auto upper_diagonal_coeffs = values[0]; + auto upper_diagonal = values[1]; + // upper_diagonal_coeffs[:, i] = upper_diagonal[:, i]; + upper_diagonal_coeffs = + UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i)); + return std::vector{upper_diagonal_coeffs, upper_diagonal}; + }; + TF_ASSIGN_OR_RETURN(auto values_after_preparation, + ForEachIndex(num_eqs - 1, S32, preparation_body_fn, + {upper_diagonal_coeffs, upper_diagonal}, + "preparation", builder)); + upper_diagonal_coeffs = values_after_preparation[0]; // Forward transformation. - for (int64 i = 1; i < num_eqs; i++) { + auto forward_transformation_fn = + [](XlaOp i_minus_one, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + auto lower_diagonal = values[0]; + auto main_diagonal = values[1]; + auto rhs = values[2]; + auto main_diag_after_elimination = values[3]; + auto upper_diagonal_coeffs = values[4]; + auto rhs_after_elimination = values[5]; + + auto one = ScalarLike(i_minus_one, 1); + auto i = i_minus_one + one; auto lower_diagonal_i = Coefficient(lower_diagonal, i); auto main_diagonal_i = Coefficient(main_diagonal, i); auto rhs_i = Coefficient(rhs, i); - auto w_i = lower_diagonal_i / main_diag_after_elimination[i - 1]; + auto w_i = + lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one); - main_diag_after_elimination[i] = - main_diagonal_i - w_i * upper_diagonal_coeffs[i - 1]; - rhs_after_elimination[i] = rhs_i - w_i * rhs_after_elimination[i - 1]; - } + // main_diag_after_elimination[:, i] = + // main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1]; + main_diag_after_elimination = UpdateEq( + main_diag_after_elimination, i, + main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one)); + // rhs_after_elimination[:, i] = + // rhs_i - w_i * rhs_after_elimination[:, i - 1]; + rhs_after_elimination = + UpdateEq(rhs_after_elimination, i, + rhs_i - w_i * Coefficient(rhs_after_elimination, i - one)); - std::vector x_coeffs(num_eqs); + return std::vector{lower_diagonal, + main_diagonal, + rhs, + main_diag_after_elimination, + upper_diagonal_coeffs, + rhs_after_elimination}; + }; + TF_ASSIGN_OR_RETURN( + auto values_after_fwd_transformation, + ForEachIndex( + num_eqs - 1, S32, forward_transformation_fn, + {lower_diagonal, main_diagonal, rhs, main_diag_after_elimination, + upper_diagonal_coeffs, rhs_after_elimination}, + "forward_transformation", builder)); + lower_diagonal = values_after_fwd_transformation[0]; + main_diagonal = values_after_fwd_transformation[1]; + rhs = values_after_fwd_transformation[2]; + main_diag_after_elimination = values_after_fwd_transformation[3]; + upper_diagonal_coeffs = values_after_fwd_transformation[4]; + rhs_after_elimination = values_after_fwd_transformation[5]; // Backward reduction. - x_coeffs[num_eqs - 1] = rhs_after_elimination[num_eqs - 1] / - main_diag_after_elimination[num_eqs - 1]; - for (int i = num_eqs - 2; i >= 0; i--) { - x_coeffs[i] = (rhs_after_elimination[i] - - upper_diagonal_coeffs[i] * x_coeffs[i + 1]) / - main_diag_after_elimination[i]; - } + // x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] / + // main_diag_after_elimination[:, num_eqs - 1]; + x_coeffs = + UpdateEq(x_coeffs, num_eqs - 1, + Coefficient(rhs_after_elimination, num_eqs - 1) / + Coefficient(main_diag_after_elimination, num_eqs - 1)); + auto bwd_reduction_fn = + [num_eqs](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + auto x_coeffs = values[0]; + auto rhs_after_elimination = values[1]; + auto upper_diagonal_coeffs = values[2]; + auto main_diag_after_elimination = values[3]; + auto n = ScalarLike(j, num_eqs - 2); + auto one = ScalarLike(j, 1); + auto i = n - j; + // for (int i = num_eqs - 2; i >= 0; i--) + // x_coeffs[:, i] = (rhs_after_elimination[:, i] - + // upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) / + // main_diag_after_elimination[:, i]; + x_coeffs = UpdateEq(x_coeffs, i, + (Coefficient(rhs_after_elimination, i) - + Coefficient(upper_diagonal_coeffs, i) * + Coefficient(x_coeffs, i + one)) / + Coefficient(main_diag_after_elimination, i)); + return std::vector{x_coeffs, rhs_after_elimination, + upper_diagonal_coeffs, + main_diag_after_elimination}; + }; - return ConcatInDim(lower_diagonal.builder(), x_coeffs, rank - 1); + TF_ASSIGN_OR_RETURN( + auto values_after_bwd_reduction, + ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn, + {x_coeffs, rhs_after_elimination, upper_diagonal_coeffs, + main_diag_after_elimination}, + "backward_reduction", builder)); + x_coeffs = values_after_bwd_reduction[0]; + + return x_coeffs; } // Applies Thomas algorithm to solve a linear system where the linear operand diff --git a/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc b/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc index 17147588ff6..0b3a32f0969 100644 --- a/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc +++ b/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc @@ -33,34 +33,28 @@ namespace { class TridiagonalTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface> {}; + public ::testing::WithParamInterface> {}; XLA_TEST_P(TridiagonalTest, Solves) { const auto& spec = GetParam(); xla::XlaBuilder builder(TestName()); - const int64 num_eqs = 5; - const int64 num_rhs = 3; - const int64 lower_diagonal_batch_size = std::get<0>(spec); - const int64 main_diagonal_batch_size = std::get<1>(spec); - const int64 upper_diagonal_batch_size = std::get<2>(spec); - const int64 rhs_diagonal_batch_size = std::get<2>(spec); + // TODO(belletti): parametrize num_rhs. + const int64 batch_size = std::get<0>(spec); + const int64 num_eqs = std::get<1>(spec); + const int64 num_rhs = std::get<2>(spec); - const int64 max_batch_size = - std::max({lower_diagonal_batch_size, main_diagonal_batch_size, - upper_diagonal_batch_size, rhs_diagonal_batch_size}); - - Array3D lower_diagonal(lower_diagonal_batch_size, 1, num_eqs); - Array3D main_diagonal(main_diagonal_batch_size, 1, num_eqs); - Array3D upper_diagonal(upper_diagonal_batch_size, 1, num_eqs); - Array3D rhs(rhs_diagonal_batch_size, num_rhs, num_eqs); + Array3D lower_diagonal(batch_size, 1, num_eqs); + Array3D main_diagonal(batch_size, 1, num_eqs); + Array3D upper_diagonal(batch_size, 1, num_eqs); + Array3D rhs(batch_size, num_rhs, num_eqs); lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0); main_diagonal.FillRandom(0.05, /*mean=*/1.0, - /*seed=*/max_batch_size * num_eqs); + /*seed=*/batch_size * num_eqs); upper_diagonal.FillRandom(1.0, /*mean=*/0.0, - /*seed=*/2 * max_batch_size * num_eqs); - rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * max_batch_size * num_eqs); + /*seed=*/2 * batch_size * num_eqs); + rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs); XlaOp lower_diagonal_xla; XlaOp main_diagonal_xla; @@ -119,10 +113,9 @@ XLA_TEST_P(TridiagonalTest, Solves) { } INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest, - ::testing::Combine(::testing::Values(1, 8), - ::testing::Values(1, 8), - ::testing::Values(1, 8), - ::testing::Values(1, 8))); + ::testing::Combine(::testing::Values(1, 12), + ::testing::Values(4, 8), + ::testing::Values(1, 12))); } // namespace } // namespace tridiagonal diff --git a/tensorflow/core/api_def/base_api/api_def_TridiagonalSolve.pbtxt b/tensorflow/core/api_def/base_api/api_def_TridiagonalSolve.pbtxt index 1eb88c886ef..058ce666e03 100644 --- a/tensorflow/core/api_def/base_api/api_def_TridiagonalSolve.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_TridiagonalSolve.pbtxt @@ -39,5 +39,6 @@ END On CPU, solution is computed via Gaussian elimination with or without partial pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv + Partial pivoting is not yet supported by XLA backends. END } diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py index d69f872f703..049938e8d03 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py @@ -178,7 +178,8 @@ class LinearOperatorTriDiagMatrixTest( if __name__ == '__main__': - linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest) - linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest) - linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest) + if not test_util.is_xla_enabled(): + linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest) + linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest) + linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest) test.main() diff --git a/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py b/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py index 2b50f1a29d4..afc327e2aef 100644 --- a/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py +++ b/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py @@ -22,8 +22,8 @@ import itertools import numpy as np -from tensorflow.python.eager import backprop from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -78,8 +78,19 @@ class TridiagonalSolveOpTest(test.TestCase): transpose_rhs=False, conjugate_rhs=False): with self.cached_session(use_gpu=True): - result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format, - transpose_rhs, conjugate_rhs) + pivoting = True + if hasattr(self, "pivoting"): + pivoting = self.pivoting + if test_util.is_xla_enabled() and pivoting: + # Pivoting is not supported by xla backends. + return + result = linalg_impl.tridiagonal_solve( + diags, + rhs, + diags_format, + transpose_rhs, + conjugate_rhs, + partial_pivoting=pivoting) self.assertAllClose(self.evaluate(result), expected) def _testWithLists(self, @@ -94,8 +105,15 @@ class TridiagonalSolveOpTest(test.TestCase): transpose_rhs, conjugate_rhs) def _assertRaises(self, diags, rhs, diags_format="compact"): + pivoting = True + if hasattr(self, "pivoting"): + pivoting = self.pivoting + if test_util.is_xla_enabled() and pivoting: + # Pivoting is not supported by xla backends. + return with self.assertRaises(ValueError): - linalg_impl.tridiagonal_solve(diags, rhs, diags_format) + linalg_impl.tridiagonal_solve( + diags, rhs, diags_format, partial_pivoting=pivoting) # Tests with various dtypes @@ -137,6 +155,9 @@ class TridiagonalSolveOpTest(test.TestCase): self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2]) def test0x0(self): + if test_util.is_xla_enabled(): + # The following test crashes with XLA due to slicing 0 length tensors. + return self._test( diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32), rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32), @@ -153,10 +174,16 @@ class TridiagonalSolveOpTest(test.TestCase): diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]]) def test1x1NotInvertible(self): + if test_util.is_xla_enabled(): + # XLA implementation does not check invertibility. + return with self.assertRaises(errors_impl.InvalidArgumentError): self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]], expected=[]) def test2x2NotInvertible(self): + if test_util.is_xla_enabled(): + # XLA implementation does not check invertibility. + return with self.assertRaises(errors_impl.InvalidArgumentError): self._testWithLists( diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[]) @@ -179,7 +206,7 @@ class TridiagonalSolveOpTest(test.TestCase): expected=[5, -2, -5, 3]) def testNotInvertible(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.is_xla_enabled(): # CuSparse gtsv routines don't raise errors for non-invertible # matrices. return @@ -252,8 +279,9 @@ class TridiagonalSolveOpTest(test.TestCase): def testSequenceFormatWithDummyElements(self): dummy = 20 self._test( - diags=(_tfconst([2, 1, 4, dummy]), _tfconst([1, 3, 2, 2]), - _tfconst([dummy, 1, -1, 1])), + diags=(_tfconst([2, 1, 4, + dummy]), _tfconst([1, 3, 2, + 2]), _tfconst([dummy, 1, -1, 1])), rhs=_tfconst([1, 2, 3, 4]), expected=_tfconst([-9, 5, -4, 4]), diags_format="sequence") @@ -261,8 +289,9 @@ class TridiagonalSolveOpTest(test.TestCase): def testSequenceFormatWithBatching(self): self._test( diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]), - _tfconst([[1, 3, 2, 2], [-1, -3, -2, -2]]), - _tfconst([[1, -1, 1], [-1, 1, -1]])), + _tfconst([[1, 3, 2, 2], + [-1, -3, -2, -2]]), _tfconst([[1, -1, 1], [-1, 1, + -1]])), rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]), expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]), diags_format="sequence") @@ -373,6 +402,9 @@ class TridiagonalSolveOpTest(test.TestCase): with backprop.GradientTape() as tape_rhs: tape_diags.watch(diags) tape_rhs.watch(rhs) + if test_util.is_xla_enabled(): + # Pivoting is not supported by xla backends. + return x = linalg_impl.tridiagonal_solve( diags, rhs, @@ -526,6 +558,9 @@ class TridiagonalSolveOpTest(test.TestCase): return diags = array_ops.placeholder(dtypes.float64, shape=diags_shape) rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape) + if test_util.is_xla_enabled() and self.pivoting: + # Pivoting is not supported by xla backends. + return x = linalg_impl.tridiagonal_solve( diags, rhs, diags_format, partial_pivoting=self.pivoting) with self.cached_session(use_gpu=True) as sess: @@ -601,6 +636,9 @@ class TridiagonalSolveOpTest(test.TestCase): def testSequenceFormatWithUnknownDims(self): if context.executing_eagerly(): return + if test_util.is_xla_enabled() and self.pivoting: + # Pivoting is not supported by xla backends. + return superdiag = array_ops.placeholder(dtypes.float64, shape=[None]) diag = array_ops.placeholder(dtypes.float64, shape=[None]) subdiag = array_ops.placeholder(dtypes.float64, shape=[None]) @@ -641,9 +679,9 @@ class TridiagonalSolveOpTest(test.TestCase): np.random.seed(seed) import scipy.sparse as sparse # pylint:disable=g-import-not-at-top # By being strictly diagonally dominant, we guarantee invertibility.d - diag = 2* np.abs(np.random.randn(matrix_size)) + 4.1 - subdiag = 2* np.abs(np.random.randn(matrix_size-1)) - superdiag = 2* np.abs(np.random.randn(matrix_size-1)) + diag = 2 * np.abs(np.random.randn(matrix_size)) + 4.1 + subdiag = 2 * np.abs(np.random.randn(matrix_size - 1)) + superdiag = 2 * np.abs(np.random.randn(matrix_size - 1)) matrix = sparse.diags([superdiag, diag, subdiag], [1, 0, -1]).toarray() vector = np.random.randn(batch_size, matrix_size, num_rhs) return (variables.Variable(np.tile(matrix, (batch_size, 1, 1))), @@ -665,6 +703,9 @@ class TridiagonalSolveOpTest(test.TestCase): session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device(device_id): diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs) + # Pivoting is not supported by XLA backends. + if test.is_xla_enabled() and pivoting: + return x = linalg_impl.tridiagonal_solve( diags, rhs, partial_pivoting=pivoting) variables.global_variables_initializer().run() @@ -673,9 +714,9 @@ class TridiagonalSolveOpTest(test.TestCase): control_flow_ops.group(x), min_iters=10, store_memory_usage=False, - name=test_name_format_string.format( - device_name, matrix_size, batch_size, num_rhs, - pivoting_name)) + name=test_name_format_string.format(device_name, matrix_size, + batch_size, num_rhs, + pivoting_name)) def benchmarkTridiagonalSolveOp_WithMatrixInput(self): self._benchmark( @@ -687,9 +728,8 @@ class TridiagonalSolveOpTest(test.TestCase): def benchmarkTridiagonalSolveOp(self): self._benchmark( self._generateMatrixData, - test_name_format_string=( - "tridiagonal_solve_{}_matrix_size_{}_" - "batch_size_{}_num_rhs_{}_{}")) + test_name_format_string=("tridiagonal_solve_{}_matrix_size_{}_" + "batch_size_{}_num_rhs_{}_{}")) if __name__ == "__main__": diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index c59314890e1..f7617d83caf 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -430,7 +430,9 @@ def tridiagonal_solve(diagonals, Raises: ValueError: An unsupported type is provided as input, or when the input - tensors have incorrect shapes. + tensors have incorrect shapes. + UnimplementedError: Whenever `partial_pivoting` is true and the backend is + XLA. [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms: Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7. From 3dc3f89aadde1f8aa755829676cb230fad2f85f1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 17:20:59 -0700 Subject: [PATCH 466/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302554124 Change-Id: I27ff409c61ee0148cfa38a44b1f1040485e9842a --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 68bb1dc49f5..75d86f71b78 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19036,7 +19036,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20107,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21279,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21987,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22183,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22252,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22367,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22426,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22600,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22977,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25320,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25383,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25626,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26110,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40308,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45834,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46686,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46749,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 4290ab0ba9c8f8c6cab07cb51b3294d77fce957b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 17:21:37 -0700 Subject: [PATCH 467/492] Supported tiled output sharding for model parallelism. PiperOrigin-RevId: 302554209 Change-Id: I1f67e80d18795e7cb22cdd2124428db4e9b5e5de --- .../mlir/tensorflow/tests/tpu_rewrite.mlir | 385 ++++++++++++++++++ .../tensorflow/transforms/tpu_rewrite_pass.cc | 48 ++- .../tensorflow/utils/xla_sharding_util.cc | 222 ++++++++-- .../mlir/tensorflow/utils/xla_sharding_util.h | 17 +- 4 files changed, 621 insertions(+), 51 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 34dbee5cba9..06d6c35e0a8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1502,6 +1502,74 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- +// Tests that outputs are correctly merged and fed from TPU computation for +// tiled output sharding. + +// The following OpSharding is used for TPU computation outputs in below test: +// Proto debug string: +// output 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// output 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_output + func @parallel_execute_with_tiled_output(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: device = "TPU_REPLICATED_CORE_1" + // + // CHECK: %[[CONST_CONCAT_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2 + + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + // The following OpSharding is used for TPU computation inputs in below test: // Proto debug string: // input 0 @@ -1510,6 +1578,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // tile_assignment_dimensions: 4 // tile_assignment_devices: 0 // tile_assignment_devices: 1 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03" // // input 1 // type: MAXIMAL @@ -1541,6 +1611,46 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc } } +// The following OpSharding is used for TPU computation outputs in below test: +// Proto debug string: +// output 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 4 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03" +// +// output 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" +// +// ----- + +// Tests tile sharding of outputs with number of splits that exeed number +// of logical devices is not allowed. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + func @uneven_output_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // expected-error@+1 {{incorrect sharding format for outputs. Number of tiled outputs(4) must match the number of logical devices(2)}} + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["", ""], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi1>, tensor<*xi32>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %3, %4 : tensor<*xi1>, tensor<*xi32> + } +} + // ----- // The following topology is used in subsequent test cases: @@ -1648,6 +1758,196 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc } +// ----- + +// The following topology is used in subsequent test cases: +// Proto debug string: +// mesh_shape: 2 +// mesh_shape: 1 +// mesh_shape: 2 +// num_tasks: 2 +// num_tpu_devices_per_task: 2 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_shape { +// element_type: F32 +// dimensions: 2 +// dimensions: 2 +// layout { +// minor_to_major: 1 +// minor_to_major: 0 +// format: DENSE +// } +// is_dynamic_dimension: false +// is_dynamic_dimension: false +// } +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// Tests inputs to TPUComputation that are tiled in multiple dimensions. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_multi_dimension_tiled_input + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func @parallel_execute_with_multi_dimension_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]]) + // CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0) + // CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#1) + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[COMPILE]]#3) + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + +// Tests that tiled output with multiple dimension sharding works properly. + +// The following OpSharding is used for TPU computation outputs in below test: +// output 0 +// Proto debug string: +// type: OTHER +// tile_shape { +// element_type: F32 +// dimensions: 2 +// dimensions: 2 +// layout { +// minor_to_major: 1 +// minor_to_major: 0 +// format: DENSE +// } +// is_dynamic_dimension: false +// is_dynamic_dimension: false +// } +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03" +// +// output 1 +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @multiple_dimension_output_sharding + func @multiple_dimension_output_sharding(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"( + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"( + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + // CHECK: %[[CONST_CONCAT_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2 + // CHECK: %[[CONST_CONCAT2_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#3, %[[PARALLEL_EXECUTE_OUTPUT]]#4 + // CHECK: %[[CONST_CONCAT3_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]] + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + // ----- // Tests inputs device assignment order is well preserved for tiled input sharding. @@ -1732,3 +2032,88 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc return %4, %3 : tensor<*xi32>, tensor<*xi1> } } + +// ----- + +// Tests device assignment is well preserved for tile sharded outputs. + +// The following OpSharding is used for TPU computation outputs in below test: +// output 0 +// Proto debug string: +// type: OTHER +// tile_shape { +// element_type: F32 +// dimensions: 2 +// dimensions: 2 +// layout { +// minor_to_major: 1 +// minor_to_major: 0 +// format: DENSE +// } +// is_dynamic_dimension: false +// is_dynamic_dimension: false +// } +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 3 +// tile_assignment_devices: 2 +// tile_assignment_devices: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00" +// +// output 1 +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @device_order_preserved_for_tiled_output + func @device_order_preserved_for_tiled_output(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"( + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"( + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + // CHECK: %[[CONST_CONCAT_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#4, %[[PARALLEL_EXECUTE_OUTPUT]]#3 + // CHECK: %[[CONST_CONCAT2_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#2, %[[PARALLEL_EXECUTE_OUTPUT]]#0 + // CHECK: %[[CONST_CONCAT3_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]] + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index f9e24e4373d..1a49350a4be 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -470,20 +470,22 @@ void AssignDevicesToReplicate( } // Creates a `tf.TPUExecute` op that executes TPU program. -Operation* BuildExecuteOp( +LogicalResult BuildExecuteOp( const int core_id, llvm::ArrayRef output_sharding_config, llvm::ArrayRef inputs, tf_device::LaunchFuncOp launch_func, - OpBuilder* builder) { + OpBuilder* builder, TF::TPUExecuteOp* execute_op) { // TODO(b/139377366): Need to snapshot all resource variable inputs in // follow-up CLs. - - auto output_types = tensorflow::GetOutputTypesForLogicalDeviceComputation( - core_id, output_sharding_config, launch_func); + llvm::SmallVector output_types; + auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation( + core_id, output_sharding_config, launch_func, &output_types); + if (failed(result)) return failure(); // TPUExecute has same output types as launch_func. - return builder->create(launch_func.getLoc(), output_types, - inputs, - llvm::ArrayRef{}); + *execute_op = builder->create( + launch_func.getLoc(), output_types, inputs, + llvm::ArrayRef{}); + return success(); } // Creates a tf_device.parallel_execute op that wraps TPUExecute op to @@ -505,8 +507,11 @@ LogicalResult BuildParallelExecuteOp( num_cores_per_replica); for (int core = 0; core < num_cores_per_replica; ++core) { - auto output_types = tensorflow::GetOutputTypesForLogicalDeviceComputation( - core, output_sharding_config, launch_func); + llvm::SmallVector output_types; + auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation( + core, output_sharding_config, launch_func, &output_types); + if (failed(result)) return failure(); + for (Type t : output_types) concatenated_output_types.emplace_back(t); } @@ -537,8 +542,10 @@ LogicalResult BuildParallelExecuteOp( auto execute_inputs = input_list[core]; execute_inputs.emplace_back(compile_op->getResult(core + 1)); - auto execute = BuildExecuteOp(core, output_sharding_config, execute_inputs, - launch_func, builder); + TF::TPUExecuteOp execute; + result = BuildExecuteOp(core, output_sharding_config, execute_inputs, + launch_func, builder, &execute); + if (failed(result)) return failure(); // If computation is replicated, use aliased device. Otherwise there is only // one execution device per core and the device is assigned to the execute @@ -699,8 +706,8 @@ LogicalResult Rewrite( builder); llvm::SmallVector output_shardings; - auto result = tensorflow::ParseAndValidateOutputSharding(launch_func, - &output_shardings); + auto result = tensorflow::ParseAndValidateOutputSharding( + num_cores_per_replica, launch_func, &output_shardings); if (failed(result)) return failure(); if (num_cores_per_replica > 1) { @@ -717,14 +724,19 @@ LogicalResult Rewrite( // ops, the number of return values of parallel_execute op exceeds that of // launch_func op. As so, each return value of parallel_execute op must be // mapped with corresponding return value usages of launch_func. - tensorflow::RemapOutputsFromLogicalDevices(output_shardings, launch_func, - execute_op); + tensorflow::RemapOutputsFromLogicalDevices(launch_func.getLoc(), + output_shardings, launch_func, + execute_op, builder); } else { llvm::SmallVector execute_inputs(launch_func.getOperands()); execute_inputs.emplace_back(compile_op->getResult(1)); - Operation* execute_op = BuildExecuteOp( - /*core_id=*/0, output_shardings, execute_inputs, launch_func, builder); + TF::TPUExecuteOp execute_op; + result = BuildExecuteOp( + /*core_id=*/0, output_shardings, execute_inputs, launch_func, builder, + &execute_op); + if (failed(result)) return failure(); + tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute( tpu_device_assignment.execution_devices, execute_op, builder); launch_func.replaceAllUsesWith(launch_op); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index 7c1f69f4d92..aaff33bce3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -27,6 +27,8 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -102,6 +104,46 @@ mlir::LogicalResult CreateSplitOp(const int num_split, return mlir::success(); } +// Creates a tf::ConcatOp that merges `input` values in `concat_dimension`. +mlir::LogicalResult CreateConcatOp(const int concat_dimension, + const mlir::Location& location, + mlir::ArrayRef inputs, + mlir::OpBuilder* builder, + mlir::TF::ConcatOp* concat_op) { + // Creates a const op to hold concat dimension value. + auto concat_dim_type = + mlir::RankedTensorType::get({}, builder->getIntegerType(32)); + auto concat_dimension_attr = + mlir::DenseElementsAttr::get(concat_dim_type, concat_dimension); + auto concat_dimension_op = builder->create( + location, concat_dim_type, concat_dimension_attr); + + // Correctly set output shapes of concat op output if output shape is + // statically known. Since the shape of TPUExecute op must be the same + // across logical devices, we refer to the shape of 0th logical device + // computation output. + mlir::Type output_type; + auto input_type = inputs[0].getType().cast(); + + if (input_type.hasRank()) { + if (input_type.getShape()[concat_dimension] == + mlir::ShapedType::kDynamicSize) { + output_type = input_type; + } else { + auto shape = llvm::to_vector<4>(input_type.getShape()); + shape[concat_dimension] = shape[concat_dimension] * inputs.size(); + output_type = + mlir::RankedTensorType::get(shape, input_type.getElementType()); + } + } else { + output_type = input_type; + } + + *concat_op = builder->create( + location, output_type, concat_dimension_op.output(), inputs); + return mlir::success(); +} + // For tile sharded inputs to TPU computation, inject split op between the // input values and TPU computation so that tiled input values are passed in // as inputs to TPU computations. If more than one dimension is sharded, then @@ -167,12 +209,12 @@ mlir::LogicalResult HandleTileShardedInputs( } // namespace mlir::LogicalResult ExtractInputsForLogicalDevices( - int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func, + const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, mlir::OpBuilder* builder, llvm::SmallVectorImpl>* input_list) { // Initialize the input list for each logical devices. - input_list->reserve(num_logical_cores); - for (int i = 0; i < num_logical_cores; ++i) + input_list->reserve(num_cores_per_replica); + for (int i = 0; i < num_cores_per_replica; ++i) input_list->emplace_back(llvm::SmallVector()); llvm::SmallVector launch_func_inputs( @@ -207,12 +249,12 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( launch_func.getLoc(), sharding, input_value, builder, &tiled_inputs); if (mlir::failed(result)) return mlir::failure(); - if (tiled_inputs.size() != num_logical_cores) + if (tiled_inputs.size() != num_cores_per_replica) launch_func.emitError(llvm::formatv( "incorrect {0}-th tiled input sharding received. " "Product of tile sharding splits({1}) must be equal to " "number of logical devices : {2}", - input_index, tiled_inputs.size(), num_logical_cores)); + input_index, tiled_inputs.size(), num_cores_per_replica)); for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) { const int assigned_logical_device = sharding.tile_assignment_devices(i); @@ -230,7 +272,7 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( } mlir::LogicalResult ParseAndValidateOutputSharding( - mlir::tf_device::LaunchFuncOp launch_func, + const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, mlir::SmallVector* output_sharding_list) { output_sharding_list->reserve(launch_func.getNumResults()); @@ -257,10 +299,22 @@ mlir::LogicalResult ParseAndValidateOutputSharding( output_sharding.cast().getValue().str())) return launch_func.emitError("incorrect sharding format for outputs"); - const auto output_sharing_type = sharding.type(); - if (output_sharing_type == xla::OpSharding::OTHER) - return launch_func.emitError( - "tiled outputs are not yet supported for model parallelism"); + if (sharding.type() == xla::OpSharding::OTHER && + sharding.tile_assignment_devices_size() != num_cores_per_replica) + return launch_func.emitError(llvm::formatv( + "incorrect sharding format for outputs. Number of " + "tiled outputs({0}) must match the number of logical " + "devices({1})", + sharding.tile_assignment_devices_size(), num_cores_per_replica)); + + if (sharding.type() == xla::OpSharding::MAXIMAL && + ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) || + (sharding.tile_assignment_devices(0) < 0))) + return launch_func.emitError(llvm::formatv( + "incorrect sharding format for outputs. Maximal " + "sharding should be assigned to device id in range " + "[0, {0}). Currently assigned to {1}", + num_cores_per_replica, sharding.tile_assignment_devices(0))); output_sharding_list->emplace_back(std::move(sharding)); } @@ -287,7 +341,7 @@ int MapLaunchOutputIndexWithRegionOutputIndex( for (int output_index = 0; output_index < launch_func_output_index; ++output_index) { const auto& sharding = output_sharding_config[output_index]; - if (sharding.type() == xla::OpSharding::REPLICATED || + if (sharding.type() != xla::OpSharding::MAXIMAL || IsAssignedToLogicalDevice(core_id, sharding)) region_output_index++; } @@ -295,41 +349,158 @@ int MapLaunchOutputIndexWithRegionOutputIndex( return region_output_index; } +// Merges outputs from TPU computation for tile-sharded outputs. +mlir::LogicalResult HandleTileShardedOutputs( + const int launch_func_output_index, const xla::OpSharding& sharding, + const mlir::Location& location, mlir::Value launch_func_output, + mlir::tf_device::ParallelExecuteOp parallel_execute, + mlir::OpBuilder* builder) { + // Inject concat ops after parallel_execute to merge outputs from + // concurrently executed computations. + builder->setInsertionPointAfter(parallel_execute); + + // Reorders outputs from TPUExecute op as defined by the output sharding + // configuration. + llvm::SmallVector outputs_to_merge; + outputs_to_merge.reserve(sharding.tile_assignment_devices_size()); + for (const auto logical_device_id : sharding.tile_assignment_devices()) { + const int region_output_index = MapLaunchOutputIndexWithRegionOutputIndex( + sharding, logical_device_id, launch_func_output_index); + const auto output_from_logical_device = parallel_execute.GetRegionOutputs( + logical_device_id)[region_output_index]; + outputs_to_merge.emplace_back(output_from_logical_device); + } + + // Creates a tree of Concat ops that merges outputs from multiple logical + // devices to a single replica output. + int concat_dimension = sharding.tile_assignment_dimensions_size() - 1; + for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) { + if (num_splits == 1) { + --concat_dimension; + continue; + } + + llvm::SmallVector new_outputs; + new_outputs.reserve(num_splits); + for (int i = 0; i < outputs_to_merge.size(); i = i + num_splits) { + mlir::TF::ConcatOp concat_op; + auto result = + CreateConcatOp(concat_dimension, location, + llvm::ArrayRef{ + outputs_to_merge.begin() + i, + outputs_to_merge.begin() + i + num_splits}, + builder, &concat_op); + if (mlir::failed(result)) return mlir::failure(); + + new_outputs.emplace_back(concat_op.getResult()); + } + + std::swap(new_outputs, outputs_to_merge); + --concat_dimension; + } + + assert(outputs_to_merge.size() == 1); + launch_func_output.replaceAllUsesWith(outputs_to_merge[0]); + return mlir::success(); +} + +mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( + const mlir::Location& location, + const mlir::TensorType launch_func_output_type, + const xla::OpSharding& output_sharding, + mlir::Type* tiled_logical_computation_type) { + auto new_output_shape = + llvm::to_vector<4>(launch_func_output_type.getShape()); + for (auto dimension_and_output_splits : + llvm::enumerate(output_sharding.tile_assignment_dimensions())) { + const auto dimension_index = dimension_and_output_splits.index(); + const auto output_splits = dimension_and_output_splits.value(); + const auto& output_shape = launch_func_output_type.getShape(); + + if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) { + *tiled_logical_computation_type = launch_func_output_type; + break; + } + + auto output_shape_at_dim = + launch_func_output_type.getShape()[dimension_index]; + if (output_shape_at_dim % output_splits != 0) { + mlir::emitError( + location, + llvm::formatv("incorrect output sharding received. " + "{0}-th dimension of the output must be " + "evenly divisible by {1}, got dimension " + "shape {2}", + dimension_index, output_splits, output_shape_at_dim)); + } + + new_output_shape[dimension_index] = + output_shape[dimension_index] / output_splits; + } + + *tiled_logical_computation_type = mlir::RankedTensorType::get( + new_output_shape, launch_func_output_type.getElementType()); + + return mlir::success(); +} + } // namespace -mlir::SmallVector GetOutputTypesForLogicalDeviceComputation( - const int logical_device_id, - llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func) { - mlir::SmallVector output_types; - output_types.reserve(launch_func.getNumResults()); +mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( + const int core_id, llvm::ArrayRef output_sharding_config, + mlir::tf_device::LaunchFuncOp launch_func, + llvm::SmallVectorImpl* output_types) { + output_types->reserve(launch_func.getNumResults()); for (auto result_and_index : llvm::enumerate(launch_func.getResults())) { const auto output_index = result_and_index.index(); const auto& output_sharding = output_sharding_config[output_index]; const auto output_sharding_type = output_sharding.type(); - const auto& launch_func_output = result_and_index.value(); + const auto& launch_func_output_type = + result_and_index.value().getType().cast(); - if (output_sharding_type == xla::OpSharding::REPLICATED || - IsAssignedToLogicalDevice(logical_device_id, output_sharding)) - output_types.emplace_back(launch_func_output.getType()); + // If output shape of launch func is statically known and output is tiled + // sharded, then the corresponding output shape of launch func must be + // evenly divisible number of shardings. + if (output_sharding_type == xla::OpSharding::OTHER) { + mlir::Type tiled_logical_computation_type; + if (launch_func_output_type.hasRank()) { + auto result = ValidateAndGetTiledExecuteOutputShape( + launch_func.getLoc(), launch_func_output_type, output_sharding, + &tiled_logical_computation_type); + if (mlir::failed(result)) return mlir::failure(); + } else { + tiled_logical_computation_type = launch_func_output_type; + } + output_types->emplace_back(tiled_logical_computation_type); + } else if (output_sharding_type == xla::OpSharding::REPLICATED || + IsAssignedToLogicalDevice(core_id, output_sharding)) { + output_types->emplace_back(launch_func_output_type); + } } - return output_types; + return mlir::success(); } void RemapOutputsFromLogicalDevices( + const mlir::Location& location, llvm::ArrayRef output_sharding_config, mlir::tf_device::LaunchFuncOp launch_func, - mlir::tf_device::ParallelExecuteOp parallel_execute) { + mlir::tf_device::ParallelExecuteOp parallel_execute, + mlir::OpBuilder* builder) { for (auto result_and_index : llvm::enumerate(launch_func.getResults())) { const auto output_index = result_and_index.index(); const auto& launch_func_output = result_and_index.value(); const auto& output_sharding = output_sharding_config[output_index]; - const auto output_sharing_type = output_sharding.type(); + const auto output_sharding_type = output_sharding.type(); + if (output_sharding_type == xla::OpSharding::OTHER) { + HandleTileShardedOutputs(output_index, output_sharding, location, + launch_func_output, parallel_execute, builder); + continue; + } int logical_device_id = 0; - if (output_sharing_type == xla::OpSharding::MAXIMAL) + if (output_sharding_type == xla::OpSharding::MAXIMAL) logical_device_id = output_sharding.tile_assignment_devices(0); // For maximal sharding configuration, correctly remap outputs from @@ -339,7 +510,6 @@ void RemapOutputsFromLogicalDevices( const auto output_from_logical_device = parallel_execute.GetRegionOutputs( logical_device_id)[region_output_index]; - launch_func_output.replaceAllUsesWith(output_from_logical_device); } } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 1df4e1fbc37..2320bd44815 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -40,30 +41,32 @@ llvm::Optional ParseShardingAttribute( // TPU computation correponding to i-th logical device. If the attribute // does not exist, the all inputs are placed on logical core 0. mlir::LogicalResult ExtractInputsForLogicalDevices( - int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func, + const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, mlir::OpBuilder* builder, llvm::SmallVectorImpl>* input_list); // Extracts a list of OpSharding that represent output sharding configuration // of `tf_device.launch`. mlir::LogicalResult ParseAndValidateOutputSharding( - mlir::tf_device::LaunchFuncOp launch_func, + const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, mlir::SmallVector* output_sharding_list); // Retrieves output types for TPUExecute op representing execution for provided // logical device id. TPUExecute op for different logical device may have // different outputs depending on the output sharding configuration. -mlir::SmallVector GetOutputTypesForLogicalDeviceComputation( - const int logical_device_id, - llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func); +mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( + const int core_id, llvm::ArrayRef output_sharding_config, + mlir::tf_device::LaunchFuncOp launch_func, + llvm::SmallVectorImpl* output_types); // Remaps outputs of `tf_device.parallel_execute` op that represent concurrent // execution of the `tf_device.launch_func` with its users. void RemapOutputsFromLogicalDevices( + const mlir::Location& location, llvm::ArrayRef output_sharding_config, mlir::tf_device::LaunchFuncOp launch_func, - mlir::tf_device::ParallelExecuteOp parallel_execute); + mlir::tf_device::ParallelExecuteOp parallel_execute, + mlir::OpBuilder* builder); } // namespace tensorflow From 8a53a3082b0b17da477da5807a926391ec37b721 Mon Sep 17 00:00:00 2001 From: Yujing Zhang Date: Mon, 23 Mar 2020 17:30:20 -0700 Subject: [PATCH 468/492] Support remote inputs passed as Tensors to EagerClusterFunctionLibraryRuntime::Run. PiperOrigin-RevId: 302555474 Change-Id: I16d52b54748e255c2bf5ebcf522cb7a4eed639f0 --- tensorflow/core/BUILD | 1 + tensorflow/core/common_runtime/eager/BUILD | 1 + .../core/common_runtime/eager/execute.cc | 2 +- .../eager/process_function_library_runtime.cc | 24 +++------ .../eager/process_function_library_runtime.h | 4 +- .../process_function_library_runtime.cc | 43 ++++++++++----- .../process_function_library_runtime.h | 28 +++------- .../core/distributed_runtime/eager/BUILD | 1 + .../eager/cluster_function_library_runtime.cc | 40 ++++++++------ .../eager/cluster_function_library_runtime.h | 2 +- .../eager/eager_service_impl_test.cc | 54 +++++++++++++++++-- .../eager/remote_copy_node.cc | 3 +- tensorflow/core/framework/function.h | 12 ++++- tensorflow/core/protobuf/eager_service.proto | 1 + tensorflow/python/eager/remote_test.py | 12 +++++ 15 files changed, 153 insertions(+), 75 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c1b889751d7..1bff24cc8e3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2674,6 +2674,7 @@ tf_cuda_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", "//third_party/eigen3", "//tensorflow/core/public:version", "//tensorflow/core/grappler/utils:functions", diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 6d7b00fa64e..01d8c48a192 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -289,6 +289,7 @@ tf_cuda_library( ], "//conditions:default": [ "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index baaddec74e1..3d4cf6ae8fc 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -736,7 +736,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, ctx.GetContextViewId())); } } - auto* input_handle = remote_op->add_inputs(); + auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle(); TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle( input, input_handle, input_device, *input_device_name, serialize_resource_dtype_and_shape)); diff --git a/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc b/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc index ca545a9e890..2051a23f14b 100644 --- a/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc @@ -28,8 +28,8 @@ namespace eager { #if !defined(IS_MOBILE_PLATFORM) void EagerProcessFunctionLibraryRuntime::RunRemoteDevice( const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle local_handle, const InternalArgsView& args, - std::vector* rets, + FunctionLibraryRuntime::Handle local_handle, + gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const { if (!rets->empty()) { done( @@ -37,19 +37,7 @@ void EagerProcessFunctionLibraryRuntime::RunRemoteDevice( "EagerClusterFunctionLibraryRuntime yet.")); return; } - if (!args.local_args.empty()) { - done( - errors::Unimplemented("Local inputs are not by supported by " - "EagerClusterFunctionLibraryRuntime.")); - return; - } - if (args.remote_args == nullptr) { - done( - errors::Internal("EagerClusterFunctionLibraryRuntime: remote_args " - "should never be null.")); - return; - } - parent_->Run(opts, local_handle, args.remote_args, std::move(done)); + parent_->Run(opts, local_handle, args, rets, std::move(done)); } void EagerProcessFunctionLibraryRuntime::Run( @@ -71,11 +59,13 @@ void EagerProcessFunctionLibraryRuntime::Run( const int index = comp_data.arg_indices_.at(i); Tensor tensor; if (args.GetLocalArg(index, &tensor).ok()) { - comp_args->local_args.push_back(std::move(tensor)); + comp_args->args.push_back(std::move(tensor)); } else { RemoteTensorHandle remote_handle; TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle)); - comp_args->remote_args.push_back(std::move(remote_handle)); + comp_args->remote_args.emplace_back( + absl::make_unique(std::move(remote_handle))); + comp_args->args.push_back(comp_args->remote_args.back().get()); } } return Status::OK(); diff --git a/tensorflow/core/common_runtime/eager/process_function_library_runtime.h b/tensorflow/core/common_runtime/eager/process_function_library_runtime.h index a73ec63f29b..1e726f5c561 100644 --- a/tensorflow/core/common_runtime/eager/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/eager/process_function_library_runtime.h @@ -50,8 +50,8 @@ class EagerProcessFunctionLibraryRuntime private: void RunRemoteDevice( const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle local_handle, const InternalArgsView& args, - std::vector* rets, + FunctionLibraryRuntime::Handle local_handle, + gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const override; #endif // IS_MOBILE_PLATFORM }; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 73d83fbbd5e..9f9924b6ff2 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -361,6 +361,17 @@ Status SetArgShape( return Status::OK(); } +// Returns the local tensors referred by `args`. +std::vector GetLocalArgs(gtl::ArraySlice args) { + std::vector tensors; + for (const auto& arg : args) { + if (arg.index() == 0) { + tensors.push_back(absl::get(arg)); + } + } + return tensors; +} + } // anonymous namespace Status ProcessFunctionLibraryRuntime::PinArgsAndRets( @@ -958,10 +969,10 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices( void ProcessFunctionLibraryRuntime::RunRemoteDevice( const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle local_handle, const InternalArgsView& args, - std::vector* rets, + FunctionLibraryRuntime::Handle local_handle, + gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const { - parent_->Run(opts, local_handle, args.local_args, rets, std::move(done)); + parent_->Run(opts, local_handle, GetLocalArgs(args), rets, std::move(done)); } void ProcessFunctionLibraryRuntime::RunMultiDevice( @@ -1047,7 +1058,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( << " with handle " << handle; VLOG(4) << " with " << opts_copy.DebugString(); - flr->Run(opts_copy, handle, comp_args.local_args, comp_rets, + flr->Run(opts_copy, handle, GetLocalArgs(comp_args.args), comp_rets, [comp_rets, rets, comp_data, refcounted_done, data](const Status& status) { if (!status.ok()) { @@ -1072,9 +1083,9 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( VLOG(1) << "Running component function on device " << target << " with handle " << handle; VLOG(4) << " with " << opts_copy.DebugString(); - InternalArgsView comp_args_view(&comp_args); + RunInternal( - opts_copy, handle, comp_args_view, comp_rets, cleanup_items, + opts_copy, handle, comp_args.args, comp_rets, cleanup_items, [comp_rets, rets, comp_data, refcounted_done](const Status& status) { if (!status.ok()) { VLOG(2) << "Component function execution failed: " << status; @@ -1299,20 +1310,26 @@ void ProcessFunctionLibraryRuntime::Run( if (multi_device) { auto get_component_args = [&args](const ComponentFunctionData& comp_data, InternalArgs* comp_args) -> Status { - comp_args->local_args = GetArgsForIndices(comp_data.arg_indices_, args); + for (const auto& tensor : + GetArgsForIndices(comp_data.arg_indices_, args)) { + comp_args->args.push_back(tensor); + } return Status::OK(); }; return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done), std::move(get_component_args)); } - InternalArgsView internal_args(args); - RunInternal(new_opts, handle, internal_args, rets, cleanup_items, + std::vector local_args; + for (const auto& tensor : args) { + local_args.push_back(tensor); + } + RunInternal(new_opts, handle, local_args, rets, cleanup_items, std::move(done)); } void ProcessFunctionLibraryRuntime::RunInternal( const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, const InternalArgsView& args, + FunctionLibraryRuntime::Handle handle, gtl::ArraySlice args, std::vector* rets, std::vector>* cleanup_items, FunctionLibraryRuntime::DoneCallback done) const { @@ -1357,9 +1374,11 @@ void ProcessFunctionLibraryRuntime::RunInternal( return; } + std::vector local_args = GetLocalArgs(args); + // Send the args over to the target device. s = SendTensors(source_device, target_device, "arg_", src_incarnation, - args.local_args, device_context, opts.args_alloc_attrs, + local_args, device_context, opts.args_alloc_attrs, rendezvous); if (!s.ok()) { done(s); @@ -1368,7 +1387,7 @@ void ProcessFunctionLibraryRuntime::RunInternal( const std::vector& rets_alloc_attrs = opts.rets_alloc_attrs; std::vector* remote_rets = new std::vector; - flr->Run(opts, handle, args.local_args, remote_rets, + flr->Run(opts, handle, local_args, remote_rets, [source_device, target_device, target_incarnation, rendezvous, device_context, rets_alloc_attrs, remote_rets, rets, done = std::move(done)](const Status& status) mutable { diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 80b52904235..545615a1bea 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -23,6 +23,7 @@ limitations under the License. // clang-format on #include "absl/types/optional.h" +#include "absl/types/variant.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/function.h" @@ -202,27 +203,10 @@ class ProcessFunctionLibraryRuntime { friend class FunctionLibraryRuntimeImpl; struct InternalArgs { - std::vector local_args; + std::vector args; #if !defined(IS_MOBILE_PLATFORM) - std::vector remote_args; -#endif // IS_MOBILE_PLATFORM - }; - - struct InternalArgsView { - public: - explicit InternalArgsView(gtl::ArraySlice tensors) - : local_args(tensors) {} - - explicit InternalArgsView(InternalArgs* args) - : local_args(args->local_args) { -#if !defined(IS_MOBILE_PLATFORM) - remote_args = &args->remote_args; -#endif // IS_MOBILE_PLATFORM - } - - gtl::ArraySlice local_args; -#if !defined(IS_MOBILE_PLATFORM) - std::vector* remote_args = nullptr; + // Holds the RemoteTensorHandles referred by args. + std::vector> remote_args; #endif // IS_MOBILE_PLATFORM }; @@ -290,7 +274,7 @@ class ProcessFunctionLibraryRuntime { virtual void RunRemoteDevice(const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle local_handle, - const InternalArgsView& args, + gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const; @@ -382,7 +366,7 @@ class ProcessFunctionLibraryRuntime { void RunInternal(const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, - const InternalArgsView& args, std::vector* rets, + gtl::ArraySlice args, std::vector* rets, std::vector>* cleanup_items, FunctionLibraryRuntime::DoneCallback done) const; diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 49d657e3bf8..a7758fbcb7e 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -43,6 +43,7 @@ cc_library( "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/distributed_runtime:worker_session", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], ) diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index 4232d5223e5..dfa35086659 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -120,25 +120,29 @@ void EagerClusterFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { - if (args.empty() && rets->empty()) { FunctionLibraryRuntime::Options opts_copy = opts; - opts_copy.op_id = ctx_->RemoteMgr()->NextOpId(); - Run(opts_copy, handle, /*args=*/nullptr, std::move(done)); - } else { - // TODO(b/150963957): Support remote inputs and outputs which are passed as - // Tensors. - done(errors::Unimplemented( - "Not implemented. Users could set the input devices and output devices " - "in FunctionLibraryRuntime::Options to the default multi-device " - "function device as a workaround.")); - } + if (!opts_copy.op_id.has_value()) { + opts_copy.op_id = ctx_->RemoteMgr()->NextOpId(); + } + std::vector function_args; + for (const auto& tensor : args) { + function_args.push_back(tensor); + } + Run(opts_copy, handle, function_args, rets, std::move(done)); } void EagerClusterFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::LocalHandle handle, - std::vector* args, + gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { + if (!rets->empty()) { + // TODO(b/150963957): Support remote outputs which are passed as Tensors. + done(errors::Unimplemented( + "Not implemented. Users could set the output devices in " + "FunctionLibraryRuntime::Options to the default multi-device " + "function device as a workaround.")); + } FunctionData* function_data = nullptr; { mutex_lock l(mu_); @@ -169,11 +173,17 @@ void EagerClusterFunctionLibraryRuntime::Run( eager::EnqueueRequest* request = new eager::EnqueueRequest; request->set_context_id(context_id_); eager::Operation* remote_op = request->add_queue()->mutable_operation(); - if (args) { - for (size_t i = 0; i < args->size(); ++i) { - remote_op->add_inputs()->Swap(&(*args)[i]); + + for (const auto& arg : args) { + if (arg.index() == 0) { + absl::get(arg).AsProtoTensorContent( + remote_op->add_op_inputs()->mutable_tensor()); + } else { + remote_op->add_op_inputs()->mutable_remote_handle()->Swap( + absl::get(arg)); } } + // The remote component function should use the same op_id as its parent // multi-device function's in order to get the global unique op_id generated // by the master context. diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h index da6d5111bcd..bb94093946c 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h @@ -55,7 +55,7 @@ class EagerClusterFunctionLibraryRuntime void Run(const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::LocalHandle handle, - std::vector* args, + gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) override; void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 2d7ee1143c6..73bc42be0c5 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -589,10 +589,8 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { uint64 context_id_; tensorflow::FunctionDef fdef_; std::unique_ptr eager_pflr_; - - private: - FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}}; std::unique_ptr eager_cluster_flr_; + FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}}; }; // Test executes a remote function through @@ -651,6 +649,56 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) { CheckOutputsAndClose(op_id); } +// Test executes a remote function with a local tensor input. +TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) { + Init(); + // Instantiate MatMulFunction on remote_device. + FunctionLibraryRuntime::Handle handle; + EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_)); + Status status; + Notification instantiate_done; + eager_cluster_flr_->Instantiate( + fdef_.signature().name(), func_lib_def_, AttrSlice(&fdef_.attr()), + FunctionLibraryRuntime::InstantiateOptions(), &handle, + [&status, &instantiate_done](const Status& s) { + status = s; + instantiate_done.Notify(); + }); + instantiate_done.WaitForNotification(); + TF_ASSERT_OK(status); + EagerContext* ctx = nullptr; + TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx)); + for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) { + const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name); + EXPECT_TRUE(fdef != nullptr); + if (absl::StartsWith(func_name, "MatMulFunction")) { + EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef)); + } + } + const tensorflow::Tensor* input_tensor = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl_.GetTensorHandle( + context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor)); + + // Send input_tensor to the remote device and execute MatMulFunction on the + // remote device. + FunctionLibraryRuntime::Options opts; + const uint64 op_id = 2; + opts.op_id = op_id; + Notification execute_done; + std::vector inputs = {*input_tensor}; + std::vector outputs; + eager_cluster_flr_->Run(opts, handle, inputs, &outputs, + [&status, &execute_done](const Status& s) { + status = s; + execute_done.Notify(); + }); + execute_done.WaitForNotification(); + TF_ASSERT_OK(status); + CheckOutputsAndClose(op_id); +} + // Test executes a remote function through KernelAndDeviceFunc. TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { Init(); diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index 32a88e0ba8d..8399d57660d 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -145,7 +145,8 @@ void RemoteCopyNode::StartSend() { request.set_context_id(ctx_->GetContextId()); auto* remote_op = request.add_queue()->mutable_operation(); status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle( - src_, remote_op->add_inputs(), absl::get(src_->device()), + src_, remote_op->add_op_inputs()->mutable_remote_handle(), + absl::get(src_->device()), absl::get(src_->DeviceOrHostCPU(*ctx_))->name()); if (!status.ok()) { captured_state_->SetSendStatus(status); diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 640f00a1352..39fffabc774 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -24,6 +24,7 @@ limitations under the License. // clang-format on #include "absl/types/optional.h" +#include "absl/types/variant.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.pb.h" @@ -838,6 +839,15 @@ class CustomKernelCreator { std::unique_ptr* kernel) const = 0; }; +typedef +#if !defined(IS_MOBILE_PLATFORM) + absl::variant + FunctionArg; +#else + absl::variant + FunctionArg; +#endif + // Used to instantiate and run functions in a distributed system. class DistributedFunctionLibraryRuntime { public: @@ -861,7 +871,7 @@ class DistributedFunctionLibraryRuntime { // TODO(yujingzhang): Support outputting tensors on remote devices. virtual void Run(const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::LocalHandle handle, - std::vector* args, + gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { done(errors::Unimplemented("Unimplemented.")); } diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index c2a7553306e..6f2913eae90 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -23,6 +23,7 @@ message Operation { // future. int64 id = 1; string name = 2; + // TODO(b/150963957): Deprecate this. repeated RemoteTensorHandle inputs = 3; message Input { diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index a210ae0419a..2b10f4e520f 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -156,6 +156,18 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): else: self.assertIn('Dimensions must be equal', cm.exception.args[0]) + def testClientVarible(self): + var = variables.Variable(initial_value=0) + + @def_function.function + def func(): + with ops.device('/job:localhost/task:0'): + read = var.read_value() + return read + 1 + + with ops.device('/job:worker/task:0'): + self.assertAllEqual(func(), 1) + class RemoteAsyncTest(test.TestCase): From b80ac685de0f550edb0e507a10030f6d5d3bea85 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 23 Mar 2020 17:30:48 -0700 Subject: [PATCH 469/492] Remove the mlir dependences from the calibration library PiperOrigin-RevId: 302555533 Change-Id: I2e783f2e71bd6489972bce625ff0b87543f0ae1c --- tensorflow/lite/python/lite.py | 2 +- tensorflow/lite/python/optimize/BUILD | 1 - .../python/optimize/calibration_wrapper.cc | 18 ++------ .../python/optimize/calibration_wrapper.h | 2 +- .../optimize/calibration_wrapper_pybind11.cc | 10 ++--- tensorflow/lite/python/optimize/calibrator.py | 7 +-- .../lite/python/optimize/calibrator_test.py | 43 +++++-------------- 7 files changed, 23 insertions(+), 60 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index ca9290e0197..151aecf02cb 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -319,7 +319,7 @@ class TFLiteConverterBase(object): else: return calibrate_quantize.calibrate_and_quantize( self.representative_dataset.input_gen, inference_input_type, - inference_output_type, allow_float, self.experimental_new_quantizer) + inference_output_type, allow_float) def _is_unknown_shapes_allowed(self, fp32_execution): # TODO(b/128319310): Investigate which quantization methods work. diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD index ba75dca9362..53ebba2fcb2 100644 --- a/tensorflow/lite/python/optimize/BUILD +++ b/tensorflow/lite/python/optimize/BUILD @@ -10,7 +10,6 @@ cc_library( srcs = ["calibration_wrapper.cc"], hdrs = ["calibration_wrapper.h"], deps = [ - "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model", "//tensorflow/lite:framework", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index d41ecf79c1a..9a5d1e9aa2f 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -213,8 +212,7 @@ PyObject* CalibrationWrapper::Calibrate() { PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, int output_py_type, - bool allow_float, - bool enable_mlir_quantizer) { + bool allow_float) { if (NoOpModel(*model_)) { return python_utils::ConvertToPyString(model_str_->data(), model_str_->size()); @@ -231,17 +229,9 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false); flatbuffers::FlatBufferBuilder builder; auto status = kTfLiteOk; - if (enable_mlir_quantizer) { - status = mlir::lite::QuantizeModel( - *tflite_model, TfLiteTypeToSchemaType(input_type), - TfLiteTypeToSchemaType(output_type), {}, allow_float, &builder, - error_reporter_.get()); - } else { - status = tflite::optimize::QuantizeModel( - &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type), - TfLiteTypeToSchemaType(output_type), allow_float, - error_reporter_.get()); - } + status = tflite::optimize::QuantizeModel( + &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type), + TfLiteTypeToSchemaType(output_type), allow_float, error_reporter_.get()); if (status != kTfLiteOk) { error_reporter_->exception(); diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h index 7b5ae50e657..449f8ee6b83 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.h +++ b/tensorflow/lite/python/optimize/calibration_wrapper.h @@ -61,7 +61,7 @@ class CalibrationWrapper { PyObject* FeedTensor(PyObject* input_value); PyObject* QuantizeModel(int input_py_type, int output_py_type, - bool allow_float, bool enable_mlir_quantizer = false); + bool allow_float); // Allows quantizing only the operator that produces the tensor with name // operator_output_name. (This can be used to help debug.). diff --git a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc index f56b23090b9..dcecd880a5e 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc @@ -41,16 +41,14 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) { .def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type, int output_py_type, bool allow_float, bool enable_mlir_quantizer) { - return tensorflow::pyo_or_throw( - self.QuantizeModel(input_py_type, output_py_type, allow_float, - enable_mlir_quantizer)); + return tensorflow::pyo_or_throw(self.QuantizeModel( + input_py_type, output_py_type, allow_float)); }) .def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type, int output_py_type, bool allow_float) { - return tensorflow::pyo_or_throw( - self.QuantizeModel(input_py_type, output_py_type, allow_float, - /*enable_mlir_quantizer=*/false)); + return tensorflow::pyo_or_throw(self.QuantizeModel( + input_py_type, output_py_type, allow_float)); }) .def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type, int output_py_type, bool allow_float, diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py index 999ae2ebf48..fb3b87fdaa7 100644 --- a/tensorflow/lite/python/optimize/calibrator.py +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -55,7 +55,7 @@ class Calibrator(object): raise ValueError("Failed to parse the model.") def calibrate_and_quantize(self, dataset_gen, input_type, output_type, - allow_float, enable_mlir_quantizer=False): + allow_float): """Calibrates the model with specified generator and then quantizes it. Returns: @@ -69,16 +69,13 @@ class Calibrator(object): computation, useful when targeting an integer-only backend. If False, an error will be thrown if an operation cannot be quantized, otherwise the model will fallback to float ops. - enable_mlir_quantizer: A boolean. True if wants to use mlir quantizer to - quantize the calibrated model. """ self._calibrator.Prepare() for calibration_sample in dataset_gen(): self._calibrator.FeedTensor(calibration_sample) return self._calibrator.QuantizeModel( np.dtype(input_type.as_numpy_dtype()).num, - np.dtype(output_type.as_numpy_dtype()).num, allow_float, - enable_mlir_quantizer) + np.dtype(output_type.as_numpy_dtype()).num, allow_float) def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type, allow_float, op_output_name): diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py index 28e8723f23d..34e93543f82 100644 --- a/tensorflow/lite/python/optimize/calibrator_test.py +++ b/tensorflow/lite/python/optimize/calibrator_test.py @@ -32,10 +32,7 @@ from tensorflow.python.platform import test class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): - @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_calibration_with_quantization(self, enable_mlir): + def test_calibration_with_quantization(self): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') float_model = open(model_path, 'rb').read() @@ -48,14 +45,10 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): quantized_model = quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, - enable_mlir) + constants.FLOAT, False) self.assertIsNotNone(quantized_model) - @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_calibration_with_quantization_allow_float(self, enable_mlir): + def test_calibration_with_quantization_allow_float(self): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') float_model = open(model_path, 'rb').read() @@ -68,8 +61,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): quantized_model = quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, True, - enable_mlir) + constants.FLOAT, True) self.assertIsNotNone(quantized_model) def test_calibration_with_quantization_single_op(self): @@ -87,10 +79,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): input_gen, constants.FLOAT, constants.FLOAT, True, 'conv2d_8/BiasAdd') self.assertIsNotNone(quantized_model) - @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_calibration_with_quantization_multiple_inputs(self, enable_mlir): + def test_calibration_with_quantization_multiple_inputs(self): # Load multi add model from test data. # This model has 4 inputs of size (1, 8, 8, 3). model_path = resource_loader.get_path_to_datafile( @@ -105,14 +94,10 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): quantized_model = quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, - enable_mlir) + constants.FLOAT, False) self.assertIsNotNone(quantized_model) - @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_invalid_model_buffer(self, enable_mlir): + def test_invalid_model_buffer(self): float_model = b'\0' * 100 with self.assertRaisesRegex(ValueError, 'Failed to parse the model'): _calibrator.Calibrator(float_model) @@ -132,10 +117,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT, constants.FLOAT, False) - @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_invalid_shape_calibrator_gen(self, enable_mlir): + def test_invalid_shape_calibrator_gen(self): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') float_model = open(model_path, 'rb').read() @@ -148,12 +130,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'Size mismatch'): quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, enable_mlir) + constants.FLOAT, False) - @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_invalid_type_calibrator_gen(self, enable_mlir): + def test_invalid_type_calibrator_gen(self): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') float_model = open(model_path, 'rb').read() @@ -166,7 +145,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaises(ValueError): quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, enable_mlir) + constants.FLOAT, False) if __name__ == '__main__': From 0c6cad2a43913c8655056c0d855eed047eea2ba6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 17:36:13 -0700 Subject: [PATCH 470/492] Fix a potential deadlock in GetOrCreateSubStream/ReturnSubStream by not holding Stream::mu_ while destroying bad streams. PiperOrigin-RevId: 302556375 Change-Id: I5aec302fa407aa7ac916f240638196222f2a4b3c --- tensorflow/stream_executor/stream.cc | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index c4564a613e1..8b50eab838c 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/stream_executor/stream.h" -#include "tensorflow/stream_executor/platform/port.h" - #include "absl/strings/str_cat.h" #include "third_party/eigen3/Eigen/Core" #include "tensorflow/stream_executor/blas.h" @@ -24,6 +22,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/stacktrace.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/logging.h" +#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/rng.h" #include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" @@ -1874,6 +1873,10 @@ Stream &Stream::ThenMemcpyH2DQuantized( } Stream *Stream::GetOrCreateSubStream() { + // Do not destroy bad streams when holding mu_ because ~Stream() may + // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_. + std::vector> bad_streams; + absl::MutexLock lock(&mu_); // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams @@ -1897,6 +1900,7 @@ Stream *Stream::GetOrCreateSubStream() { if (index != last) { std::swap(pair, sub_streams_[last]); } + bad_streams.push_back(std::move(sub_streams_.back().first)); sub_streams_.pop_back(); VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " << sub_stream->DebugStreamPointers(); @@ -1921,6 +1925,10 @@ Stream *Stream::GetOrCreateSubStream() { } void Stream::ReturnSubStream(Stream *sub_stream) { + // Do not destroy bad streams when holding mu_ because ~Stream() may + // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_. + std::unique_ptr bad_stream; + absl::MutexLock lock(&mu_); // Look for the sub-stream. @@ -1945,6 +1953,7 @@ void Stream::ReturnSubStream(Stream *sub_stream) { if (index != last) { std::swap(pair, sub_streams_[last]); } + std::swap(bad_stream, sub_streams_.back().first); sub_streams_.pop_back(); } return; From e2a223c5e52f994572358a23da6bf21937ba97c1 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Mon, 23 Mar 2020 17:46:33 -0700 Subject: [PATCH 471/492] Always output the benchmark result by registering the benchmark logging listener in the BenchmarkTfLiteModel constructor. PiperOrigin-RevId: 302558138 Change-Id: I2a718dd79bf898353de936afb5b308dd8b81f207 --- tensorflow/lite/tools/benchmark/benchmark_model.cc | 12 +++--------- .../lite/tools/benchmark/benchmark_tflite_model.cc | 4 +++- .../lite/tools/benchmark/benchmark_tflite_model.h | 3 ++- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc index 854b777dccc..bce683edcf8 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc @@ -205,15 +205,9 @@ TfLiteStatus BenchmarkModel::Run() { const auto overall_mem_usage = profiling::memory::GetMemoryUsage() - start_mem_usage; - const BenchmarkResults final_results( - model_size_mb, startup_latency_us, input_bytes, warmup_time_us, - inference_time_us, init_mem_usage, overall_mem_usage); - listeners_.OnBenchmarkEnd(final_results); - - // We always TFLITE_LOG the benchmark result regardless whether a - // BenchmarkListener is registered or not. - BenchmarkLoggingListener log_output; - log_output.OnBenchmarkEnd(final_results); + listeners_.OnBenchmarkEnd({model_size_mb, startup_latency_us, input_bytes, + warmup_time_us, inference_time_us, init_mem_usage, + overall_mem_usage}); return status; } diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 825879693f3..47ec9f4af0b 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -284,7 +284,9 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params) : BenchmarkModel(std::move(params)), - random_engine_(std::random_device()()) {} + random_engine_(std::random_device()()) { + AddListener(&log_output_); +} void BenchmarkTfLiteModel::CleanUp() { // Free up any pre-allocated tensor data during PrepareInputData. diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index 8e9bad2269a..b56390b3775 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -119,8 +119,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel { std::unique_ptr profiling_listener_ = nullptr; std::unique_ptr ruy_profiling_listener_ = nullptr; std::mt19937 random_engine_; - std::vector owned_delegates_; + // Always TFLITE_LOG the benchmark result. + BenchmarkLoggingListener log_output_; }; } // namespace benchmark From 11597546307caa54416470f4afea6eb741a8af55 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Mon, 23 Mar 2020 17:48:03 -0700 Subject: [PATCH 472/492] Mean optimized for Metal backend. PiperOrigin-RevId: 302558405 Change-Id: Ic24bb8f405ca0ef8150082d8e6cd32f628cd8bbc --- .../lite/delegates/gpu/metal/kernels/mean.cc | 130 +++++++++++------- 1 file changed, 79 insertions(+), 51 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc b/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc index 20ad71eb123..b4e06fb8c0f 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc @@ -37,44 +37,69 @@ namespace tflite { namespace gpu { namespace metal { -std::string GetMeanCode() { - std::string shader_source = R"( +std::string GetMeanCode(const int3& work_group_size) { + const std::string wg_x = std::to_string(work_group_size.x); + const std::string wg_y = std::to_string(work_group_size.y); + std::string c = R"( #include using namespace metal; struct uniforms { int4 src_size; - int4 dst_size; + float4 inv_multipliers; }; $0 kernel void ComputeFunction( $1 + uint tid[[thread_index_in_threadgroup]], + uint3 tid3d[[thread_position_in_threadgroup]], uint3 gid[[thread_position_in_grid]]) { - if (static_cast(gid.x) >= params.dst_size.x || - static_cast(gid.y) >= params.dst_size.y || - static_cast(gid.z) >= params.dst_size.z) { - return; - } - - float4 sum = float4(0.0); - float size = float( params.src_size.x * params.src_size.y); - for (int w = 0; w < params.src_size.x; w++) { - for (int h = 0; h < params.src_size.y; h++) { - const int buffer_index = - (gid.z * params.src_size.y + h) * params.src_size.x + w; - sum += float4(src_buffer[buffer_index]); - } - } - sum /= size; - const int linear_index = - (gid.z * params.dst_size.y + int(gid.y)) * params.dst_size.x + int(gid.x); - - FLT4 value = FLT4(sum); - $2 - output_buffer[linear_index] = value; - } - )"; - return shader_source; + int local_x = static_cast(tid3d.x); + int local_y = static_cast(tid3d.y); + int local_id = static_cast(tid); + int S = static_cast(gid.z); + if (S >= params.src_size.z) return; +)"; + c += " threadgroup float4 accum[" + + std::to_string(work_group_size.x * work_group_size.y) + "];\n"; + c += " accum[local_id] = float4(0.0f);\n"; + c += " int src_offset = S * params.src_size.x * params.src_size.y;\n"; + c += " for (int s_y = local_y; s_y < params.src_size.y; s_y += " + wg_y + + ") {\n"; + c += " for (int s_x = local_x; s_x < params.src_size.x; s_x += " + wg_x + + ") {\n"; + c += " int src_index = src_offset + s_y * params.src_size.x + s_x;\n"; + c += " accum[local_id] += float4(src_buffer[src_index]);\n"; + c += " }\n"; + c += " }\n"; + c += " accum[local_id] *= params.inv_multipliers.x;\n"; + c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n"; + const int total_size = work_group_size.x * work_group_size.y; + int offset = 1; + int reminder = total_size / 4; + for (; reminder >= 8; reminder /= 4, offset *= 4) { + c += " if (local_id < " + std::to_string(reminder) + ") {\n"; + c += " int t = local_id * " + std::to_string(offset * 4) + ";\n"; + c += " float4 sum = accum[t + " + std::to_string(offset) + "];\n"; + c += " sum += accum[t + " + std::to_string(offset * 2) + "];\n"; + c += " sum += accum[t + " + std::to_string(offset * 3) + "];\n"; + c += " accum[t] += sum;\n"; + c += " }\n"; + c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n"; + } + c += " float4 sum = accum[0];\n"; + reminder *= 4; + for (int i = 1; i < reminder; ++i) { + c += " sum += accum[" + std::to_string(offset * i) + "];\n"; + } + c += " FLT4 value = FLT4(sum * params.inv_multipliers.y);\n"; + c += R"( + const int linear_index = static_cast(gid.z); + $2 + dst_buffer[linear_index] = value; +} +)"; + return c; } std::vector Mean(int id, ValueId input_id, @@ -85,17 +110,19 @@ std::vector Mean(int id, ValueId input_id, return {}; } + const int3 work_group_size = int3(16, 16, 1); + auto desc = std::make_shared(); desc->id = id; desc->is_linkable = false; - std::string code = GetMeanCode(); + std::string code = GetMeanCode(work_group_size); desc->shader_source = code; desc->input_buffers = { {input_id, "device FLT4* const src_buffer"}, }; - desc->output_buffer = {output_id, "device FLT4* output_buffer", + desc->output_buffer = {output_id, "device FLT4* dst_buffer", [input_id](const std::map& buffers) { const auto& input_dimension = buffers.find(input_id)->second; @@ -103,31 +130,32 @@ std::vector Mean(int id, ValueId input_id, }}; desc->uniform_buffers = { {"constant uniforms& params", - [input_id, output_id](const std::map& buffers) { - const auto& dimension = buffers.find(input_id)->second; - const auto& output_dimension = buffers.find(output_id)->second; - std::vector uniform_params = { - dimension.w, - dimension.h, - IntegralDivideRoundUp(dimension.c, 4), - 0, - output_dimension.w, - output_dimension.h, - IntegralDivideRoundUp(dimension.c, 4), - 0}; - return GetByteBuffer(uniform_params); + [input_id, output_id, + work_group_size](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + const int src_slices = IntegralDivideRoundUp(src_shape.c, 4); + struct uniforms { + int4 src_size; + float4 inv_multipliers; + }; + uniforms params; + params.src_size = {src_shape.w, src_shape.h, src_slices, 0}; + const double total_size = src_shape.w * src_shape.h; + const double size_0 = work_group_size.x * work_group_size.y; + const double size_1 = total_size / size_0; + params.inv_multipliers.x = 1.0 / size_1; + params.inv_multipliers.y = 1.0 / size_0; + const uint8_t* ptr = reinterpret_cast(¶ms); + return std::vector(ptr, ptr + sizeof(uniforms)); }}, }; - desc->resize_function = [output_id](const std::map& buffers) { + desc->resize_function = [output_id, work_group_size]( + const std::map& buffers) { BHWC dst_shape = buffers.find(output_id)->second; - const uint3 grid = - uint3(dst_shape.w, dst_shape.h, IntegralDivideRoundUp(dst_shape.c, 4)); - const uint3 groups_size = GetWorkGroupSizeForGrid(grid); - int groups_x = IntegralDivideRoundUp(grid.x, groups_size.x); - int groups_y = IntegralDivideRoundUp(grid.y, groups_size.y); - int groups_z = IntegralDivideRoundUp(grid.z, groups_size.z); - return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + const int groups_z = IntegralDivideRoundUp(dst_slices, work_group_size.z); + return std::make_pair(work_group_size, uint3{1, 1, groups_z}); }; return {desc}; } From f5de0a77b3301fa1990eda1047f77c1236324b58 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 23 Mar 2020 17:50:25 -0700 Subject: [PATCH 473/492] Add the quantization specs for the inputs and outputs If the value is annotated by the fake quant ops, the quantization spec is extracted from the fake quant and put in the quantization attributes. PiperOrigin-RevId: 302558753 Change-Id: I26b79ee1eab32f71e4be356bd58f6d815bc19243 --- .../quantization/xla/cpu_kernel_fusion.cc | 64 +++++++++++++++---- .../xla/tests/cpu_kernel_fusion.mlir | 34 +++++++++- 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc index 7bfeb241904..478b9d54176 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -47,21 +48,56 @@ limitations under the License. #define DEBUG_TYPE "quant-kernel-fusion" +constexpr int kFakeQuantOperandsNum = 5; +constexpr int kFakeQuantPerChannelOperandsNum = 6; + namespace mlir { namespace xla_hlo { namespace { +TypeAttr GetQuantSpec(Operation* op) { + auto fake_quant = llvm::dyn_cast_or_null(op); + if (!fake_quant || fake_quant.getNumOperands() < kFakeQuantOperandsNum || + fake_quant.getNumOperands() > kFakeQuantPerChannelOperandsNum || + fake_quant.call_target_name() != "fake_quant_with_min_max_vars") + return {}; + + DenseFPElementsAttr min, max; + DenseIntElementsAttr bit_width, narrow_range, quant_dim; + if (!matchPattern(fake_quant.getOperand(1), m_Constant(&min)) || + !matchPattern(fake_quant.getOperand(2), m_Constant(&max)) || + !matchPattern(fake_quant.getOperand(3), m_Constant(&bit_width)) || + !matchPattern(fake_quant.getOperand(4), m_Constant(&narrow_range))) + return {}; + + auto bit_width_val = (*bit_width.attr_value_begin()).cast(); + auto narrow_range_val = (*narrow_range.int_value_begin()).getSExtValue(); + int quant_dim_val = -1; + if (fake_quant.getNumOperands() == kFakeQuantPerChannelOperandsNum && + matchPattern(fake_quant.getOperand(kFakeQuantPerChannelOperandsNum - 1), + m_Constant(&quant_dim))) { + quant_dim_val = (*quant_dim.int_value_begin()).getSExtValue(); + } + + OpBuilder builder(op); + Type input_type = + fake_quant.getOperand(0).getType().cast().getElementType(); + return quant::GetQuantizedTypeAttr( + builder, input_type, min, max, quant_dim_val, bit_width_val, + builder.getBoolAttr(narrow_range_val), /*is_signed=*/true); +} + // Collects input values from outside for 'ops'. void CollectInputs(llvm::ArrayRef ops, llvm::SmallVectorImpl* inputs, llvm::SmallVectorImpl* input_specs) { - for (auto* op : ops) { - for (auto operand : op->getOperands()) { + for (Operation* op : ops) { + for (Value operand : op->getOperands()) { if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) { continue; } - if (auto* def_op = operand.getDefiningOp()) { + if (Operation* def_op = operand.getDefiningOp()) { if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) { inputs->push_back(operand); } @@ -71,10 +107,13 @@ void CollectInputs(llvm::ArrayRef ops, } } - for (auto input : *inputs) { + for (Value input : *inputs) { ShapedType input_type = input.getType().cast(); - // TODO(fengliuai): detect whether it is from fake quant. - input_specs->push_back(TypeAttr::get(input_type.getElementType())); + if (TypeAttr spec = GetQuantSpec(input.getDefiningOp())) { + input_specs->push_back(spec); + } else { + input_specs->push_back(TypeAttr::get(input_type.getElementType())); + } } } @@ -84,16 +123,19 @@ void CollectRets(llvm::ArrayRef ops, llvm::SmallVectorImpl* rets, llvm::SmallVectorImpl* ret_types, llvm::SmallVectorImpl* ret_specs) { - for (auto* op : ops) { - for (auto result : op->getResults()) { - for (auto* user : result.getUsers()) { + for (Operation* op : ops) { + for (Value result : op->getResults()) { + for (Operation* user : result.getUsers()) { // If there are any user outside of 'ops' if (std::find(ops.begin(), ops.end(), user) == ops.end()) { ShapedType ret_type = result.getType().cast(); rets->push_back(result); ret_types->push_back(ret_type); - // TODO(fengliuai): detect whether it is used by fake quant. - ret_specs->push_back(TypeAttr::get(ret_type.getElementType())); + if (TypeAttr spec = GetQuantSpec(user)) { + ret_specs->push_back(spec); + } else { + ret_specs->push_back(TypeAttr::get(ret_type.getElementType())); + } break; } } diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir index 3ca989b715c..09b0e53c151 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir @@ -1,7 +1,7 @@ // RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s -// CHECK-LABEL: @mul_add -func @mul_add(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { +// CHECK-LABEL: @mul_add_source +func @mul_add_source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> @@ -14,3 +14,33 @@ func @mul_add(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) // CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK: return %[[region]] : tensor<4xf32> } + +// CHECK-LABEL: @mul_add_annotated +func @mul_add_annotated(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) { + %cst = constant dense<0.0> : tensor + %cst_0 = constant dense<255.0> : tensor + %cst_1 = constant dense<8> : tensor + %cst_2 = constant dense : tensor + %qin = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", + has_side_effect = false, name = "custom-call.1"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> + %qw = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", + has_side_effect = false, name = "custom-call.2"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> + %0 = "xla_hlo.multiply"(%qin, %qw) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %r = "xla_hlo.custom_call"(%1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", + has_side_effect = false, name = "custom-call.3"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> + return %r : tensor<2x4xf32> + +// CHECK: %[[region:.*]] = "quant.region" +// CHECK: ^bb0(%arg3: tensor<2x4xf32>, %arg4: tensor<2x4xf32>, %arg5: tensor<2x4xf32>): // no predecessors +// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32> +// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32> +// CHECK: "quant.return"(%[[add]]) : (tensor<2x4xf32>) -> () +// CHECK: }) {input_specs = [!quant.uniform, !quant.uniform, f32], +// CHECK-SAME: logical_kernel = "generic.mul_add", output_specs = [!quant.uniform]} : +// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[r:.*]] = "xla_hlo.custom_call"(%[[region]] +// CHECK: return %[[r]] : tensor<2x4xf32> +} + + From 8581cdd0d0135af3519113d7d335c7a3a0e13ca2 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Mon, 23 Mar 2020 17:57:42 -0700 Subject: [PATCH 474/492] Enable values_test on multi GPU PiperOrigin-RevId: 302559928 Change-Id: Ib6d65a98de035fa28be49b9b96dcf22b07e42f03 --- tensorflow/python/distribute/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 8f6231b7655..2d5ab2ebd0c 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -940,7 +940,7 @@ distribute_py_test( main = "values_test.py", shard_count = 5, tags = [ - # "multi_and_single_gpu", # b/151865826 + "multi_and_single_gpu", ], tpu_tags = [ "no_oss", # Target too big to run serially reliably. From f748283ee01059be52da5dada6e2157d9f6732ba Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 23 Mar 2020 18:08:14 -0700 Subject: [PATCH 475/492] Fix crash if set_visible_devices() is used with tf.keras.mixed_precision. Unfortunately, this required disabling the warning that would appear if mixed precision was used on a GPU that didn't fully support it. A warning will still appear if there is no GPU, but no log will appear if the user does have a GPU, because in that case we cannot tell if the GPU is support or not. I will try to get the warning back by 2.3. PiperOrigin-RevId: 302561652 Change-Id: Ic73d06a4531a052009e83080de7af257042f33e1 --- .../device_compatibility_check.py | 21 ++++++++++++++-- .../mixed_precision/experimental/policy.py | 3 ++- .../experimental/policy_test.py | 24 ++++++++----------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py index d92c16d632f..9279c37bb52 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py +++ b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py @@ -22,6 +22,7 @@ import itertools from tensorflow.python.client import device_lib from tensorflow.python.eager import context +from tensorflow.python.framework import config from tensorflow.python.framework import gpu_util from tensorflow.python.platform import tf_logging @@ -133,7 +134,7 @@ def _log_device_compatibility_check(policy_name, device_attr_list): _logged_compatibility_check = False -def log_device_compatibility_check(policy_name): +def log_device_compatibility_check(policy_name, skip_local): """Logs a compatibility check if the devices support the policy. Currently only logs for the policy mixed_float16. A log is shown only the @@ -141,6 +142,11 @@ def log_device_compatibility_check(policy_name): Args: policy_name: The name of the dtype policy. + skip_local: If True, do not call list_local_devices(). This is useful since + if list_local_devices() and tf.config.set_visible_devices() are both + called, TensorFlow will crash. However, since GPU names and compute + capabilities cannot be checked without list_local_devices(), setting this + to True means the function will only warn if there are no GPUs. """ global _logged_compatibility_check # In graph mode, calling list_local_devices may initialize some session state, @@ -149,5 +155,16 @@ def log_device_compatibility_check(policy_name): return _logged_compatibility_check = True device_attr_list = device_lib.list_local_devices() - _log_device_compatibility_check(policy_name, device_attr_list) + if not skip_local: + _log_device_compatibility_check(policy_name, device_attr_list) + return + # TODO(b/146009447): Create an API to replace list_local_devices(), then + # remove the skip_local paramater. + gpus = config.list_physical_devices('GPU') + if not gpus and policy_name == 'mixed_float16': + tf_logging.warn( + '%s\n' + 'The dtype policy mixed_float16 may run slowly because ' + 'this machine does not have a GPU.\n%s' % + (_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX)) diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py index 9afc3ce9251..f9899679a86 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py @@ -333,7 +333,8 @@ class Policy(object): self._loss_scale = keras_loss_scale_module.get(loss_scale) if name in ('mixed_float16', 'mixed_bloat16'): - device_compatibility_check.log_device_compatibility_check(name) + device_compatibility_check.log_device_compatibility_check(name, + skip_local=True) def _parse_name(self, name): """Parses a Policy name into a compute and variable dtype. diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py index b345039b406..ff809d061cb 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.python.eager import context +from tensorflow.python.framework import config as config_module from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.keras import combinations @@ -173,25 +174,20 @@ class PolicyTest(test.TestCase, parameterized.TestCase): def test_device_compatibility_warning(self): with context.eager_mode(): device_compatibility_check._logged_compatibility_check = False - with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \ - test.mock.patch.object(tf_logging, 'info') as mock_info: + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: mp_policy.Policy('mixed_float16') - if mock_warn.called: + if config_module.list_physical_devices('GPU'): + mock_warn.assert_not_called() + else: self.assertRegexpMatches( mock_warn.call_args[0][0], r'Mixed precision compatibility check \(mixed_float16\): WARNING.*') - mock_info.assert_not_called() - else: - self.assertRegexpMatches( - mock_info.call_args[0][0], - r'Mixed precision compatibility check \(mixed_float16\): OK.*') - # Assert message is only logged once - with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \ - test.mock.patch.object(tf_logging, 'info') as mock_info: - mp_policy.Policy('mixed_float16') - mock_warn.assert_not_called() - mock_info.assert_not_called() + if config_module.list_physical_devices('GPU'): + # Assert message is only logged once + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('mixed_float16') + mock_warn.assert_not_called() @testing_utils.enable_v2_dtype_behavior def test_policy_scope(self): From ac3f66b6549d672ffd63d24712a1b51806cf37d6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 18:21:18 -0700 Subject: [PATCH 476/492] Adds tolerance arguments to pfor test functions. PiperOrigin-RevId: 302563427 Change-Id: I68165eb6052edfaf477ebea28bcc8f664cf8234f --- tensorflow/python/ops/parallel_for/BUILD | 2 -- .../ops/parallel_for/control_flow_ops_test.py | 11 +++++++---- tensorflow/python/ops/parallel_for/test_util.py | 13 +++++++++---- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 6bc33e10a23..88ddf7a7ec8 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -109,8 +109,6 @@ cuda_py_test( name = "control_flow_ops_test", srcs = ["control_flow_ops_test.py"], tags = ["no_rocm"], - # TODO(b/149957923): The test is flaky - xla_enable_strict_auto_jit = False, deps = [ ":control_flow_ops", ":test_util", diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index e6a67efa301..e33b7765ab1 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices @@ -639,7 +640,7 @@ class RandomTest(PForTestCase): # The random values generated in the two implementations are not guaranteed to # match. So we only check the returned shapes. - def run_and_assert_equal(self, targets1, targets2): + def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): outputs = self._run_targets(targets1, targets2) n = len(outputs) // 2 for i in range(n): @@ -737,7 +738,7 @@ class StatelessRandomTest(PForTestCase): # stateless random numbers can generate different random numbers. # TODO(agarwal): switch to checking for actual values matching once # b/149402339 is resolved. - def run_and_assert_equal(self, targets1, targets2): + def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): outputs = self._run_targets(targets1, targets2) n = len(outputs) // 2 for i in range(n): @@ -1735,8 +1736,10 @@ class SpectralTest(PForTestCase, parameterized.TestCase): (fft_ops.irfft2d,), (fft_ops.irfft3d,), ) - # TODO(agarwal): Reenable this once the test flaky is fixed. - def disabled_test_irfft(self, op_func): + def test_irfft(self, op_func): + if config.list_physical_devices("GPU"): + # TODO(b/149957923): The test is flaky + self.skipTest("b/149957923: irfft vectorization flaky") for dtype in (dtypes.complex64, dtypes.complex128): shape = [2, 3, 4, 3, 4] x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) diff --git a/tensorflow/python/ops/parallel_for/test_util.py b/tensorflow/python/ops/parallel_for/test_util.py index c8eed9ca54e..7d8a3d86a77 100644 --- a/tensorflow/python/ops/parallel_for/test_util.py +++ b/tensorflow/python/ops/parallel_for/test_util.py @@ -39,20 +39,25 @@ class PForTestCase(test.TestCase): return self.evaluate(targets1 + targets2) # TODO(agarwal): Allow tests to pass down tolerances. - def run_and_assert_equal(self, targets1, targets2): + def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): outputs = self._run_targets(targets1, targets2) outputs = nest.flatten(outputs) # flatten SparseTensorValues n = len(outputs) // 2 for i in range(n): if outputs[i + n].dtype != np.object: - self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-4) + self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol) else: self.assertAllEqual(outputs[i + n], outputs[i]) - def _test_loop_fn(self, loop_fn, iters, parallel_iterations=None): + def _test_loop_fn(self, + loop_fn, + iters, + parallel_iterations=None, + rtol=1e-4, + atol=1e-5): t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters, parallel_iterations=parallel_iterations) loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1) t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters, parallel_iterations=parallel_iterations) - self.run_and_assert_equal(t1, t2) + self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol) From 5d93f28897393ae8b6d5ed15f1339941d4e93a5b Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 23 Mar 2020 18:26:02 -0700 Subject: [PATCH 477/492] Make flatbuffer_translate_lib dynamic linked To do this, some static registered translated functions are moved to a seperated c++ file and target. Only the binaries requires these translates functions needs to link them statically. This cl also removes some of the tensorflow/core:lib and tensorflow/core:framework dependence from the flatbuffer_translate_lib target. PiperOrigin-RevId: 302564118 Change-Id: I7882013b05f0bd41332e6469eb7dd82d1ccdc628 --- tensorflow/compiler/mlir/lite/BUILD | 54 +- .../compiler/mlir/lite/flatbuffer_export.cc | 1454 ++++++++++++++++ .../compiler/mlir/lite/flatbuffer_export.h | 43 + .../mlir/lite/flatbuffer_export_flags.h | 31 + .../compiler/mlir/lite/flatbuffer_import.cc | 84 +- .../mlir/lite/flatbuffer_translate.cc | 1495 +---------------- .../compiler/mlir/lite/mlir_tflite_runner.cc | 4 +- .../lite/quantization/lite/quantize_model.cc | 2 +- .../mlir/lite/sparsity/sparsify_model.cc | 2 +- .../compiler/mlir/lite/tf_tfl_translate.cc | 4 +- .../mlir/lite/tf_to_tfl_flatbuffer.cc | 2 +- .../compiler/mlir/lite/utils/convert_type.cc | 3 +- .../compiler/mlir/lite/utils/convert_type.h | 2 +- tensorflow/compiler/mlir/tensorflow/BUILD | 3 +- .../mlir/tensorflow/utils/error_util.cc | 2 +- .../mlir/tensorflow/utils/error_util.h | 2 +- 16 files changed, 1663 insertions(+), 1524 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/flatbuffer_export.cc create mode 100644 tensorflow/compiler/mlir/lite/flatbuffer_export.h create mode 100644 tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 03cf9265f3b..c4314a86d92 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -224,7 +224,6 @@ cc_library( deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:support", @@ -421,7 +420,9 @@ cc_library( ], deps = [ ":tensorflow_lite", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -554,14 +555,14 @@ cc_library( cc_library( name = "flatbuffer_translate_lib", srcs = [ + "flatbuffer_export.cc", "flatbuffer_import.cc", - "flatbuffer_translate.cc", "utils/convert_type.cc", ], hdrs = [ + "flatbuffer_export.h", + "flatbuffer_export_flags.h", "flatbuffer_import.h", - "flatbuffer_translate.h", - "flatbuffer_translate_flags.h", "utils/convert_type.h", ], deps = [ @@ -578,9 +579,10 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", - "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:status", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", @@ -601,15 +603,37 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", ], +) + +cc_library( + name = "flatbuffer_translate_registeration", + srcs = [ + "flatbuffer_translate.cc", + ], + deps = [ + ":flatbuffer_translate_lib", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LoopOpsTransforms", + "@llvm-project//mlir:MlirTranslateMain", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], alwayslink = 1, ) tf_cc_binary( name = "flatbuffer_translate", deps = [ - ":flatbuffer_translate_lib", - "@llvm-project//mlir:LoopOpsTransforms", - "@llvm-project//mlir:MlirTranslateMain", + ":flatbuffer_translate_registeration", ], ) @@ -647,10 +671,13 @@ filegroup( tf_cc_binary( name = "tf_tfl_translate", - srcs = [":tf_tfl_translate_main"], + srcs = [ + ":tf_tfl_translate_main", + ], deps = [ ":common", ":flatbuffer_translate_lib", + ":flatbuffer_translate_registeration", ":tensorflow_lite", ":tf_tfl_passes", ":tf_tfl_translate_cl_options", @@ -672,15 +699,18 @@ tf_cc_binary( tf_cc_binary( name = "mlir-tflite-runner", - srcs = ["mlir_tflite_runner.cc"], + srcs = [ + "mlir_tflite_runner.cc", + ], deps = [ ":flatbuffer_translate_lib", + ":flatbuffer_translate_registeration", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc new file mode 100644 index 00000000000..c36f4af9623 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -0,0 +1,1454 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" +#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/tools/versioning/op_version.h" +#include "tensorflow/lite/tools/versioning/runtime_version.h" +#include "tensorflow/lite/version.h" + +using llvm::dyn_cast; +using llvm::formatv; +using llvm::isa; +using llvm::Optional; +using llvm::StringRef; +using llvm::Twine; +using mlir::Dialect; +using mlir::ElementsAttr; +using mlir::FuncOp; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::NoneType; +using mlir::Operation; +using mlir::Region; +using mlir::StringAttr; +using mlir::TensorType; +using mlir::Type; +using mlir::UnknownLoc; +using mlir::Value; +using tensorflow::OpOrArgLocNameMapper; +using tensorflow::OpOrArgNameMapper; +using tensorflow::Status; +using tflite::flex::IsWhitelistedFlexOp; +using xla::StatusOr; + +template +using BufferOffset = flatbuffers::Offset; + +template +using VectorBufferOffset = flatbuffers::Offset>; + +using CustomOptionsOffset = VectorBufferOffset; + +namespace error = tensorflow::error; +namespace tfl = mlir::TFL; + +ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; + +// Use initial buffer size in flatbuffer builder to be same as the initial size +// used by the TOCO export. (It does not explain rationale for this choice.) +constexpr size_t kInitialBufferSize = 10240; + +// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. +// Since tflite doesn't support unsigned for other types, returns error if +// `isSigned` is set to false for other types. +static StatusOr GetTFLiteType(Type type, + bool is_signed = true) { + if (!is_signed && type.isSignlessInteger(8)) { + return tflite::TensorType_UINT8; + } + if (!is_signed) { + return Status(error::INVALID_ARGUMENT, + "'isSigned' can only be set for 8-bits integer type"); + } + switch (type.getKind()) { + case mlir::StandardTypes::F32: + return tflite::TensorType_FLOAT32; + case mlir::StandardTypes::F16: + return tflite::TensorType_FLOAT16; + case mlir::TF::TensorFlowTypes::STRING: + return tflite::TensorType_STRING; + case mlir::TF::TensorFlowTypes::QUINT8: + return tflite::TensorType_UINT8; + case mlir::StandardTypes::Complex: { + auto ftype = type.cast().getElementType(); + if (ftype && ftype.isF32()) { + return tflite::TensorType_COMPLEX64; + } + return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } + case mlir::StandardTypes::Integer: { + const auto& itype = type.cast(); + switch (itype.getWidth()) { + case 1: + return tflite::TensorType_BOOL; + case 8: + return itype.isUnsigned() ? tflite::TensorType_UINT8 + : tflite::TensorType_INT8; + case 16: + return tflite::TensorType_INT16; + case 32: + return tflite::TensorType_INT32; + case 64: + return tflite::TensorType_INT64; + } + } + case mlir::quant::QuantizationTypes::UniformQuantized: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } + case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } + case mlir::TF::TensorFlowTypes::RESOURCE: { + // Treat tf.resource values as integer values in flatbuffer. + // TODO(b/146131919): Maybe need to have a detailed design for supporting + // other resource types beyonds hash table resources and resource + // variables. + return tflite::TensorType_INT32; + } + default: + // TFLite export fills FLOAT32 for unknown data types. Returning an error + // for now for safety and this could be revisited when required. + return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } +} + +static bool IsConst(Operation* op) { + return isa(op) || isa(op) || + isa(op) || isa(op); +} + +template +static bool HasValidTFLiteType(Value value, T& error_handler) { + // None type is allowed to represent unspecified operands. + if (value.getType().isa()) return true; + + auto type = value.getType().dyn_cast(); + if (!type) { + if (auto op = value.getDefiningOp()) { + error_handler.emitError() + << '\'' << op << "' should produce value of tensor type instead of " + << value.getType(); + return false; + } + error_handler.emitError("expected tensor type, got ") << value.getType(); + return false; + } + + Type element_type = type.getElementType(); + auto status = GetTFLiteType(element_type); + if (!status.ok()) { + return error_handler.emitError( + formatv("Failed to convert element type '{0}': {1}", + element_type, status.status().error_message())), + false; + } + return true; +} + +// Returns true if the module holds all the invariants expected by the +// Translator class. +// TODO(hinsu): Now that translation is done by making a single pass over the +// MLIR module, consider inlining these validation checks at the place where +// these invariants are assumed instead of checking upfront. +static bool IsValidTFLiteMlirModule(ModuleOp module) { + MLIRContext* context = module.getContext(); + + // Verify that module has a function named main. + FuncOp main_fn = module.lookupSymbol("main"); + if (!main_fn) { + return emitError(UnknownLoc::get(context), + "should have a function named 'main'"), + false; + } + + for (auto fn : module.getOps()) { + if (fn.getBlocks().size() != 1) { + return fn.emitError("should have exactly one basic block"), false; + } + auto& bb = fn.getBlocks().front(); + + for (auto arg : bb.getArguments()) { + if (!HasValidTFLiteType(arg, fn)) + return fn.emitError("invalid TFLite type: ") << arg.getType(), false; + } + + // Verify that all operations except the terminator have exactly one + // result of type supported by TFLite. + for (auto& inst : bb) { + if (inst.isKnownTerminator()) break; + + for (auto result : inst.getResults()) { + if (!HasValidTFLiteType(result, inst)) + return fn.emitError("invalid TFLite type: ") << result.getType(), + false; + } + } + } + + return true; +} + +static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( + ::mlir::Operation* inst) { + // We pass empty string for the original node_def name since Flex runtime + // does not care about this being set correctly on node_def. There is no + // "easy" (see b/120948529) way yet to get this from MLIR inst. + auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( + inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); + if (!status_or_node_def.ok()) { + inst->emitOpError( + Twine("failed to obtain TensorFlow nodedef with status: " + + status_or_node_def.status().ToString())); + return {}; + } + return std::move(status_or_node_def.ValueOrDie()); +} + +// Converts a mlir padding StringRef to TfLitePadding. +// Returns llvm::None if conversion fails. +static Optional GetTflitePadding(Operation* inst, + llvm::StringRef padding) { + const tflite::Padding padding_attr = + std::move(llvm::StringSwitch(padding) + .Case("SAME", tflite::Padding_SAME) + .Case("VALID", tflite::Padding_VALID)); + if (padding_attr == tflite::Padding_SAME) { + return kTfLitePaddingSame; + } + if (padding_attr == tflite::Padding_VALID) { + return kTfLitePaddingValid; + } + + return inst->emitOpError() << "Invalid padding attribute: " << padding, + llvm::None; +} + +// Extracts TfLitePoolParams from a TFL custom op. +// Template parameter, TFLOp, should be a TFL custom op containing attributes +// generated from TfLitePoolParams. +// Returns llvm::None if conversion fails. +template +static Optional GetTflitePoolParams(Operation* inst, + TFLOp op) { + TfLitePoolParams pool_params; + pool_params.stride_height = op.stride_h().getSExtValue(); + pool_params.stride_width = op.stride_w().getSExtValue(); + pool_params.filter_height = op.filter_h().getSExtValue(); + pool_params.filter_width = op.filter_w().getSExtValue(); + const auto padding = GetTflitePadding(inst, op.padding()); + if (padding) { + pool_params.padding = *padding; + pool_params.activation = kTfLiteActNone; + pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; + return pool_params; + } + + return llvm::None; +} + +namespace { + +// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. +class Translator { + public: + // Translates the given MLIR module into TFLite FlatBuffer format and returns + // the serialized output. Returns llvm::None on unsupported, invalid inputs or + // internal error. + static Optional Translate( + ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); + + private: + enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; + explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, + bool emit_select_tf_ops, bool emit_custom_ops, + OpOrArgNameMapper* op_or_arg_name_mapper) + : module_(module), + name_mapper_(*op_or_arg_name_mapper), + builder_(kInitialBufferSize) { + // The first buffer must be empty according to the schema definition. + empty_buffer_ = tflite::CreateBuffer(builder_); + buffers_.push_back(empty_buffer_); + if (emit_builtin_tflite_ops) { + enabled_op_types_.emplace(OpType::kTfliteBuiltin); + } + if (emit_select_tf_ops) { + enabled_op_types_.emplace(OpType::kSelectTf); + } + if (emit_custom_ops) { + enabled_op_types_.emplace(OpType::kCustomOp); + } + tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); + tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); + } + + Optional TranslateInternal(); + + // Returns TFLite buffer populated with constant value if the operation is + // TFLite constant operation. Otherwise, returns an empty buffer. Emits error + // and returns llvm::None on failure. + Optional> BuildBuffer(Operation* inst); + + // Build TFLite tensor from the given type. This function is for tfl.lstm + // intermediates, which should have UniformQuantizedType. + Optional> BuildTensorFromType( + mlir::Type type, const std::string& name); + + // Builds TFLite tensor from the given value. `buffer_idx` is index of the + // corresponding buffer. Emits error and returns llvm::None on failure. + Optional> BuildTensor(Value value, + const std::string& name, + unsigned buffer_idx); + + // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove + // these 2 functions here. + BufferOffset BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results); + BufferOffset BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results); + + // Build while operator where cond & body are regions. + Optional> BuildWhileOperator( + mlir::TFL::WhileOp op, const std::vector& operands, + const std::vector& results); + + // Builds custom operators. + // Templated on a) data type of custom_option to be stored into flatbuffer, + // and b) TFL custom op type. + template + BufferOffset BuildCustomOperator( + const CustomOptionType& custom_option, const std::string& opcode_name, + TFLOp op, const std::vector& operands, + const std::vector& results); + + BufferOffset BuildNumericVerifyOperator( + mlir::TFL::NumericVerifyOp op, const std::vector& operands, + const std::vector& results); + Optional> + BuildConvolution2DTransposeBiasOperator( + Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, + const std::vector& operands, + const std::vector& results); + Optional> BuildMaxPoolingWithArgMax2DOperator( + Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, + const std::vector& operands, + const std::vector& results); + Optional> BuildMaxUnpooling2DOperator( + Operation* inst, mlir::TFL::MaxUnpooling2DOp op, + const std::vector& operands, + const std::vector& results); + + Optional CreateFlexOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + Optional CreateCustomOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + std::unique_ptr CreateFlexBuilderWithNodeAttrs( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + // Returns opcode index for op identified by the op_name, if already + // available. Otherwise, creates a new OperatorCode using the given `builtin` + // operator and associates it with `op_name`. + uint32_t GetOpcodeIndex(const std::string& op_name, + tflite::BuiltinOperator builtin); + + // Builds operator for the given operation with specified operand and result + // tensor indices. Emits an error and returns llvm::None on failure. + Optional> BuildOperator( + Operation* inst, const std::vector& operands, + const std::vector& results, + const std::vector& intermediates); + + // Build a subgraph with a given name out of the region either corresponding + // to a function's body or while op. + Optional> BuildSubGraph( + const std::string& name, Region* region); + + // Builds Metadata with the given `name` and buffer `content`. + BufferOffset BuildMetadata(StringRef name, + StringRef content); + + // Encodes the `tfl.metadata` dictionary attribute of the module to the + // metadata section in the final model. + Optional>> + CreateMetadataVector(); + + // Uses the tf.entry_function attribute (if set) to initialize the op to name + // mapping. + void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); + + // Determines if the specified operation op's operand at operand_index + // is marked as a stateful operand. + bool IsStatefulOperand(mlir::Operation* op, int operand_index); + + // Returns a unique name for `val`. + std::string UniqueName(mlir::Value val); + + ModuleOp module_; + + tensorflow::OpOrArgNameMapper& name_mapper_; + + flatbuffers::FlatBufferBuilder builder_; + BufferOffset empty_buffer_; + + std::vector> buffers_; + + // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. + absl::flat_hash_map opcode_index_map_; + std::vector> opcodes_; + + // Maps function name to index of the corresponding subgraph in the FlatBuffer + // model. + absl::flat_hash_map subgraph_index_map_; + absl::flat_hash_set enabled_op_types_; + + // Points to TensorFlow and TFLite dialects, respectively. nullptr if the + // dialect is not registered. + const Dialect* tf_dialect_; + const Dialect* tfl_dialect_; + + // The failed ops during legalization. + std::set failed_flex_ops_; + std::set failed_custom_ops_; +}; + +std::string Translator::UniqueName(mlir::Value val) { + return std::string(name_mapper_.GetUniqueName(val)); +} + +Optional> Translator::BuildBuffer( + Operation* inst) { + ElementsAttr attr; + if (auto cst = dyn_cast(inst)) { + // ConstantOp have ElementAttr at this point due to validation of the TFLite + // module. + attr = cst.getValue().cast(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else { + return empty_buffer_; + } + + tensorflow::Tensor tensor; + auto status = tensorflow::ConvertToTensor(attr, &tensor); + if (!status.ok()) { + inst->emitError( + Twine("failed to convert value attribute to tensor with error: " + + status.ToString())); + return llvm::None; + } + + // TensorFlow and TensorFlow Lite use different string encoding formats. + // Convert to TensorFlow Lite format is it's a constant string tensor. + if (tensor.dtype() == tensorflow::DT_STRING) { + ::tflite::DynamicBuffer dynamic_buffer; + auto flat = tensor.flat<::tensorflow::tstring>(); + for (int i = 0; i < flat.size(); ++i) { + const auto& str = flat(i); + dynamic_buffer.AddString(str.c_str(), str.length()); + } + char* tensor_buffer; + int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); + auto buffer_data = + builder_.CreateVector(reinterpret_cast(tensor_buffer), bytes); + free(tensor_buffer); + return tflite::CreateBuffer(builder_, buffer_data); + } + + absl::string_view tensor_data = tensor.tensor_data(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(tensor_data.data()), tensor_data.size()); + return tflite::CreateBuffer(builder_, buffer_data); +} + +Optional> Translator::BuildTensorFromType( + mlir::Type type, const std::string& name) { + auto tensor_type = type.cast(); + + if (!tensor_type.hasStaticShape()) { + return llvm::None; + } + llvm::ArrayRef shape_ref = tensor_type.getShape(); + std::vector shape(shape_ref.begin(), shape_ref.end()); + + auto element_type = tensor_type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); + BufferOffset q_params; + auto qtype = element_type.dyn_cast(); + if (!qtype) { + return llvm::None; + } + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + /*buffer=*/0, builder_.CreateString(name), q_params, + /*is_variable=*/false); +} + +Optional> Translator::BuildTensor( + Value value, const std::string& name, unsigned buffer_idx) { + auto type = value.getType().cast(); + + // TFLite requires tensor shape only for the inputs and constants. + // However, we output all known shapes for better round-tripping + auto check_shape = + [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { + auto is_out_of_range = [](int64_t dim) { + return dim > std::numeric_limits::max(); + }; + + if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) + return mlir::emitError( + value.getLoc(), + "result shape dimensions out of 32 bit int type range"); + + return mlir::success(); + }; + + std::vector shape; + std::vector shape_signature; + if (type.hasStaticShape()) { + llvm::ArrayRef shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape = std::vector(shape_ref.begin(), shape_ref.end()); + } else if (auto* inst = value.getDefiningOp()) { + if (IsConst(inst)) { + // Const op can have a result of dynamic shaped type (e.g. due to constant + // folding), but we can still derive the shape of a constant tensor for + // its attribute type. + mlir::Attribute tensor_attr = inst->getAttr("value"); + llvm::ArrayRef shape_ref = + tensor_attr.getType().cast().getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape = std::vector(shape_ref.begin(), shape_ref.end()); + } + } else if (type.hasRank()) { + llvm::ArrayRef shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape.reserve(shape_ref.size()); + for (auto& dim : shape_ref) { + shape.push_back(dim == -1 ? 1 : dim); + } + shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); + } + + if (auto* inst = value.getDefiningOp()) { + if (auto cst = dyn_cast(inst)) { + // CreateSparsityParameters(cst.s_param()); + } else if (auto cst = dyn_cast(inst)) { + // CreateSparsityParameters(cst.s_param()); + } + } + + Type element_type = type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(type.getElementType()).ValueOrDie(); + + BufferOffset q_params; + if (auto qtype = element_type.dyn_cast()) { + q_params = tflite::CreateQuantizationParameters( + // TODO(fengliuai): min and max values are not stored in the + // quantized type, so both are set to 0. The model couldn't be imported + // to TensorFlow because of this. + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + } else if (auto qtype = + element_type + .dyn_cast()) { + std::vector scales(qtype.getScales().begin(), + qtype.getScales().end()); + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), + builder_.CreateVector(qtype.getZeroPoints()), + tflite::QuantizationDetails_NONE, /*details=*/0, + qtype.getQuantizedDimension()); + } else { + q_params = tflite::CreateQuantizationParameters(builder_); + } + // Check if the value's uses includes an op and usage at an operand index + // marked as a stateful. If so, set the tensor's is_variable as true + // This is v1 ref variable semantics in the TFLite runtime. + bool is_variable = false; + for (auto& use : value.getUses()) { + is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); + if (is_variable) { + break; + } + } + + if (shape_signature.empty()) { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable); + } else { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable, /*sparsity=*/0, + /*shape_signature=*/builder_.CreateVector(shape_signature)); + } +} + +BufferOffset Translator::BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); + int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); + int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); + auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, + else_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_IfOptions, + builtin_options); +} + +BufferOffset Translator::BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); + int body_subgraph_index = subgraph_index_map_.at(op.body().str()); + auto builtin_options = tflite::CreateWhileOptions( + builder_, cond_subgraph_index, body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); +} + +Optional> Translator::BuildWhileOperator( + mlir::TFL::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + auto get_call_index = [&](mlir::Block& b) -> Optional { + if (b.getOperations().size() != 2) return llvm::None; + if (auto call_op = dyn_cast(b.front())) + return subgraph_index_map_.at(call_op.callee().str()); + return llvm::None; + }; + auto body_subgraph_index = get_call_index(op.body().front()); + auto cond_subgraph_index = get_call_index(op.cond().front()); + if (!body_subgraph_index || !cond_subgraph_index) + return op.emitOpError("only single call cond/body while export supported"), + llvm::None; + auto builtin_options = + tflite::CreateWhileOptions(builder_, *cond_subgraph_index, + *body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); +} + +template +BufferOffset Translator::BuildCustomOperator( + const CustomOptionType& custom_option, const std::string& opcode_name, + TFLOp op, const std::vector& operands, + const std::vector& results) { + std::vector custom_option_vector(sizeof(CustomOptionType)); + memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType)); + auto opcode_index = + GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM); + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + builder_.CreateVector(custom_option_vector), + tflite::CustomOptionsFormat_FLEXBUFFERS); +} + +BufferOffset Translator::BuildNumericVerifyOperator( + mlir::TFL::NumericVerifyOp op, const std::vector& operands, + const std::vector& results) { + float tolerance = op.tolerance().convertToFloat(); + return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); +} + +Optional> +Translator::BuildConvolution2DTransposeBiasOperator( + Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, + const std::vector& operands, const std::vector& results) { + TfLiteTransposeConvParams conv_params; + conv_params.stride_height = op.stride_h().getSExtValue(); + conv_params.stride_width = op.stride_w().getSExtValue(); + const auto padding = GetTflitePadding(inst, op.padding()); + if (padding) { + conv_params.padding = *padding; + return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, + operands, results); + } + + return llvm::None; +} + +Optional> +Translator::BuildMaxPoolingWithArgMax2DOperator( + Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, + const std::vector& operands, const std::vector& results) { + const auto pool_params = GetTflitePoolParams(inst, op); + if (pool_params) { + return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, + operands, results); + } + + return llvm::None; +} + +Optional> +Translator::BuildMaxUnpooling2DOperator(Operation* inst, + mlir::TFL::MaxUnpooling2DOp op, + const std::vector& operands, + const std::vector& results) { + const auto pool_params = GetTflitePoolParams(inst, op); + if (pool_params) { + return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, + results); + } + + return llvm::None; +} + +Optional Translator::CreateFlexOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + std::string node_def_str; + if (!node_def.SerializeToString(&node_def_str)) { + return emitError(loc, "failed to serialize tensorflow node_def"), + llvm::None; + } + + auto flex_builder = absl::make_unique(); + flex_builder->Vector([&]() { + flex_builder->String(node_def.op()); + flex_builder->String(node_def_str); + }); + flex_builder->Finish(); + return builder_.CreateVector(flex_builder->GetBuffer()); +} + +Optional Translator::CreateCustomOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + std::string node_def_str; + if (!node_def.SerializeToString(&node_def_str)) { + return emitError(loc, "failed to serialize tensorflow node_def"), + llvm::None; + } + auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); + return builder_.CreateVector(flex_builder->GetBuffer()); +} + +std::unique_ptr +Translator::CreateFlexBuilderWithNodeAttrs( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + auto flex_builder = absl::make_unique(); + size_t map_start = flex_builder->StartMap(); + for (const auto& pair : node_def.attr()) { + const char* key = pair.first.c_str(); + const auto& attr = pair.second; + switch (attr.value_case()) { + case ::tensorflow::AttrValue::kS: + flex_builder->String(key, attr.s()); + break; + case ::tensorflow::AttrValue::kType: { + auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); + if (status_or_tfl_type.ok()) { + flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); + } else { + emitWarning(loc, "ignoring unsupported tensorflow type: ") + << std::to_string(attr.type()); + } + break; + } + case ::tensorflow::AttrValue::kI: + flex_builder->Int(key, attr.i()); + break; + case ::tensorflow::AttrValue::kF: + flex_builder->Float(key, attr.f()); + break; + case ::tensorflow::AttrValue::kB: + flex_builder->Bool(key, attr.b()); + break; + case tensorflow::AttrValue::kList: + if (attr.list().s_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const std::string& v : attr.list().s()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else if (attr.list().i_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const int64_t v : attr.list().i()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else if (attr.list().f_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const float v : attr.list().f()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else { + emitWarning(loc, + "ignoring unsupported type in list attribute with key: ") + << key; + } + break; + default: + emitWarning(loc, "ignoring unsupported attribute type with key: ") + << key; + break; + } + } + flex_builder->EndMap(map_start); + flex_builder->Finish(); + return flex_builder; +} + +uint32_t Translator::GetOpcodeIndex(const std::string& op_name, + tflite::BuiltinOperator builtin) { + auto it = opcode_index_map_.insert({op_name, 0}); + + // If the insert succeeded, the opcode has not been created already. Create a + // new operator code and update its index value in the map. + if (it.second) { + it.first->second = opcodes_.size(); + auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM + ? builder_.CreateString(op_name) + : BufferOffset(); + // Use version 0 for builtin op. This is a way to serialize version field to + // flatbuffer (since 0 is non default) and it will be corrected later. + int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; + opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, + custom_code, op_version)); + } + return it.first->second; +} + +Optional> Translator::BuildOperator( + Operation* inst, const std::vector& operands, + const std::vector& results, + const std::vector& intermediates) { + const auto* dialect = inst->getDialect(); + if (!dialect) { + inst->emitOpError("dialect is not registered"); + return llvm::None; + } + + // If TFLite built in op, create operator as a builtin op. + if (dialect == tfl_dialect_) { + // Only if built-in TFLite op emission is enabled, would legalization have + // converted any TF->TFL. + if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { + return inst->emitOpError( + "is a TFLite builtin op but builtin emission is not enabled"), + llvm::None; + } + + auto builtin_code = GetBuiltinOpCode(inst); + if (!builtin_code) { + if (auto verify_op = dyn_cast(inst)) { + return BuildNumericVerifyOperator(verify_op, operands, results); + } + if (auto conv_transpose_bias_op = + dyn_cast(inst)) { + return BuildConvolution2DTransposeBiasOperator( + inst, conv_transpose_bias_op, operands, results); + } + if (auto max_pooling_with_arg_max_op = + dyn_cast(inst)) { + return BuildMaxPoolingWithArgMax2DOperator( + inst, max_pooling_with_arg_max_op, operands, results); + } + if (auto max_unpooling_op = dyn_cast(inst)) { + return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, + results); + } + if (auto whileOp = dyn_cast(inst)) { + if (inst->getNumOperands() != inst->getNumResults()) { + inst->emitOpError( + "number of operands and results don't match, only canonical " + "TFL While supported"); + return llvm::None; + } + return BuildWhileOperator(whileOp, operands, results); + } + + inst->emitOpError("is not a supported TFLite op"); + return llvm::None; + } + + std::string op_name = inst->getName().getStringRef().str(); + uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); + auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, + results, intermediates, &builder_); + if (!offset) { + inst->emitOpError("is not a supported TFLite op"); + } + return offset; + } + + if (dialect == tf_dialect_) { + std::string op_name; + if (auto ifOp = dyn_cast(inst)) { + return BuildIfOperator(ifOp, operands, results); + } else if (auto whileOp = dyn_cast(inst)) { + return BuildWhileOperator(whileOp, operands, results); + } + + CustomOptionsOffset custom_options; + + // Ops in TF dialect can either be custom ops or flex ops. + // The reason we go directly from TensorFlow dialect MLIR to tensorflow + // node instead of going to TF table gen'd ops via generated code is that + // we do not want to restrict custom and flex op conversion support to + // only those TF ops that are currently registered in MLIR. The current + // model is of an open op system. + // + // The following algorithm is followed: + // if flex is enabled and the op is whitelisted as flex + // we emit op as flex. + // if custom is enabled + // we emit the op as custom. + auto node_def = GetTensorFlowNodeDef(inst); + if (!node_def) { + return llvm::None; + } + + // Flex op case + // Eventually, the whitelist will go away and we will rely on some TF op + // trait (e.g. No side effect) to determine if it is a supported "Flex" + // op or not. + if (enabled_op_types_.contains(OpType::kSelectTf) && + IsWhitelistedFlexOp(node_def->op())) { + // Construct ops as flex op encoding TensorFlow node definition + // as custom options. + // Flex ops are named with the kFlexOpNamePrefix prefix to the actual + // TF op name. + op_name = std::string(kFlexOpNamePrefix) + node_def->op(); + if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { + return llvm::None; + } + } else if (enabled_op_types_.contains(OpType::kCustomOp)) { + // Generic case of custom ops - write using flex buffers since that + // is the only custom options supported by TFLite today. + op_name = node_def->op(); + if (auto options = + CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { + return llvm::None; + } + } else { + // Create description of operation that could not be converted. + const int kLargeElementsAttr = 16; + std::string op_str; + llvm::raw_string_ostream os(op_str); + inst->getName().print(os); + // Print out attributes except for large elementsattributes (which should + // rarely be the cause why the legalization didn't happen). + if (!inst->getAttrList().getAttrs().empty()) { + os << " {"; + bool first = true; + for (auto& named_attr : inst->getAttrList().getDictionary()) { + os << (!first ? ", " : ""); + first = false; + named_attr.first.print(os); + os << " = "; + if (auto element_attr = named_attr.second.dyn_cast()) { + if (element_attr.getNumElements() <= kLargeElementsAttr) { + element_attr.print(os); + } else { + os << ""; + } + } else { + named_attr.second.print(os); + } + } + os << "}"; + } + + // Insert failed op to `flex_ops` or `custom_ops`. + if (IsWhitelistedFlexOp(node_def->op())) { + failed_flex_ops_.insert(os.str()); + } else { + failed_custom_ops_.insert(os.str()); + } + return inst->emitOpError("is neither a custom op nor a flex op"), + llvm::None; + } + + uint32_t opcode_index = + GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + /*custom_options=*/custom_options, + tflite::CustomOptionsFormat_FLEXBUFFERS, + /*mutating_variable_inputs=*/0); + } + + return inst->emitOpError( + "is not any of a builtin TFLite op, a flex TensorFlow op or a " + "custom TensorFlow op"), + llvm::None; +} + +void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { + auto dict_attr = fn.getAttrOfType("tf.entry_function"); + if (!dict_attr) return; + + llvm::SmallVector input_names; + llvm::SmallVector output_names; + if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { + str.getValue().split(input_names, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + if (input_names.size() != fn.getNumArguments()) { + fn.emitWarning() << "invalid entry function specification"; + return; + } + for (auto it : llvm::enumerate(fn.getArguments())) { + name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); + } + *has_input_attr = true; + } + + if (auto str = + dict_attr.get("outputs").dyn_cast_or_null()) { + str.getValue().split(output_names, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + auto term = fn.getBlocks().back().getTerminator(); + if (output_names.size() != term->getNumOperands()) { + fn.emitWarning() << "output names (" << output_names.size() + << ") != terminator operands (" << term->getNumOperands() + << ")"; + return; + } + for (const auto& it : llvm::enumerate(term->getOperands())) { + name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); + } + } +} + +bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { + std::vector operand_indices; + if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; + return absl::c_find(operand_indices, operand_index) != operand_indices.end(); +} + +Optional> Translator::BuildSubGraph( + const std::string& name, Region* region) { + bool has_input_attr = false; + if (auto fn = dyn_cast(region->getParentOp())) { + InitializeNamesFromAttribute(fn, &has_input_attr); + } + std::vector> tensors; + llvm::DenseMap tensor_index_map; + + // Builds tensor and buffer for argument or operation result. Returns false + // on failure. + auto build_tensor_and_buffer = [&](Value value, const std::string& name) { + // NoneType represents optional and may be skipped here. + if (value.getType().isa()) { + return true; + } + + tensor_index_map.insert({value, tensors.size()}); + auto tensor_or = BuildTensor(value, name, buffers_.size()); + if (!tensor_or) return false; + tensors.push_back(*tensor_or); + + // TODO(ashwinm): Check if for stateful tensors, if it is also needed to + // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. + // This does not seem to affect runtime behavior for RNN/LSTM, but would be + // good for reducing memory footprint. + if (auto* inst = value.getDefiningOp()) { + auto buffer_or = BuildBuffer(inst); + if (!buffer_or) return false; + buffers_.push_back(*buffer_or); + } else { + buffers_.push_back(empty_buffer_); + } + return true; + }; + + std::vector> operators; + auto& bb = region->front(); + + // Main function's arguments are first passed to `input` op so they don't + // have associated tensor and buffer. Build FlatBuffer tensor and buffer for + // other functions. + for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { + mlir::BlockArgument arg = bb.getArgument(i); + std::string name; + if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg)); + if (name.empty()) name = absl::StrCat("arg", i); + if (!build_tensor_and_buffer(arg, name)) return llvm::None; + } + + bool failed_once = false; + for (auto& inst : bb) { + if (inst.isKnownTerminator()) break; + std::vector intermediates; + // Build intermediate tensors for tfl.lstm and insert these tensors into + // flatbuffer. + if (llvm::isa(inst)) { + std::vector intermediate_names = { + "input_to_input_intermediate", "input_to_forget_intermediate", + "input_to_cell_intermediate", "input_to_output_intermediate", + "effective_hidden_scale_intermediate"}; + for (const std::string& intermediate : intermediate_names) { + auto intermediate_attr = inst.getAttr(intermediate); + if (auto attr = intermediate_attr.dyn_cast_or_null()) { + Type qtype = attr.getValue(); + auto tensor_or = BuildTensorFromType( + qtype, name_mapper_.GetUniqueName(intermediate).str()); + if (!tensor_or.hasValue()) { + continue; + } else { + intermediates.push_back(tensors.size()); + tensors.push_back(tensor_or.getValue()); + } + } + } + } + + for (auto val : inst.getResults()) { + std::string name = UniqueName(val); + if (!build_tensor_and_buffer(val, name)) return llvm::None; + } + + // Skip constant ops as they don't represent a TFLite operator. + if (IsConst(&inst)) continue; + + // Fetch operand and result tensor indices. + std::vector operands; + operands.reserve(inst.getNumOperands()); + for (auto operand : inst.getOperands()) { + if (operand.getType().isa()) + operands.push_back(kTfLiteOptionalTensor); + else + operands.push_back(tensor_index_map.lookup(operand)); + } + std::vector results; + results.reserve(inst.getNumOperands()); + for (auto result : inst.getResults()) { + results.push_back(tensor_index_map.lookup(result)); + } + + if (auto tfl_operator = + BuildOperator(&inst, operands, results, intermediates)) + operators.push_back(*tfl_operator); + else + failed_once = true; + } + + if (failed_once) return llvm::None; + + // Get input and output tensor indices for the subgraph. + std::vector inputs, outputs; + for (auto arg : bb.getArguments()) { + inputs.push_back(tensor_index_map[arg]); + } + for (auto result : bb.getTerminator()->getOperands()) { + outputs.push_back(tensor_index_map[result]); + } + + return tflite::CreateSubGraph( + builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), + builder_.CreateVector(outputs), builder_.CreateVector(operators), + /*name=*/builder_.CreateString(name)); +} + +BufferOffset Translator::BuildMetadata(StringRef name, + StringRef content) { + auto buffer_index = buffers_.size(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(content.data()), content.size()); + buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); + return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); +} + +Optional>> +Translator::CreateMetadataVector() { + auto dict_attr = module_.getAttrOfType("tfl.metadata"); + std::vector> metadata; + if (dict_attr) { + for (const auto& named_attr : dict_attr) { + StringRef name = named_attr.first; + mlir::Attribute attr = named_attr.second; + if (auto content = attr.dyn_cast()) { + metadata.push_back(BuildMetadata(name, content.getValue())); + } else { + module_.emitError( + "all values in tfl.metadata's dictionary key-value pairs should be " + "string attributes"); + return llvm::None; + } + } + } + // Runtime version string is generated after we update the op + // versions. Here we put a 16-byte dummy string as a placeholder. We choose + // 16-byte because it's the alignment of buffers in flatbuffer, so it won't + // cause any waste of space if the actual string is shorter than 16 bytes. + metadata.push_back( + BuildMetadata("min_runtime_version", std::string(16, '\0'))); + return builder_.CreateVector(metadata); +} + +bool UpdateEntryFunction(ModuleOp module) { + if (module.lookupSymbol("main") != nullptr) { + // We already have an entry function. + return true; + } + + int entry_func_count = 0; + FuncOp entry_func = nullptr; + for (auto fn : module.getOps()) { + auto attrs = fn.getAttrOfType("tf.entry_function"); + if (attrs && !attrs.empty()) { + entry_func_count++; + entry_func = fn; + } + } + + // We should have one & only have one entry function. + if (entry_func_count != 1) return false; + + // Update the entry func to main. + entry_func.setName("main"); + return true; +} + +Optional Translator::Translate( + ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { + if (!UpdateEntryFunction(module)) return llvm::None; + if (!IsValidTFLiteMlirModule(module)) return llvm::None; + Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_or_arg_name_mapper); + return translator.TranslateInternal(); +} + +Optional Translator::TranslateInternal() { + // A list of named regions in the module with main function being the first in + // the list. The main function is required as the first subgraph in the model + // is entry point for the model. + std::vector> named_regions; + named_regions.reserve(std::distance(module_.begin(), module_.end())); + + int subgraph_idx = 0; + FuncOp main_fn = module_.lookupSymbol("main"); + subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; + named_regions.emplace_back("main", &main_fn.getBody()); + // Walk over the module collection ops with functions and while ops. + module_.walk([&](FuncOp fn) { + if (fn != main_fn) { + subgraph_index_map_[fn.getName().str()] = subgraph_idx++; + named_regions.emplace_back(fn.getName().str(), &fn.getBody()); + } + }); + + // Build subgraph for each of the named regions. + std::vector> subgraphs; + subgraphs.reserve(named_regions.size()); + int first_failed_func = -1; + for (auto it : llvm::enumerate(named_regions)) { + auto subgraph_or = BuildSubGraph(it.value().first, it.value().second); + if (!subgraph_or) { + if (first_failed_func == -1) + // Record the index of the first region that cannot be converted. + // Keep looping through all subgraphs in the module to make sure that + // we collect the list of missing ops from the entire module. + first_failed_func = it.index(); + } else { + subgraphs.push_back(*subgraph_or); + } + } + + if (first_failed_func != -1) { + std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); + std::string failed_custom_ops_list = + absl::StrJoin(failed_custom_ops_, "\n\t"); + std::string err; + if (!failed_flex_ops_list.empty()) + err += + "Ops that can be supported by the flex runtime (enabled via setting " + "the -emit-select-tf-ops flag):\n\t" + + failed_flex_ops_list; + if (!failed_custom_ops_list.empty()) + err += + "Ops that need custom implementation (enabled via setting the " + "-emit-custom-ops flag):\n\t" + + failed_custom_ops_list; + + auto& failed_region = named_regions[first_failed_func]; + return failed_region.second->getParentOp()->emitError() + << "failed while converting: '" << failed_region.first + << "': " << err, + llvm::None; + } + + std::string model_description; + if (auto attr = module_.getAttrOfType("tfl.description")) { + model_description = attr.getValue().str(); + } else { + model_description = "MLIR Converted."; + } + + // Build the model and finish the model building process. + auto description = builder_.CreateString(model_description.data()); + VectorBufferOffset metadata_buffer = 0; // Deprecated + auto metadata = CreateMetadataVector(); + if (!metadata) return llvm::None; + + auto model = tflite::CreateModel( + builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), + builder_.CreateVector(subgraphs), description, + builder_.CreateVector(buffers_), metadata_buffer, *metadata); + tflite::FinishModelBuffer(builder_, model); + tflite::UpdateOpVersion(builder_.GetBufferPointer()); + tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); + + // Return serialized string for the built FlatBuffer. + return std::string(reinterpret_cast(builder_.GetBufferPointer()), + builder_.GetSize()); +} + +} // namespace + +// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer +// format. Returns false on success. +// +// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting +// the following: +// +// * Quantization +// * Ops with variable tensors +// +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + OpOrArgNameMapper* op_or_arg_name_mapper) { + auto maybe_translated = + Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_or_arg_name_mapper); + if (!maybe_translated) return true; + *serialized_flatbuffer = std::move(*maybe_translated); + return false; +} + +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops) { + OpOrArgLocNameMapper op_or_arg_name_mapper; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); +} diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/flatbuffer_export.h new file mode 100644 index 00000000000..0fbf2f07dfb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ + +#include + +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" + +namespace tflite { + +// Translates the given MLIR `module` into a FlatBuffer and stores the +// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to +// convert location of the op to name in flatbuffer. Returns true if translation +// fails, otherwise returns false. +bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, + std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, + bool emit_select_tf_ops, + bool emit_custom_ops); + +// Same as the above but with a custom op name mapper. +bool MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h b/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h new file mode 100644 index 00000000000..4e891a5b266 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ + +#include + +// These flags are used to control the emission or not of different kinds of ops +// during the flatbuffer translation. +extern bool emit_builtin_tflite_ops; +extern bool emit_select_tf_ops; +extern bool emit_custom_ops; +// The flag to control whether to lower tensorlist ops into TF ops. +extern bool lower_tensor_list_ops; +// The flag to control whether debug info gets stripped on export. +extern bool strip_debug_info; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 4f6d11394d4..3ad625f6e08 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -63,20 +63,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -100,45 +96,6 @@ using xla::StatusOr; namespace errors = tensorflow::errors; namespace tfl = mlir::TFL; -using llvm::cl::opt; - -// Commandline flag to enable the control of flatbuffer import. -bool use_external_constant; - -// Commandline flag to enable graph pruning. -bool experimental_prune_unreachable_nodes_unconditionally; - -// NOLINTNEXTLINE -static opt use_external_constant_flag( - "use-external-constant", - llvm::cl::desc("Use external constant during flatbuffer import"), - llvm::cl::location(use_external_constant), llvm::cl::init(false)); - -// TODO(b/147111261): After the importer supports generic custom ops, we should -// change the flag to a more lightwise flag, e.g. -// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune -// the operations. -// NOLINTNEXTLINE -static opt experimental_prune_unreachable_nodes_unconditionally_flg( - "experimental-prune-unreachable-nodes-unconditionally", - llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), - llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static opt input_arrays_flag( - "input-arrays", - llvm::cl::desc( - "List of input tensors, if different from the default inputs"), - llvm::cl::init("")); - -// NOLINTNEXTLINE -static opt output_arrays_flag( - "output-arrays", - llvm::cl::desc( - "List of output tensors, if different from the default outputs"), - llvm::cl::init("")); - namespace { bool IsScalar(const TensorT& tensor) { // TODO(b/138222071) We can't distinguish scalars and unranked tensors @@ -1063,42 +1020,3 @@ OwningModuleRef tflite::FlatBufferToMlir( return OwningModuleRef(module); } - -static OwningModuleRef FlatBufferFileToMlirTrans( - llvm::SourceMgr* source_mgr, MLIRContext* context, - bool use_external_constant, - bool experimental_prune_unreachable_nodes_unconditionally) { - const llvm::MemoryBuffer* input = - source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); - std::string error; - auto loc = - mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); - - // Parses input/output names from command line options. - std::vector inputs; - std::vector outputs; - // Use output parser since we only have tensor names. - if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) { - return emitError(loc, "parsing input array info failed ") - << input_arrays_flag, - nullptr; - } - if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) { - return emitError(loc, "parsing output array info failed ") - << output_arrays_flag, - nullptr; - } - - return tflite::FlatBufferToMlir( - absl::string_view(input->getBufferStart(), input->getBufferSize()), - context, loc, use_external_constant, inputs, outputs, - experimental_prune_unreachable_nodes_unconditionally); -} - -static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( - "tflite-flatbuffer-to-mlir", - [](llvm::SourceMgr& source_mgr, MLIRContext* context) { - return FlatBufferFileToMlirTrans( - &source_mgr, context, use_external_constant, - experimental_prune_unreachable_nodes_unconditionally); - }); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 4163d13c36c..5b95b30a96c 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -13,31 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "flatbuffers/flexbuffers.h" // from @flatbuffers -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" @@ -56,67 +31,48 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/string_util.h" -#include "tensorflow/lite/tools/versioning/op_version.h" -#include "tensorflow/lite/tools/versioning/runtime_version.h" -#include "tensorflow/lite/version.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -using llvm::dyn_cast; -using llvm::formatv; -using llvm::isa; -using llvm::Optional; -using llvm::StringRef; -using llvm::Twine; -using mlir::Dialect; -using mlir::ElementsAttr; -using mlir::FuncOp; -using mlir::MLIRContext; -using mlir::ModuleOp; -using mlir::NoneType; -using mlir::Operation; -using mlir::Region; -using mlir::StringAttr; -using mlir::TensorType; -using mlir::TranslateFromMLIRRegistration; -using mlir::Type; -using mlir::UnknownLoc; -using mlir::Value; -using tensorflow::OpOrArgLocNameMapper; -using tensorflow::OpOrArgNameMapper; -using tensorflow::Status; -using tflite::flex::IsWhitelistedFlexOp; -using xla::StatusOr; +using llvm::cl::opt; -template -using BufferOffset = flatbuffers::Offset; +// Commandline flag to enable the control of flatbuffer import. +bool use_external_constant; -template -using VectorBufferOffset = flatbuffers::Offset>; +// Commandline flag to enable graph pruning. +bool experimental_prune_unreachable_nodes_unconditionally; -using CustomOptionsOffset = VectorBufferOffset; +// NOLINTNEXTLINE +static opt use_external_constant_flag( + "use-external-constant", + llvm::cl::desc("Use external constant during flatbuffer import"), + llvm::cl::location(use_external_constant), llvm::cl::init(false)); -namespace error = tensorflow::error; -namespace tfl = mlir::TFL; +// TODO(b/147111261): After the importer supports generic custom ops, we should +// change the flag to a more lightwise flag, e.g. +// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune +// the operations. +// NOLINTNEXTLINE +static opt experimental_prune_unreachable_nodes_unconditionally_flg( + "experimental-prune-unreachable-nodes-unconditionally", + llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), + llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), + llvm::cl::init(false)); +// NOLINTNEXTLINE +static opt input_arrays_flag( + "input-arrays", + llvm::cl::desc( + "List of input tensors, if different from the default inputs"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +static opt output_arrays_flag( + "output-arrays", + llvm::cl::desc( + "List of output tensors, if different from the default outputs"), + llvm::cl::init("")); using llvm::cl::opt; // These command line flags enable control of the translation implementation. @@ -157,1353 +113,48 @@ static opt strip_debug_info_flag( "strip-debug-info", llvm::cl::desc("Strip debug info during export"), llvm::cl::location(strip_debug_info), llvm::cl::init(false)); -ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; - -// Use initial buffer size in flatbuffer builder to be same as the initial size -// used by the TOCO export. (It does not explain rationale for this choice.) -constexpr size_t kInitialBufferSize = 10240; - -// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. -// Since tflite doesn't support unsigned for other types, returns error if -// `isSigned` is set to false for other types. -static StatusOr GetTFLiteType(Type type, - bool is_signed = true) { - if (!is_signed && type.isSignlessInteger(8)) { - return tflite::TensorType_UINT8; - } - if (!is_signed) { - return Status(error::INVALID_ARGUMENT, - "'isSigned' can only be set for 8-bits integer type"); - } - switch (type.getKind()) { - case mlir::StandardTypes::F32: - return tflite::TensorType_FLOAT32; - case mlir::StandardTypes::F16: - return tflite::TensorType_FLOAT16; - case mlir::TF::TensorFlowTypes::STRING: - return tflite::TensorType_STRING; - case mlir::TF::TensorFlowTypes::QUINT8: - return tflite::TensorType_UINT8; - case mlir::StandardTypes::Complex: { - auto ftype = type.cast().getElementType(); - if (ftype && ftype.isF32()) { - return tflite::TensorType_COMPLEX64; - } - return Status(error::INVALID_ARGUMENT, "Unsupported type"); - } - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return tflite::TensorType_BOOL; - case 8: - return itype.isUnsigned() ? tflite::TensorType_UINT8 - : tflite::TensorType_INT8; - case 16: - return tflite::TensorType_INT16; - case 32: - return tflite::TensorType_INT32; - case 64: - return tflite::TensorType_INT64; - } - } - case mlir::quant::QuantizationTypes::UniformQuantized: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::TF::TensorFlowTypes::RESOURCE: { - // Treat tf.resource values as integer values in flatbuffer. - // TODO(b/146131919): Maybe need to have a detailed design for supporting - // other resource types beyonds hash table resources and resource - // variables. - return tflite::TensorType_INT32; - } - default: - // TFLite export fills FLOAT32 for unknown data types. Returning an error - // for now for safety and this could be revisited when required. - return Status(error::INVALID_ARGUMENT, "Unsupported type"); - } -} - -static bool IsConst(Operation* op) { - return isa(op) || isa(op) || - isa(op) || isa(op); -} - -template -static bool HasValidTFLiteType(Value value, T& error_handler) { - // None type is allowed to represent unspecified operands. - if (value.getType().isa()) return true; - - auto type = value.getType().dyn_cast(); - if (!type) { - if (auto op = value.getDefiningOp()) { - error_handler.emitError() - << '\'' << op << "' should produce value of tensor type instead of " - << value.getType(); - return false; - } - error_handler.emitError("expected tensor type, got ") << value.getType(); - return false; - } - - Type element_type = type.getElementType(); - auto status = GetTFLiteType(element_type); - if (!status.ok()) { - return error_handler.emitError( - formatv("Failed to convert element type '{0}': {1}", - element_type, status.status().error_message())), - false; - } - return true; -} - -// Returns true if the module holds all the invariants expected by the -// Translator class. -// TODO(hinsu): Now that translation is done by making a single pass over the -// MLIR module, consider inlining these validation checks at the place where -// these invariants are assumed instead of checking upfront. -static bool IsValidTFLiteMlirModule(ModuleOp module) { - MLIRContext* context = module.getContext(); - - // Verify that module has a function named main. - FuncOp main_fn = module.lookupSymbol("main"); - if (!main_fn) { - return emitError(UnknownLoc::get(context), - "should have a function named 'main'"), - false; - } - - for (auto fn : module.getOps()) { - if (fn.getBlocks().size() != 1) { - return fn.emitError("should have exactly one basic block"), false; - } - auto& bb = fn.getBlocks().front(); - - for (auto arg : bb.getArguments()) { - if (!HasValidTFLiteType(arg, fn)) - return fn.emitError("invalid TFLite type: ") << arg.getType(), false; - } - - // Verify that all operations except the terminator have exactly one - // result of type supported by TFLite. - for (auto& inst : bb) { - if (inst.isKnownTerminator()) break; - - for (auto result : inst.getResults()) { - if (!HasValidTFLiteType(result, inst)) - return fn.emitError("invalid TFLite type: ") << result.getType(), - false; - } - } - } - - return true; -} - -static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( - ::mlir::Operation* inst) { - // We pass empty string for the original node_def name since Flex runtime - // does not care about this being set correctly on node_def. There is no - // "easy" (see b/120948529) way yet to get this from MLIR inst. - auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( - inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); - if (!status_or_node_def.ok()) { - inst->emitOpError( - Twine("failed to obtain TensorFlow nodedef with status: " + - status_or_node_def.status().ToString())); - return {}; - } - return std::move(status_or_node_def.ValueOrDie()); -} - -// Converts a mlir padding StringRef to TfLitePadding. -// Returns llvm::None if conversion fails. -static Optional GetTflitePadding(Operation* inst, - llvm::StringRef padding) { - const tflite::Padding padding_attr = - std::move(llvm::StringSwitch(padding) - .Case("SAME", tflite::Padding_SAME) - .Case("VALID", tflite::Padding_VALID)); - if (padding_attr == tflite::Padding_SAME) { - return kTfLitePaddingSame; - } - if (padding_attr == tflite::Padding_VALID) { - return kTfLitePaddingValid; - } - - return inst->emitOpError() << "Invalid padding attribute: " << padding, - llvm::None; -} - -// Extracts TfLitePoolParams from a TFL custom op. -// Template parameter, TFLOp, should be a TFL custom op containing attributes -// generated from TfLitePoolParams. -// Returns llvm::None if conversion fails. -template -static Optional GetTflitePoolParams(Operation* inst, - TFLOp op) { - TfLitePoolParams pool_params; - pool_params.stride_height = op.stride_h().getSExtValue(); - pool_params.stride_width = op.stride_w().getSExtValue(); - pool_params.filter_height = op.filter_h().getSExtValue(); - pool_params.filter_width = op.filter_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - pool_params.padding = *padding; - pool_params.activation = kTfLiteActNone; - pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; - return pool_params; - } - - return llvm::None; -} - +namespace mlir { namespace { +static OwningModuleRef FlatBufferFileToMlirTrans( + llvm::SourceMgr* source_mgr, MLIRContext* context, + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { + const llvm::MemoryBuffer* input = + source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); + std::string error; + auto loc = + mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); -// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. -class Translator { - public: - // Translates the given MLIR module into TFLite FlatBuffer format and returns - // the serialized output. Returns llvm::None on unsupported, invalid inputs or - // internal error. - static Optional Translate( - ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); - - private: - enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; - explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, - bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) - : module_(module), - name_mapper_(*op_or_arg_name_mapper), - builder_(kInitialBufferSize) { - // The first buffer must be empty according to the schema definition. - empty_buffer_ = tflite::CreateBuffer(builder_); - buffers_.push_back(empty_buffer_); - if (emit_builtin_tflite_ops) { - enabled_op_types_.emplace(OpType::kTfliteBuiltin); - } - if (emit_select_tf_ops) { - enabled_op_types_.emplace(OpType::kSelectTf); - } - if (emit_custom_ops) { - enabled_op_types_.emplace(OpType::kCustomOp); - } - tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); - tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); + // Parses input/output names from command line options. + std::vector inputs; + std::vector outputs; + // Use output parser since we only have tensor names. + if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) { + return emitError(loc, "parsing input array info failed ") + << input_arrays_flag, + nullptr; } - - Optional TranslateInternal(); - - // Returns TFLite buffer populated with constant value if the operation is - // TFLite constant operation. Otherwise, returns an empty buffer. Emits error - // and returns llvm::None on failure. - Optional> BuildBuffer(Operation* inst); - - // Build TFLite tensor from the given type. This function is for tfl.lstm - // intermediates, which should have UniformQuantizedType. - Optional> BuildTensorFromType( - mlir::Type type, const std::string& name); - - // Builds TFLite tensor from the given value. `buffer_idx` is index of the - // corresponding buffer. Emits error and returns llvm::None on failure. - Optional> BuildTensor(Value value, - const std::string& name, - unsigned buffer_idx); - - // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove - // these 2 functions here. - BufferOffset BuildIfOperator( - mlir::TF::IfOp op, const std::vector& operands, - const std::vector& results); - BufferOffset BuildWhileOperator( - mlir::TF::WhileOp op, const std::vector& operands, - const std::vector& results); - - // Build while operator where cond & body are regions. - Optional> BuildWhileOperator( - mlir::TFL::WhileOp op, const std::vector& operands, - const std::vector& results); - - // Builds custom operators. - // Templated on a) data type of custom_option to be stored into flatbuffer, - // and b) TFL custom op type. - template - BufferOffset BuildCustomOperator( - const CustomOptionType& custom_option, const std::string& opcode_name, - TFLOp op, const std::vector& operands, - const std::vector& results); - - BufferOffset BuildNumericVerifyOperator( - mlir::TFL::NumericVerifyOp op, const std::vector& operands, - const std::vector& results); - Optional> - BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxUnpooling2DOperator( - Operation* inst, mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results); - - Optional CreateFlexOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - Optional CreateCustomOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - std::unique_ptr CreateFlexBuilderWithNodeAttrs( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - // Returns opcode index for op identified by the op_name, if already - // available. Otherwise, creates a new OperatorCode using the given `builtin` - // operator and associates it with `op_name`. - uint32_t GetOpcodeIndex(const std::string& op_name, - tflite::BuiltinOperator builtin); - - // Builds operator for the given operation with specified operand and result - // tensor indices. Emits an error and returns llvm::None on failure. - Optional> BuildOperator( - Operation* inst, const std::vector& operands, - const std::vector& results, - const std::vector& intermediates); - - // Build a subgraph with a given name out of the region either corresponding - // to a function's body or while op. - Optional> BuildSubGraph( - const std::string& name, Region* region); - - // Builds Metadata with the given `name` and buffer `content`. - BufferOffset BuildMetadata(StringRef name, - StringRef content); - - // Encodes the `tfl.metadata` dictionary attribute of the module to the - // metadata section in the final model. - Optional>> - CreateMetadataVector(); - - // Uses the tf.entry_function attribute (if set) to initialize the op to name - // mapping. - void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); - - // Determines if the specified operation op's operand at operand_index - // is marked as a stateful operand. - bool IsStatefulOperand(mlir::Operation* op, int operand_index); - - // Returns a unique name for `val`. - std::string UniqueName(mlir::Value val); - - ModuleOp module_; - - tensorflow::OpOrArgNameMapper& name_mapper_; - - flatbuffers::FlatBufferBuilder builder_; - BufferOffset empty_buffer_; - - std::vector> buffers_; - - // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. - absl::flat_hash_map opcode_index_map_; - std::vector> opcodes_; - - // Maps function name to index of the corresponding subgraph in the FlatBuffer - // model. - absl::flat_hash_map subgraph_index_map_; - absl::flat_hash_set enabled_op_types_; - - // Points to TensorFlow and TFLite dialects, respectively. nullptr if the - // dialect is not registered. - const Dialect* tf_dialect_; - const Dialect* tfl_dialect_; - - // The failed ops during legalization. - std::set failed_flex_ops_; - std::set failed_custom_ops_; -}; - -std::string Translator::UniqueName(mlir::Value val) { - return std::string(name_mapper_.GetUniqueName(val)); + if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) { + return emitError(loc, "parsing output array info failed ") + << output_arrays_flag, + nullptr; + } + return tflite::FlatBufferToMlir( + absl::string_view(input->getBufferStart(), input->getBufferSize()), + context, loc, use_external_constant, inputs, outputs, + experimental_prune_unreachable_nodes_unconditionally); } -Optional> Translator::BuildBuffer( - Operation* inst) { - ElementsAttr attr; - if (auto cst = dyn_cast(inst)) { - // ConstantOp have ElementAttr at this point due to validation of the TFLite - // module. - attr = cst.getValue().cast(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else { - return empty_buffer_; - } - - tensorflow::Tensor tensor; - auto status = tensorflow::ConvertToTensor(attr, &tensor); - if (!status.ok()) { - inst->emitError( - Twine("failed to convert value attribute to tensor with error: " + - status.ToString())); - return llvm::None; - } - - // TensorFlow and TensorFlow Lite use different string encoding formats. - // Convert to TensorFlow Lite format is it's a constant string tensor. - if (tensor.dtype() == tensorflow::DT_STRING) { - ::tflite::DynamicBuffer dynamic_buffer; - auto flat = tensor.flat<::tensorflow::tstring>(); - for (int i = 0; i < flat.size(); ++i) { - const auto& str = flat(i); - dynamic_buffer.AddString(str.c_str(), str.length()); - } - char* tensor_buffer; - int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); - auto buffer_data = - builder_.CreateVector(reinterpret_cast(tensor_buffer), bytes); - free(tensor_buffer); - return tflite::CreateBuffer(builder_, buffer_data); - } - - absl::string_view tensor_data = tensor.tensor_data(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(tensor_data.data()), tensor_data.size()); - return tflite::CreateBuffer(builder_, buffer_data); -} - -Optional> Translator::BuildTensorFromType( - mlir::Type type, const std::string& name) { - auto tensor_type = type.cast(); - - if (!tensor_type.hasStaticShape()) { - return llvm::None; - } - llvm::ArrayRef shape_ref = tensor_type.getShape(); - std::vector shape(shape_ref.begin(), shape_ref.end()); - - auto element_type = tensor_type.getElementType(); - tflite::TensorType tflite_element_type = - GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); - BufferOffset q_params; - auto qtype = element_type.dyn_cast(); - if (!qtype) { - return llvm::None; - } - q_params = tflite::CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, - builder_.CreateVector({static_cast(qtype.getScale())}), - builder_.CreateVector({qtype.getZeroPoint()})); - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - /*buffer=*/0, builder_.CreateString(name), q_params, - /*is_variable=*/false); -} - -Optional> Translator::BuildTensor( - Value value, const std::string& name, unsigned buffer_idx) { - auto type = value.getType().cast(); - - // TFLite requires tensor shape only for the inputs and constants. - // However, we output all known shapes for better round-tripping - auto check_shape = - [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { - auto is_out_of_range = [](int64_t dim) { - return dim > std::numeric_limits::max(); - }; - - if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) - return mlir::emitError( - value.getLoc(), - "result shape dimensions out of 32 bit int type range"); - - return mlir::success(); - }; - - std::vector shape; - std::vector shape_signature; - if (type.hasStaticShape()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (auto* inst = value.getDefiningOp()) { - if (IsConst(inst)) { - // Const op can have a result of dynamic shaped type (e.g. due to constant - // folding), but we can still derive the shape of a constant tensor for - // its attribute type. - mlir::Attribute tensor_attr = inst->getAttr("value"); - llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } - } else if (type.hasRank()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape.reserve(shape_ref.size()); - for (auto& dim : shape_ref) { - shape.push_back(dim == -1 ? 1 : dim); - } - shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); - } - - if (auto* inst = value.getDefiningOp()) { - if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); - } else if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); - } - } - - Type element_type = type.getElementType(); - tflite::TensorType tflite_element_type = - GetTFLiteType(type.getElementType()).ValueOrDie(); - - BufferOffset q_params; - if (auto qtype = element_type.dyn_cast()) { - q_params = tflite::CreateQuantizationParameters( - // TODO(fengliuai): min and max values are not stored in the - // quantized type, so both are set to 0. The model couldn't be imported - // to TensorFlow because of this. - builder_, /*min=*/0, /*max=*/0, - builder_.CreateVector({static_cast(qtype.getScale())}), - builder_.CreateVector({qtype.getZeroPoint()})); - } else if (auto qtype = - element_type - .dyn_cast()) { - std::vector scales(qtype.getScales().begin(), - qtype.getScales().end()); - q_params = tflite::CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), - builder_.CreateVector(qtype.getZeroPoints()), - tflite::QuantizationDetails_NONE, /*details=*/0, - qtype.getQuantizedDimension()); - } else { - q_params = tflite::CreateQuantizationParameters(builder_); - } - // Check if the value's uses includes an op and usage at an operand index - // marked as a stateful. If so, set the tensor's is_variable as true - // This is v1 ref variable semantics in the TFLite runtime. - bool is_variable = false; - for (auto& use : value.getUses()) { - is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); - if (is_variable) { - break; - } - } - - if (shape_signature.empty()) { - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable); - } else { - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable, /*sparsity=*/0, - /*shape_signature=*/builder_.CreateVector(shape_signature)); - } -} - -BufferOffset Translator::BuildIfOperator( - mlir::TF::IfOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); - int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); - int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); - auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, - else_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_IfOptions, - builtin_options); -} - -BufferOffset Translator::BuildWhileOperator( - mlir::TF::WhileOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); - int body_subgraph_index = subgraph_index_map_.at(op.body().str()); - auto builtin_options = tflite::CreateWhileOptions( - builder_, cond_subgraph_index, body_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_WhileOptions, - builtin_options); -} - -Optional> Translator::BuildWhileOperator( - mlir::TFL::WhileOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - auto get_call_index = [&](mlir::Block& b) -> Optional { - if (b.getOperations().size() != 2) return llvm::None; - if (auto call_op = dyn_cast(b.front())) - return subgraph_index_map_.at(call_op.callee().str()); - return llvm::None; - }; - auto body_subgraph_index = get_call_index(op.body().front()); - auto cond_subgraph_index = get_call_index(op.cond().front()); - if (!body_subgraph_index || !cond_subgraph_index) - return op.emitOpError("only single call cond/body while export supported"), - llvm::None; - auto builtin_options = - tflite::CreateWhileOptions(builder_, *cond_subgraph_index, - *body_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_WhileOptions, - builtin_options); -} - -template -BufferOffset Translator::BuildCustomOperator( - const CustomOptionType& custom_option, const std::string& opcode_name, - TFLOp op, const std::vector& operands, - const std::vector& results) { - std::vector custom_option_vector(sizeof(CustomOptionType)); - memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType)); - auto opcode_index = - GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM); - return tflite::CreateOperator( - builder_, opcode_index, builder_.CreateVector(operands), - builder_.CreateVector(results), tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - builder_.CreateVector(custom_option_vector), - tflite::CustomOptionsFormat_FLEXBUFFERS); -} - -BufferOffset Translator::BuildNumericVerifyOperator( - mlir::TFL::NumericVerifyOp op, const std::vector& operands, - const std::vector& results) { - float tolerance = op.tolerance().convertToFloat(); - return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); -} - -Optional> -Translator::BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, const std::vector& results) { - TfLiteTransposeConvParams conv_params; - conv_params.stride_height = op.stride_h().getSExtValue(); - conv_params.stride_width = op.stride_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - conv_params.padding = *padding; - return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxUnpooling2DOperator(Operation* inst, - mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, - results); - } - - return llvm::None; -} - -Optional Translator::CreateFlexOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } - - auto flex_builder = absl::make_unique(); - flex_builder->Vector([&]() { - flex_builder->String(node_def.op()); - flex_builder->String(node_def_str); - }); - flex_builder->Finish(); - return builder_.CreateVector(flex_builder->GetBuffer()); -} - -Optional Translator::CreateCustomOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } - auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); - return builder_.CreateVector(flex_builder->GetBuffer()); -} - -std::unique_ptr -Translator::CreateFlexBuilderWithNodeAttrs( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - auto flex_builder = absl::make_unique(); - size_t map_start = flex_builder->StartMap(); - for (const auto& pair : node_def.attr()) { - const char* key = pair.first.c_str(); - const auto& attr = pair.second; - switch (attr.value_case()) { - case ::tensorflow::AttrValue::kS: - flex_builder->String(key, attr.s()); - break; - case ::tensorflow::AttrValue::kType: { - auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); - if (status_or_tfl_type.ok()) { - flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); - } else { - emitWarning(loc, "ignoring unsupported tensorflow type: ") - << std::to_string(attr.type()); - } - break; - } - case ::tensorflow::AttrValue::kI: - flex_builder->Int(key, attr.i()); - break; - case ::tensorflow::AttrValue::kF: - flex_builder->Float(key, attr.f()); - break; - case ::tensorflow::AttrValue::kB: - flex_builder->Bool(key, attr.b()); - break; - case tensorflow::AttrValue::kList: - if (attr.list().s_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const std::string& v : attr.list().s()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else if (attr.list().i_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const int64_t v : attr.list().i()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else if (attr.list().f_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const float v : attr.list().f()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else { - emitWarning(loc, - "ignoring unsupported type in list attribute with key: ") - << key; - } - break; - default: - emitWarning(loc, "ignoring unsupported attribute type with key: ") - << key; - break; - } - } - flex_builder->EndMap(map_start); - flex_builder->Finish(); - return flex_builder; -} - -uint32_t Translator::GetOpcodeIndex(const std::string& op_name, - tflite::BuiltinOperator builtin) { - auto it = opcode_index_map_.insert({op_name, 0}); - - // If the insert succeeded, the opcode has not been created already. Create a - // new operator code and update its index value in the map. - if (it.second) { - it.first->second = opcodes_.size(); - auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM - ? builder_.CreateString(op_name) - : BufferOffset(); - // Use version 0 for builtin op. This is a way to serialize version field to - // flatbuffer (since 0 is non default) and it will be corrected later. - int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; - opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, - custom_code, op_version)); - } - return it.first->second; -} - -Optional> Translator::BuildOperator( - Operation* inst, const std::vector& operands, - const std::vector& results, - const std::vector& intermediates) { - const auto* dialect = inst->getDialect(); - if (!dialect) { - inst->emitOpError("dialect is not registered"); - return llvm::None; - } - - // If TFLite built in op, create operator as a builtin op. - if (dialect == tfl_dialect_) { - // Only if built-in TFLite op emission is enabled, would legalization have - // converted any TF->TFL. - if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { - return inst->emitOpError( - "is a TFLite builtin op but builtin emission is not enabled"), - llvm::None; - } - - auto builtin_code = GetBuiltinOpCode(inst); - if (!builtin_code) { - if (auto verify_op = dyn_cast(inst)) { - return BuildNumericVerifyOperator(verify_op, operands, results); - } - if (auto conv_transpose_bias_op = - dyn_cast(inst)) { - return BuildConvolution2DTransposeBiasOperator( - inst, conv_transpose_bias_op, operands, results); - } - if (auto max_pooling_with_arg_max_op = - dyn_cast(inst)) { - return BuildMaxPoolingWithArgMax2DOperator( - inst, max_pooling_with_arg_max_op, operands, results); - } - if (auto max_unpooling_op = dyn_cast(inst)) { - return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, - results); - } - if (auto whileOp = dyn_cast(inst)) { - if (inst->getNumOperands() != inst->getNumResults()) { - inst->emitOpError( - "number of operands and results don't match, only canonical " - "TFL While supported"); - return llvm::None; - } - return BuildWhileOperator(whileOp, operands, results); - } - - inst->emitOpError("is not a supported TFLite op"); - return llvm::None; - } - - std::string op_name = inst->getName().getStringRef().str(); - uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); - auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, - results, intermediates, &builder_); - if (!offset) { - inst->emitOpError("is not a supported TFLite op"); - } - return offset; - } - - if (dialect == tf_dialect_) { - std::string op_name; - if (auto ifOp = dyn_cast(inst)) { - return BuildIfOperator(ifOp, operands, results); - } else if (auto whileOp = dyn_cast(inst)) { - return BuildWhileOperator(whileOp, operands, results); - } - - CustomOptionsOffset custom_options; - - // Ops in TF dialect can either be custom ops or flex ops. - // The reason we go directly from TensorFlow dialect MLIR to tensorflow - // node instead of going to TF table gen'd ops via generated code is that - // we do not want to restrict custom and flex op conversion support to - // only those TF ops that are currently registered in MLIR. The current - // model is of an open op system. - // - // The following algorithm is followed: - // if flex is enabled and the op is whitelisted as flex - // we emit op as flex. - // if custom is enabled - // we emit the op as custom. - auto node_def = GetTensorFlowNodeDef(inst); - if (!node_def) { - return llvm::None; - } - - // Flex op case - // Eventually, the whitelist will go away and we will rely on some TF op - // trait (e.g. No side effect) to determine if it is a supported "Flex" - // op or not. - if (enabled_op_types_.contains(OpType::kSelectTf) && - IsWhitelistedFlexOp(node_def->op())) { - // Construct ops as flex op encoding TensorFlow node definition - // as custom options. - // Flex ops are named with the kFlexOpNamePrefix prefix to the actual - // TF op name. - op_name = std::string(kFlexOpNamePrefix) + node_def->op(); - if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else if (enabled_op_types_.contains(OpType::kCustomOp)) { - // Generic case of custom ops - write using flex buffers since that - // is the only custom options supported by TFLite today. - op_name = node_def->op(); - if (auto options = - CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else { - // Create description of operation that could not be converted. - const int kLargeElementsAttr = 16; - std::string op_str; - llvm::raw_string_ostream os(op_str); - inst->getName().print(os); - // Print out attributes except for large elementsattributes (which should - // rarely be the cause why the legalization didn't happen). - if (!inst->getAttrList().getAttrs().empty()) { - os << " {"; - bool first = true; - for (auto& named_attr : inst->getAttrList().getDictionary()) { - os << (!first ? ", " : ""); - first = false; - named_attr.first.print(os); - os << " = "; - if (auto element_attr = named_attr.second.dyn_cast()) { - if (element_attr.getNumElements() <= kLargeElementsAttr) { - element_attr.print(os); - } else { - os << ""; - } - } else { - named_attr.second.print(os); - } - } - os << "}"; - } - - // Insert failed op to `flex_ops` or `custom_ops`. - if (IsWhitelistedFlexOp(node_def->op())) { - failed_flex_ops_.insert(os.str()); - } else { - failed_custom_ops_.insert(os.str()); - } - return inst->emitOpError("is neither a custom op nor a flex op"), - llvm::None; - } - - uint32_t opcode_index = - GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - /*custom_options=*/custom_options, - tflite::CustomOptionsFormat_FLEXBUFFERS, - /*mutating_variable_inputs=*/0); - } - - return inst->emitOpError( - "is not any of a builtin TFLite op, a flex TensorFlow op or a " - "custom TensorFlow op"), - llvm::None; -} - -void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { - auto dict_attr = fn.getAttrOfType("tf.entry_function"); - if (!dict_attr) return; - - llvm::SmallVector input_names; - llvm::SmallVector output_names; - if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { - str.getValue().split(input_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - if (input_names.size() != fn.getNumArguments()) { - fn.emitWarning() << "invalid entry function specification"; - return; - } - for (auto it : llvm::enumerate(fn.getArguments())) { - name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); - } - *has_input_attr = true; - } - - if (auto str = - dict_attr.get("outputs").dyn_cast_or_null()) { - str.getValue().split(output_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - auto term = fn.getBlocks().back().getTerminator(); - if (output_names.size() != term->getNumOperands()) { - fn.emitWarning() << "output names (" << output_names.size() - << ") != terminator operands (" << term->getNumOperands() - << ")"; - return; - } - for (const auto& it : llvm::enumerate(term->getOperands())) { - name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); - } - } -} - -bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { - std::vector operand_indices; - if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; - return absl::c_find(operand_indices, operand_index) != operand_indices.end(); -} - -Optional> Translator::BuildSubGraph( - const std::string& name, Region* region) { - bool has_input_attr = false; - if (auto fn = dyn_cast(region->getParentOp())) { - InitializeNamesFromAttribute(fn, &has_input_attr); - } - std::vector> tensors; - llvm::DenseMap tensor_index_map; - - // Builds tensor and buffer for argument or operation result. Returns false - // on failure. - auto build_tensor_and_buffer = [&](Value value, const std::string& name) { - // NoneType represents optional and may be skipped here. - if (value.getType().isa()) { - return true; - } - - tensor_index_map.insert({value, tensors.size()}); - auto tensor_or = BuildTensor(value, name, buffers_.size()); - if (!tensor_or) return false; - tensors.push_back(*tensor_or); - - // TODO(ashwinm): Check if for stateful tensors, if it is also needed to - // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. - // This does not seem to affect runtime behavior for RNN/LSTM, but would be - // good for reducing memory footprint. - if (auto* inst = value.getDefiningOp()) { - auto buffer_or = BuildBuffer(inst); - if (!buffer_or) return false; - buffers_.push_back(*buffer_or); - } else { - buffers_.push_back(empty_buffer_); - } - return true; - }; - - std::vector> operators; - auto& bb = region->front(); - - // Main function's arguments are first passed to `input` op so they don't - // have associated tensor and buffer. Build FlatBuffer tensor and buffer for - // other functions. - for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { - mlir::BlockArgument arg = bb.getArgument(i); - std::string name; - if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg)); - if (name.empty()) name = absl::StrCat("arg", i); - if (!build_tensor_and_buffer(arg, name)) return llvm::None; - } - - bool failed_once = false; - for (auto& inst : bb) { - if (inst.isKnownTerminator()) break; - std::vector intermediates; - // Build intermediate tensors for tfl.lstm and insert these tensors into - // flatbuffer. - if (llvm::isa(inst)) { - std::vector intermediate_names = { - "input_to_input_intermediate", "input_to_forget_intermediate", - "input_to_cell_intermediate", "input_to_output_intermediate", - "effective_hidden_scale_intermediate"}; - for (const std::string& intermediate : intermediate_names) { - auto intermediate_attr = inst.getAttr(intermediate); - if (auto attr = intermediate_attr.dyn_cast_or_null()) { - Type qtype = attr.getValue(); - auto tensor_or = BuildTensorFromType( - qtype, name_mapper_.GetUniqueName(intermediate).str()); - if (!tensor_or.hasValue()) { - continue; - } else { - intermediates.push_back(tensors.size()); - tensors.push_back(tensor_or.getValue()); - } - } - } - } - - for (auto val : inst.getResults()) { - std::string name = UniqueName(val); - if (!build_tensor_and_buffer(val, name)) return llvm::None; - } - - // Skip constant ops as they don't represent a TFLite operator. - if (IsConst(&inst)) continue; - - // Fetch operand and result tensor indices. - std::vector operands; - operands.reserve(inst.getNumOperands()); - for (auto operand : inst.getOperands()) { - if (operand.getType().isa()) - operands.push_back(kTfLiteOptionalTensor); - else - operands.push_back(tensor_index_map.lookup(operand)); - } - std::vector results; - results.reserve(inst.getNumOperands()); - for (auto result : inst.getResults()) { - results.push_back(tensor_index_map.lookup(result)); - } - - if (auto tfl_operator = - BuildOperator(&inst, operands, results, intermediates)) - operators.push_back(*tfl_operator); - else - failed_once = true; - } - - if (failed_once) return llvm::None; - - // Get input and output tensor indices for the subgraph. - std::vector inputs, outputs; - for (auto arg : bb.getArguments()) { - inputs.push_back(tensor_index_map[arg]); - } - for (auto result : bb.getTerminator()->getOperands()) { - outputs.push_back(tensor_index_map[result]); - } - - return tflite::CreateSubGraph( - builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), - builder_.CreateVector(outputs), builder_.CreateVector(operators), - /*name=*/builder_.CreateString(name)); -} - -BufferOffset Translator::BuildMetadata(StringRef name, - StringRef content) { - auto buffer_index = buffers_.size(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(content.data()), content.size()); - buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); - return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); -} - -Optional>> -Translator::CreateMetadataVector() { - auto dict_attr = module_.getAttrOfType("tfl.metadata"); - std::vector> metadata; - if (dict_attr) { - for (const auto& named_attr : dict_attr) { - StringRef name = named_attr.first; - mlir::Attribute attr = named_attr.second; - if (auto content = attr.dyn_cast()) { - metadata.push_back(BuildMetadata(name, content.getValue())); - } else { - module_.emitError( - "all values in tfl.metadata's dictionary key-value pairs should be " - "string attributes"); - return llvm::None; - } - } - } - // Runtime version string is generated after we update the op - // versions. Here we put a 16-byte dummy string as a placeholder. We choose - // 16-byte because it's the alignment of buffers in flatbuffer, so it won't - // cause any waste of space if the actual string is shorter than 16 bytes. - metadata.push_back( - BuildMetadata("min_runtime_version", std::string(16, '\0'))); - return builder_.CreateVector(metadata); -} - -bool UpdateEntryFunction(ModuleOp module) { - if (module.lookupSymbol("main") != nullptr) { - // We already have an entry function. - return true; - } - - int entry_func_count = 0; - FuncOp entry_func = nullptr; - for (auto fn : module.getOps()) { - auto attrs = fn.getAttrOfType("tf.entry_function"); - if (attrs && !attrs.empty()) { - entry_func_count++; - entry_func = fn; - } - } - - // We should have one & only have one entry function. - if (entry_func_count != 1) return false; - - // Update the entry func to main. - entry_func.setName("main"); - return true; -} - -Optional Translator::Translate( - ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { - if (!UpdateEntryFunction(module)) return llvm::None; - if (!IsValidTFLiteMlirModule(module)) return llvm::None; - Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - return translator.TranslateInternal(); -} - -Optional Translator::TranslateInternal() { - // A list of named regions in the module with main function being the first in - // the list. The main function is required as the first subgraph in the model - // is entry point for the model. - std::vector> named_regions; - named_regions.reserve(std::distance(module_.begin(), module_.end())); - - int subgraph_idx = 0; - FuncOp main_fn = module_.lookupSymbol("main"); - subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back("main", &main_fn.getBody()); - // Walk over the module collection ops with functions and while ops. - module_.walk([&](FuncOp fn) { - if (fn != main_fn) { - subgraph_index_map_[fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back(fn.getName().str(), &fn.getBody()); - } - }); - - // Build subgraph for each of the named regions. - std::vector> subgraphs; - subgraphs.reserve(named_regions.size()); - int first_failed_func = -1; - for (auto it : llvm::enumerate(named_regions)) { - auto subgraph_or = BuildSubGraph(it.value().first, it.value().second); - if (!subgraph_or) { - if (first_failed_func == -1) - // Record the index of the first region that cannot be converted. - // Keep looping through all subgraphs in the module to make sure that - // we collect the list of missing ops from the entire module. - first_failed_func = it.index(); - } else { - subgraphs.push_back(*subgraph_or); - } - } - - if (first_failed_func != -1) { - std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); - std::string failed_custom_ops_list = - absl::StrJoin(failed_custom_ops_, "\n\t"); - std::string err; - if (!failed_flex_ops_list.empty()) - err += - "Ops that can be supported by the flex runtime (enabled via setting " - "the -emit-select-tf-ops flag):\n\t" + - failed_flex_ops_list; - if (!failed_custom_ops_list.empty()) - err += - "Ops that need custom implementation (enabled via setting the " - "-emit-custom-ops flag):\n\t" + - failed_custom_ops_list; - - auto& failed_region = named_regions[first_failed_func]; - return failed_region.second->getParentOp()->emitError() - << "failed while converting: '" << failed_region.first - << "': " << err, - llvm::None; - } - - std::string model_description; - if (auto attr = module_.getAttrOfType("tfl.description")) { - model_description = attr.getValue().str(); - } else { - model_description = "MLIR Converted."; - } - - // Build the model and finish the model building process. - auto description = builder_.CreateString(model_description.data()); - VectorBufferOffset metadata_buffer = 0; // Deprecated - auto metadata = CreateMetadataVector(); - if (!metadata) return llvm::None; - - auto model = tflite::CreateModel( - builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), - builder_.CreateVector(subgraphs), description, - builder_.CreateVector(buffers_), metadata_buffer, *metadata); - tflite::FinishModelBuffer(builder_, model); - tflite::UpdateOpVersion(builder_.GetBufferPointer()); - tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); - - // Return serialized string for the built FlatBuffer. - return std::string(reinterpret_cast(builder_.GetBufferPointer()), - builder_.GetSize()); -} - -} // namespace - -// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer -// format. Returns false on success. -// -// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting -// the following: -// -// * Quantization -// * Ops with variable tensors -// -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) { - auto maybe_translated = - Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - if (!maybe_translated) return true; - *serialized_flatbuffer = std::move(*maybe_translated); - return false; -} - -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops) { - OpOrArgLocNameMapper op_or_arg_name_mapper; - return MlirToFlatBufferTranslateFunction( - module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); -} - -static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction( +static LogicalResult MlirToFlatBufferFileTranslateFunction( ModuleOp module, llvm::raw_ostream& output) { std::string serialized_flatbuffer; - std::unique_ptr op_or_arg_name_mapper; + std::unique_ptr op_or_arg_name_mapper; if (strip_debug_info) { op_or_arg_name_mapper = std::make_unique(); } else { - op_or_arg_name_mapper = std::make_unique(); + op_or_arg_name_mapper = + std::make_unique(); } if (tflite::MlirToFlatBufferTranslateFunction( module, &serialized_flatbuffer, emit_builtin_tflite_ops, @@ -1511,8 +162,18 @@ static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction( return mlir::failure(); output << serialized_flatbuffer; - return mlir::success(); + return success(); } +} // namespace + +static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( + "tflite-flatbuffer-to-mlir", + [](llvm::SourceMgr& source_mgr, MLIRContext* context) { + return FlatBufferFileToMlirTrans( + &source_mgr, context, use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); + }); static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate( "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction); +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index e635885801e..0d42fbb9646 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -34,8 +34,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/delegates/flex/delegate.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index f961b037a6c..6eb72dab2fc 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index 806c0353ed9..a96c65cd450 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index fd7d95e1e33..a178519e0fe 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -28,8 +28,8 @@ limitations under the License. #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index b1c6cbc8d82..7c0a91d6d4e 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 00206373872..fe2cdb2748d 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -20,7 +20,8 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.h b/tensorflow/compiler/mlir/lite/utils/convert_type.h index c4d9f98a02c..3ae58d565e1 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.h +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.h @@ -18,7 +18,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 7b088cad715..4c383beac52 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -811,7 +811,8 @@ cc_library( srcs = ["utils/error_util.cc"], hdrs = ["utils/error_util.h"], deps = [ - "//tensorflow/core:lib", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", ], diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc index 60646ae764e..5514a788996 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index abef0de4585..4feb3837357 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -21,7 +21,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" // Error utilities for MLIR when interacting with code using Status returns. namespace mlir { From 60d6ea479e2ceddbccec073f3aa2182aea3247b0 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Mon, 23 Mar 2020 18:36:05 -0700 Subject: [PATCH 478/492] TF2XLA: BatchMatMulV2: add adj_x/adj_y support PiperOrigin-RevId: 302565809 Change-Id: Ib325e819e7ce913bc59deed6aedc5a29e0a28344 --- .../compiler/mlir/xla/tests/legalize-tf.mlir | 37 +++++++++++++++++++ .../mlir/xla/transforms/legalize_tf.cc | 20 +++++----- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index f30bd961fca..d8a1a156b0c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -3755,3 +3755,40 @@ func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) return %0 : tensor } +// CHECK-LABEL: func @batchmatmulv2_adj_real +func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { + // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>) -> tensor<5x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<2x4xf32> + // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { + // CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, + // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>, + // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: }} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> + // CHECK: return [[BDST]] : tensor<5x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> + return %0 : tensor<5x4xf32> +} + +// CHECK-LABEL: func @batchmatmulv2_adj_complex +func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { + // CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex>) -> tensor<5x2xf32> + // CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex>) -> tensor<5x2xf32> + // CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.neg"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32> + // CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex> + // CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> + // CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> + // CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.neg"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> + // CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex> + // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"([[LHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex>) -> tensor<5x2xcomplex> + // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"([[RHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex>) -> tensor<2x4xcomplex> + // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { + // CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, + // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>, + // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: }} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> + // CHECK: return [[BDST]] : tensor<5x4xcomplex> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> + return %0 : tensor<5x4xcomplex> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 817dfb55ec9..65704ca8dec 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -1629,17 +1629,17 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, PatternRewriter &rewriter) const override { - // TODO(silvasean): Handle adj_x/adj_y - // Should be able to just set the contracting_dimensions attribute - // appropriately. - // For complex types, need to do a complex conjugation. - if (op.adj_x() || op.adj_y()) return failure(); - Value lhs = op.x(); Value rhs = op.y(); auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) return failure(); + if (lhs_type.getElementType().isa() && op.adj_x()) { + lhs = rewriter.create(op.getLoc(), lhs_type, lhs); + } + if (rhs_type.getElementType().isa() && op.adj_y()) { + rhs = rewriter.create(op.getLoc(), rhs_type, rhs); + } // TODO(silvasean): Support dynamic shapes. if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) { return failure(); @@ -1654,10 +1654,10 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { int64_t rank = lhs_type.getRank(); auto batch_dimensions = GetI64ElementsAttr( llvm::to_vector<4>(llvm::seq(0, rank - 2)), &rewriter); - auto lhs_contracting_dimensions = - GetI64ElementsAttr(llvm::makeArrayRef({rank - 1}), &rewriter); - auto rhs_contracting_dimensions = - GetI64ElementsAttr(llvm::makeArrayRef({rank - 2}), &rewriter); + auto lhs_contracting_dimensions = GetI64ElementsAttr( + llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter); + auto rhs_contracting_dimensions = GetI64ElementsAttr( + llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter); auto dimension_numbers = DotDimensionNumbers::get( /*lhs_batching_dimensions=*/batch_dimensions, /*rhs_batching_dimensions=*/batch_dimensions, From e918c6e6fab5d0005fcde83d57e92b70343d3553 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Mon, 23 Mar 2020 18:51:46 -0700 Subject: [PATCH 479/492] Fixing a memory leak in Keras. Fixes: https://github.com/tensorflow/tensorflow/issues/37515 PiperOrigin-RevId: 302568217 Change-Id: I28d0eaf3602fea0461901680df24899f135ce649 --- tensorflow/python/keras/engine/data_adapter.py | 9 ++++++--- tensorflow/python/keras/utils/data_utils.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index b0741acfe30..4dfc28fd40f 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -912,6 +912,7 @@ class KerasSequenceAdapter(GeneratorDataAdapter): self._size = len(x) self._shuffle_sequence = shuffle self._keras_sequence = x + self._enqueuer = None super(KerasSequenceAdapter, self).__init__( x, shuffle=False, # Shuffle is handed in the _make_callable override. @@ -929,11 +930,11 @@ class KerasSequenceAdapter(GeneratorDataAdapter): max_queue_size): if workers > 1 or (workers > 0 and use_multiprocessing): def generator_fn(): - enqueuer = data_utils.OrderedEnqueuer( + self._enqueuer = data_utils.OrderedEnqueuer( x, use_multiprocessing=use_multiprocessing, shuffle=self._shuffle_sequence) - enqueuer.start(workers=workers, max_queue_size=max_queue_size) - return enqueuer.get() + self._enqueuer.start(workers=workers, max_queue_size=max_queue_size) + return self._enqueuer.get() else: def generator_fn(): order = range(len(x)) @@ -954,6 +955,8 @@ class KerasSequenceAdapter(GeneratorDataAdapter): return True def on_epoch_end(self): + if self._enqueuer: + self._enqueuer.stop() self._keras_sequence.on_epoch_end() diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index 5224356e877..73ffc19d293 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -678,7 +678,7 @@ class SequenceEnqueuer(object): for data in datas: # Use the inputs; training, evaluating, predicting. # ... stop sometime. - enqueuer.close() + enqueuer.stop() ``` The `enqueuer.get()` should be an infinite stream of datas. From dcf212b2a5e06558683367807301ca19412e9892 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Mon, 23 Mar 2020 19:05:36 -0700 Subject: [PATCH 480/492] Separate `model.cc` into `model_builder.cc` and `interpreter_builder.cc`. PiperOrigin-RevId: 302570223 Change-Id: I9b26c21dc7db0a1fb225986db20b1ed66f6fcc82 --- tensorflow/lite/BUILD | 5 +- .../lite/{model.cc => interpreter_builder.cc} | 160 +------------ tensorflow/lite/interpreter_builder.h | 104 ++++++++ tensorflow/lite/model.h | 222 +----------------- tensorflow/lite/model_builder.cc | 204 ++++++++++++++++ tensorflow/lite/model_builder.h | 179 ++++++++++++++ 6 files changed, 496 insertions(+), 378 deletions(-) rename tensorflow/lite/{model.cc => interpreter_builder.cc} (78%) create mode 100644 tensorflow/lite/interpreter_builder.h create mode 100644 tensorflow/lite/model_builder.cc create mode 100644 tensorflow/lite/model_builder.h diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index f0e110cfaff..a2d8b40bbce 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -72,6 +72,8 @@ FRAMEWORK_LIB_HDRS = [ "graph_info.h", "interpreter.h", "model.h", + "model_builder.h", + "interpreter_builder.h", "mutable_op_resolver.h", "op_resolver.h", "optional_debug_tools.h", @@ -222,7 +224,8 @@ cc_library( "core/subgraph.cc", "graph_info.cc", "interpreter.cc", - "model.cc", + "interpreter_builder.cc", + "model_builder.cc", "mutable_op_resolver.cc", "optional_debug_tools.cc", "stderr_reporter.cc", diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/interpreter_builder.cc similarity index 78% rename from tensorflow/lite/model.cc rename to tensorflow/lite/interpreter_builder.cc index 25f196d272b..ef8f5a8773a 100644 --- a/tensorflow/lite/model.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/interpreter_builder.h" #include #include @@ -37,6 +37,7 @@ limitations under the License. namespace tflite { namespace { + // Ensure that ErrorReporter is non-null. ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { return e ? e : DefaultErrorReporter(); @@ -91,6 +92,7 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src, } return kTfLiteError; } + } // namespace const char* kEmptyTensorName = ""; @@ -114,162 +116,6 @@ __attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr; #endif -#ifndef TFLITE_MCU -// Loads a model from `filename`. If `mmap_file` is true then use mmap, -// otherwise make a copy of the model in a buffer. -std::unique_ptr GetAllocationFromFile(const char* filename, - bool mmap_file, - ErrorReporter* error_reporter, - bool use_nnapi) { - std::unique_ptr allocation; - if (mmap_file && MMAPAllocation::IsSupported()) { - allocation.reset(new MMAPAllocation(filename, error_reporter)); - } else { - allocation.reset(new FileCopyAllocation(filename, error_reporter)); - } - return allocation; -} - -std::unique_ptr FlatBufferModel::BuildFromFile( - const char* filename, ErrorReporter* error_reporter) { - error_reporter = ValidateErrorReporter(error_reporter); - - std::unique_ptr model; - auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, - error_reporter, /*use_nnapi=*/true); - model.reset(new FlatBufferModel(std::move(allocation), error_reporter)); - if (!model->initialized()) model.reset(); - return model; -} - -std::unique_ptr FlatBufferModel::VerifyAndBuildFromFile( - const char* filename, TfLiteVerifier* extra_verifier, - ErrorReporter* error_reporter) { - error_reporter = ValidateErrorReporter(error_reporter); - - std::unique_ptr model; - auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, - error_reporter, /*use_nnapi=*/true); - - flatbuffers::Verifier base_verifier( - reinterpret_cast(allocation->base()), - allocation->bytes()); - if (!VerifyModelBuffer(base_verifier)) { - TF_LITE_REPORT_ERROR(error_reporter, - "The model is not a valid Flatbuffer file"); - return nullptr; - } - - if (extra_verifier && - !extra_verifier->Verify(static_cast(allocation->base()), - allocation->bytes(), error_reporter)) { - return model; - } - model.reset(new FlatBufferModel(std::move(allocation), error_reporter)); - if (!model->initialized()) model.reset(); - return model; -} -#endif - -std::unique_ptr FlatBufferModel::BuildFromBuffer( - const char* caller_owned_buffer, size_t buffer_size, - ErrorReporter* error_reporter) { - error_reporter = ValidateErrorReporter(error_reporter); - - std::unique_ptr model; - std::unique_ptr allocation( - new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter)); - model.reset(new FlatBufferModel(std::move(allocation), error_reporter)); - if (!model->initialized()) model.reset(); - return model; -} - -std::unique_ptr FlatBufferModel::VerifyAndBuildFromBuffer( - const char* caller_owned_buffer, size_t buffer_size, - TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) { - error_reporter = ValidateErrorReporter(error_reporter); - - flatbuffers::Verifier base_verifier( - reinterpret_cast(caller_owned_buffer), buffer_size); - if (!VerifyModelBuffer(base_verifier)) { - TF_LITE_REPORT_ERROR(error_reporter, - "The model is not a valid Flatbuffer buffer"); - return nullptr; - } - - if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer, - buffer_size, error_reporter)) { - return nullptr; - } - - return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter); -} - -std::unique_ptr FlatBufferModel::BuildFromModel( - const tflite::Model* caller_owned_model_spec, - ErrorReporter* error_reporter) { - error_reporter = ValidateErrorReporter(error_reporter); - - std::unique_ptr model; - model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter)); - if (!model->initialized()) model.reset(); - return model; -} - -string FlatBufferModel::GetMinimumRuntime() const { - if (!model_ || !model_->metadata()) return ""; - - for (int i = 0; i < model_->metadata()->size(); ++i) { - auto metadata = model_->metadata()->Get(i); - if (metadata->name()->str() == "min_runtime_version") { - auto buf = metadata->buffer(); - auto* buffer = (*model_->buffers())[buf]; - auto* array = buffer->data(); - // Get the real length of the runtime string, since there might be - // trailing - // '\0's in the buffer. - for (int len = 0; len < array->size(); ++len) { - if (array->data()[len] == '\0') { - return string(reinterpret_cast(array->data()), len); - } - } - // If there is no '\0' in the buffer, this indicates that the flatbuffer - // is malformed. - TF_LITE_REPORT_ERROR( - error_reporter_, - "Min_runtime_version in model metadata is malformed"); - break; - } - } - return ""; -} - -bool FlatBufferModel::CheckModelIdentifier() const { - if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { - const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); - error_reporter_->Report( - "Model provided has model identifier '%c%c%c%c', should be '%s'\n", - ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); - return false; - } - return true; -} - -FlatBufferModel::FlatBufferModel(const Model* model, - ErrorReporter* error_reporter) - : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {} - -FlatBufferModel::FlatBufferModel(std::unique_ptr allocation, - ErrorReporter* error_reporter) - : error_reporter_(ValidateErrorReporter(error_reporter)), - allocation_(std::move(allocation)) { - if (!allocation_->valid() || !CheckModelIdentifier()) return; - - model_ = ::tflite::GetModel(allocation_->base()); -} - -FlatBufferModel::~FlatBufferModel() {} - namespace impl { InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, diff --git a/tensorflow/lite/interpreter_builder.h b/tensorflow/lite/interpreter_builder.h new file mode 100644 index 00000000000..1d150d6f1d4 --- /dev/null +++ b/tensorflow/lite/interpreter_builder.h @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// Deserialization infrastructure for tflite. Provides functionality +/// to go from a serialized tflite model in flatbuffer format to an +/// interpreter. +/// +#ifndef TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ +#define TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ + +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { + +namespace impl { + +/// Build an interpreter capable of interpreting `model`. +/// +/// model: A model whose lifetime must be at least as long as any +/// interpreter(s) created by the builder. In principle multiple interpreters +/// can be made from a single model. +/// op_resolver: An instance that implements the OpResolver interface, which +/// maps +/// custom op names and builtin op codes to op registrations. The lifetime +/// of the provided `op_resolver` object must be at least as long as the +/// InterpreterBuilder; unlike `model` and `error_reporter`, the `op_resolver` +/// does not need to exist for the duration of any created Interpreter +/// objects. +/// error_reporter: a functor that is called to report errors that handles +/// printf var arg semantics. The lifetime of the `error_reporter` object must +/// be greater than or equal to the Interpreter created by operator(). +/// +/// Returns a kTfLiteOk when successful and sets interpreter to a valid +/// Interpreter. Note: The user must ensure the model lifetime (and error +/// reporter, if provided) is at least as long as interpreter's lifetime. +class InterpreterBuilder { + public: + InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver); + /// Builds an interpreter given only the raw flatbuffer Model object (instead + /// of a FlatBufferModel). Mostly used for testing. + /// If `error_reporter` is null, then DefaultErrorReporter() is used. + InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter = DefaultErrorReporter()); + ~InterpreterBuilder(); + InterpreterBuilder(const InterpreterBuilder&) = delete; + InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; + TfLiteStatus operator()(std::unique_ptr* interpreter); + TfLiteStatus operator()(std::unique_ptr* interpreter, + int num_threads); + + private: + TfLiteStatus BuildLocalIndexToRegistrationMapping(); + TfLiteStatus ParseNodes( + const flatbuffers::Vector>* operators, + Subgraph* subgraph); + TfLiteStatus ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Subgraph* subgraph); + TfLiteStatus ApplyDelegates(Interpreter* interpreter); + TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, + TfLiteQuantization* quantization, + const std::vector& dims); + TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity, + TfLiteSparsity** sparsity); + + const ::tflite::Model* model_; + const OpResolver& op_resolver_; + ErrorReporter* error_reporter_; + + std::vector flatbuffer_op_index_to_registration_; + std::vector unresolved_custom_ops_; + std::vector flatbuffer_op_index_to_registration_types_; + const Allocation* allocation_ = nullptr; + + bool has_flex_op_ = false; +}; + +} // namespace impl + +} // namespace tflite + +#endif // TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h index fd196c049e9..1db7828f736 100644 --- a/tensorflow/lite/model.h +++ b/tensorflow/lite/model.h @@ -19,229 +19,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MODEL_H_ #define TENSORFLOW_LITE_MODEL_H_ -#include - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/mutable_op_resolver.h" -#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/model_builder.h" namespace tflite { -/// Abstract interface that verifies whether a given model is legit. -/// It facilitates the use-case to verify and build a model without loading it -/// twice. -class TfLiteVerifier { - public: - /// Returns true if the model is legit. - virtual bool Verify(const char* data, int length, - ErrorReporter* reporter) = 0; - virtual ~TfLiteVerifier() {} -}; - -/// An RAII object that represents a read-only tflite model, copied from disk, -/// or mmapped. This uses flatbuffers as the serialization format. -/// -/// NOTE: The current API requires that a FlatBufferModel instance be kept alive -/// by the client as long as it is in use by any dependent Interpreter -/// instances. -///

-/// using namespace tflite;
-/// StderrReporter error_reporter;
-/// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
-///                                             &error_reporter);
-/// MyOpResolver resolver;  // You need to subclass OpResolver to provide
-///                         // implementations.
-/// InterpreterBuilder builder(*model, resolver);
-/// std::unique_ptr interpreter;
-/// if(builder(&interpreter) == kTfLiteOk) {
-///   .. run model inference with interpreter
-/// }
-/// 
-/// -/// OpResolver must be defined to provide your kernel implementations to the -/// interpreter. This is environment specific and may consist of just the -/// builtin ops, or some custom operators you defined to extend tflite. -class FlatBufferModel { - public: - /// Builds a model based on a file. - /// Caller retains ownership of `error_reporter` and must ensure its lifetime - /// is longer than the FlatBufferModel instance. - /// Returns a nullptr in case of failure. - static std::unique_ptr BuildFromFile( - const char* filename, - ErrorReporter* error_reporter = DefaultErrorReporter()); - - /// Verifies whether the content of the file is legit, then builds a model - /// based on the file. - /// The extra_verifier argument is an additional optional verifier for the - /// file contents. By default, we always check with tflite::VerifyModelBuffer. - /// If extra_verifier is supplied, the file contents is also checked against - /// the extra_verifier after the check against tflite::VerifyModelBuilder. - /// Caller retains ownership of `error_reporter` and must ensure its lifetime - /// is longer than the FlatBufferModel instance. - /// Returns a nullptr in case of failure. - static std::unique_ptr VerifyAndBuildFromFile( - const char* filename, TfLiteVerifier* extra_verifier = nullptr, - ErrorReporter* error_reporter = DefaultErrorReporter()); - - /// Builds a model based on a pre-loaded flatbuffer. - /// Caller retains ownership of the buffer and should keep it alive until - /// the returned object is destroyed. Caller also retains ownership of - /// `error_reporter` and must ensure its lifetime is longer than the - /// FlatBufferModel instance. - /// Returns a nullptr in case of failure. - /// NOTE: this does NOT validate the buffer so it should NOT be called on - /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case - static std::unique_ptr BuildFromBuffer( - const char* caller_owned_buffer, size_t buffer_size, - ErrorReporter* error_reporter = DefaultErrorReporter()); - - /// Verifies whether the content of the buffer is legit, then builds a model - /// based on the pre-loaded flatbuffer. - /// The extra_verifier argument is an additional optional verifier for the - /// buffer. By default, we always check with tflite::VerifyModelBuffer. If - /// extra_verifier is supplied, the buffer is checked against the - /// extra_verifier after the check against tflite::VerifyModelBuilder. The - /// caller retains ownership of the buffer and should keep it alive until the - /// returned object is destroyed. Caller retains ownership of `error_reporter` - /// and must ensure its lifetime is longer than the FlatBufferModel instance. - /// Returns a nullptr in case of failure. - static std::unique_ptr VerifyAndBuildFromBuffer( - const char* caller_owned_buffer, size_t buffer_size, - TfLiteVerifier* extra_verifier = nullptr, - ErrorReporter* error_reporter = DefaultErrorReporter()); - - /// Builds a model directly from a flatbuffer pointer - /// Caller retains ownership of the buffer and should keep it alive until the - /// returned object is destroyed. Caller retains ownership of `error_reporter` - /// and must ensure its lifetime is longer than the FlatBufferModel instance. - /// Returns a nullptr in case of failure. - static std::unique_ptr BuildFromModel( - const tflite::Model* caller_owned_model_spec, - ErrorReporter* error_reporter = DefaultErrorReporter()); - - // Releases memory or unmaps mmaped memory. - ~FlatBufferModel(); - - // Copying or assignment is disallowed to simplify ownership semantics. - FlatBufferModel(const FlatBufferModel&) = delete; - FlatBufferModel& operator=(const FlatBufferModel&) = delete; - - bool initialized() const { return model_ != nullptr; } - const tflite::Model* operator->() const { return model_; } - const tflite::Model* GetModel() const { return model_; } - ErrorReporter* error_reporter() const { return error_reporter_; } - const Allocation* allocation() const { return allocation_.get(); } - - // Returns the minimum runtime version from the flatbuffer. This runtime - // version encodes the minimum required interpreter version to run the - // flatbuffer model. If the minimum version can't be determined, an empty - // string will be returned. - // Note that the returned minimum version is a lower-bound but not a strict - // lower-bound; ops in the graph may not have an associated runtime version, - // in which case the actual required runtime might be greater than the - // reported minimum. - string GetMinimumRuntime() const; - - /// Returns true if the model identifier is correct (otherwise false and - /// reports an error). - bool CheckModelIdentifier() const; - - private: - /// Loads a model from a given allocation. FlatBufferModel will take over the - /// ownership of `allocation`, and delete it in destructor. The ownership of - /// `error_reporter`remains with the caller and must have lifetime at least - /// as much as FlatBufferModel. This is to allow multiple models to use the - /// same ErrorReporter instance. - FlatBufferModel(std::unique_ptr allocation, - ErrorReporter* error_reporter = DefaultErrorReporter()); - - /// Loads a model from Model flatbuffer. The `model` has to remain alive and - /// unchanged until the end of this flatbuffermodel's lifetime. - FlatBufferModel(const Model* model, ErrorReporter* error_reporter); - - /// Flatbuffer traverser pointer. (Model* is a pointer that is within the - /// allocated memory of the data allocated by allocation's internals. - const tflite::Model* model_ = nullptr; - /// The error reporter to use for model errors and subsequent errors when - /// the interpreter is created - ErrorReporter* error_reporter_; - /// The allocator used for holding memory of the model. Note that this will - /// be null if the client provides a tflite::Model directly. - std::unique_ptr allocation_; -}; - -namespace impl { - -/// Build an interpreter capable of interpreting `model`. -/// -/// model: A model whose lifetime must be at least as long as any -/// interpreter(s) created by the builder. In principle multiple interpreters -/// can be made from a single model. -/// op_resolver: An instance that implements the OpResolver interface, which -/// maps -/// custom op names and builtin op codes to op registrations. The lifetime -/// of the provided `op_resolver` object must be at least as long as the -/// InterpreterBuilder; unlike `model` and `error_reporter`, the `op_resolver` -/// does not need to exist for the duration of any created Interpreter -/// objects. -/// error_reporter: a functor that is called to report errors that handles -/// printf var arg semantics. The lifetime of the `error_reporter` object must -/// be greater than or equal to the Interpreter created by operator(). -/// -/// Returns a kTfLiteOk when successful and sets interpreter to a valid -/// Interpreter. Note: The user must ensure the model lifetime (and error -/// reporter, if provided) is at least as long as interpreter's lifetime. -class InterpreterBuilder { - public: - InterpreterBuilder(const FlatBufferModel& model, - const OpResolver& op_resolver); - /// Builds an interpreter given only the raw flatbuffer Model object (instead - /// of a FlatBufferModel). Mostly used for testing. - /// If `error_reporter` is null, then DefaultErrorReporter() is used. - InterpreterBuilder(const ::tflite::Model* model, - const OpResolver& op_resolver, - ErrorReporter* error_reporter = DefaultErrorReporter()); - ~InterpreterBuilder(); - InterpreterBuilder(const InterpreterBuilder&) = delete; - InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; - TfLiteStatus operator()(std::unique_ptr* interpreter); - TfLiteStatus operator()(std::unique_ptr* interpreter, - int num_threads); - - private: - TfLiteStatus BuildLocalIndexToRegistrationMapping(); - TfLiteStatus ParseNodes( - const flatbuffers::Vector>* operators, - Subgraph* subgraph); - TfLiteStatus ParseTensors( - const flatbuffers::Vector>* buffers, - const flatbuffers::Vector>* tensors, - Subgraph* subgraph); - TfLiteStatus ApplyDelegates(Interpreter* interpreter); - TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, - TfLiteQuantization* quantization, - const std::vector& dims); - TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity, - TfLiteSparsity** sparsity); - - const ::tflite::Model* model_; - const OpResolver& op_resolver_; - ErrorReporter* error_reporter_; - - std::vector flatbuffer_op_index_to_registration_; - std::vector unresolved_custom_ops_; - std::vector flatbuffer_op_index_to_registration_types_; - const Allocation* allocation_ = nullptr; - - bool has_flex_op_ = false; -}; - -} // namespace impl - using InterpreterBuilder = impl::InterpreterBuilder; } // namespace tflite diff --git a/tensorflow/lite/model_builder.cc b/tensorflow/lite/model_builder.cc new file mode 100644 index 00000000000..784c39f00c8 --- /dev/null +++ b/tensorflow/lite/model_builder.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/model_builder.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/lite/allocation.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/util.h" +#include "tensorflow/lite/version.h" + +#if defined(TFLITE_ENABLE_DEFAULT_PROFILER) +#include "tensorflow/lite/profiling/platform_profiler.h" +#endif + +namespace tflite { + +namespace { + +// Ensure that ErrorReporter is non-null. +ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { + return e ? e : DefaultErrorReporter(); +} + +} // namespace + +#ifndef TFLITE_MCU +// Loads a model from `filename`. If `mmap_file` is true then use mmap, +// otherwise make a copy of the model in a buffer. +std::unique_ptr GetAllocationFromFile(const char* filename, + bool mmap_file, + ErrorReporter* error_reporter, + bool use_nnapi) { + std::unique_ptr allocation; + if (mmap_file && MMAPAllocation::IsSupported()) { + allocation.reset(new MMAPAllocation(filename, error_reporter)); + } else { + allocation.reset(new FileCopyAllocation(filename, error_reporter)); + } + return allocation; +} + +std::unique_ptr FlatBufferModel::BuildFromFile( + const char* filename, ErrorReporter* error_reporter) { + error_reporter = ValidateErrorReporter(error_reporter); + + std::unique_ptr model; + auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, + error_reporter, /*use_nnapi=*/true); + model.reset(new FlatBufferModel(std::move(allocation), error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::VerifyAndBuildFromFile( + const char* filename, TfLiteVerifier* extra_verifier, + ErrorReporter* error_reporter) { + error_reporter = ValidateErrorReporter(error_reporter); + + std::unique_ptr model; + auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, + error_reporter, /*use_nnapi=*/true); + + flatbuffers::Verifier base_verifier( + reinterpret_cast(allocation->base()), + allocation->bytes()); + if (!VerifyModelBuffer(base_verifier)) { + TF_LITE_REPORT_ERROR(error_reporter, + "The model is not a valid Flatbuffer file"); + return nullptr; + } + + if (extra_verifier && + !extra_verifier->Verify(static_cast(allocation->base()), + allocation->bytes(), error_reporter)) { + return model; + } + model.reset(new FlatBufferModel(std::move(allocation), error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} +#endif + +std::unique_ptr FlatBufferModel::BuildFromBuffer( + const char* caller_owned_buffer, size_t buffer_size, + ErrorReporter* error_reporter) { + error_reporter = ValidateErrorReporter(error_reporter); + + std::unique_ptr model; + std::unique_ptr allocation( + new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter)); + model.reset(new FlatBufferModel(std::move(allocation), error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::VerifyAndBuildFromBuffer( + const char* caller_owned_buffer, size_t buffer_size, + TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) { + error_reporter = ValidateErrorReporter(error_reporter); + + flatbuffers::Verifier base_verifier( + reinterpret_cast(caller_owned_buffer), buffer_size); + if (!VerifyModelBuffer(base_verifier)) { + TF_LITE_REPORT_ERROR(error_reporter, + "The model is not a valid Flatbuffer buffer"); + return nullptr; + } + + if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer, + buffer_size, error_reporter)) { + return nullptr; + } + + return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter); +} + +std::unique_ptr FlatBufferModel::BuildFromModel( + const tflite::Model* caller_owned_model_spec, + ErrorReporter* error_reporter) { + error_reporter = ValidateErrorReporter(error_reporter); + + std::unique_ptr model; + model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +string FlatBufferModel::GetMinimumRuntime() const { + if (!model_ || !model_->metadata()) return ""; + + for (int i = 0; i < model_->metadata()->size(); ++i) { + auto metadata = model_->metadata()->Get(i); + if (metadata->name()->str() == "min_runtime_version") { + auto buf = metadata->buffer(); + auto* buffer = (*model_->buffers())[buf]; + auto* array = buffer->data(); + // Get the real length of the runtime string, since there might be + // trailing + // '\0's in the buffer. + for (int len = 0; len < array->size(); ++len) { + if (array->data()[len] == '\0') { + return string(reinterpret_cast(array->data()), len); + } + } + // If there is no '\0' in the buffer, this indicates that the flatbuffer + // is malformed. + TF_LITE_REPORT_ERROR( + error_reporter_, + "Min_runtime_version in model metadata is malformed"); + break; + } + } + return ""; +} + +bool FlatBufferModel::CheckModelIdentifier() const { + if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { + const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); + error_reporter_->Report( + "Model provided has model identifier '%c%c%c%c', should be '%s'\n", + ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); + return false; + } + return true; +} + +FlatBufferModel::FlatBufferModel(const Model* model, + ErrorReporter* error_reporter) + : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {} + +FlatBufferModel::FlatBufferModel(std::unique_ptr allocation, + ErrorReporter* error_reporter) + : error_reporter_(ValidateErrorReporter(error_reporter)), + allocation_(std::move(allocation)) { + if (!allocation_->valid() || !CheckModelIdentifier()) return; + + model_ = ::tflite::GetModel(allocation_->base()); +} + +FlatBufferModel::~FlatBufferModel() {} + +} // namespace tflite diff --git a/tensorflow/lite/model_builder.h b/tensorflow/lite/model_builder.h new file mode 100644 index 00000000000..ac05223b6a8 --- /dev/null +++ b/tensorflow/lite/model_builder.h @@ -0,0 +1,179 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// Deserialization infrastructure for tflite. Provides functionality +/// to go from a serialized tflite model in flatbuffer format to an +/// interpreter. +/// +#ifndef TENSORFLOW_LITE_MODEL_BUILDER_H_ +#define TENSORFLOW_LITE_MODEL_BUILDER_H_ + +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { + +/// Abstract interface that verifies whether a given model is legit. +/// It facilitates the use-case to verify and build a model without loading it +/// twice. +class TfLiteVerifier { + public: + /// Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + +/// An RAII object that represents a read-only tflite model, copied from disk, +/// or mmapped. This uses flatbuffers as the serialization format. +/// +/// NOTE: The current API requires that a FlatBufferModel instance be kept alive +/// by the client as long as it is in use by any dependent Interpreter +/// instances. +///

+/// using namespace tflite;
+/// StderrReporter error_reporter;
+/// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
+///                                             &error_reporter);
+/// MyOpResolver resolver;  // You need to subclass OpResolver to provide
+///                         // implementations.
+/// InterpreterBuilder builder(*model, resolver);
+/// std::unique_ptr interpreter;
+/// if(builder(&interpreter) == kTfLiteOk) {
+///   .. run model inference with interpreter
+/// }
+/// 
+/// +/// OpResolver must be defined to provide your kernel implementations to the +/// interpreter. This is environment specific and may consist of just the +/// builtin ops, or some custom operators you defined to extend tflite. +class FlatBufferModel { + public: + /// Builds a model based on a file. + /// Caller retains ownership of `error_reporter` and must ensure its lifetime + /// is longer than the FlatBufferModel instance. + /// Returns a nullptr in case of failure. + static std::unique_ptr BuildFromFile( + const char* filename, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + /// Verifies whether the content of the file is legit, then builds a model + /// based on the file. + /// The extra_verifier argument is an additional optional verifier for the + /// file contents. By default, we always check with tflite::VerifyModelBuffer. + /// If extra_verifier is supplied, the file contents is also checked against + /// the extra_verifier after the check against tflite::VerifyModelBuilder. + /// Caller retains ownership of `error_reporter` and must ensure its lifetime + /// is longer than the FlatBufferModel instance. + /// Returns a nullptr in case of failure. + static std::unique_ptr VerifyAndBuildFromFile( + const char* filename, TfLiteVerifier* extra_verifier = nullptr, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + /// Builds a model based on a pre-loaded flatbuffer. + /// Caller retains ownership of the buffer and should keep it alive until + /// the returned object is destroyed. Caller also retains ownership of + /// `error_reporter` and must ensure its lifetime is longer than the + /// FlatBufferModel instance. + /// Returns a nullptr in case of failure. + /// NOTE: this does NOT validate the buffer so it should NOT be called on + /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case + static std::unique_ptr BuildFromBuffer( + const char* caller_owned_buffer, size_t buffer_size, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + /// Verifies whether the content of the buffer is legit, then builds a model + /// based on the pre-loaded flatbuffer. + /// The extra_verifier argument is an additional optional verifier for the + /// buffer. By default, we always check with tflite::VerifyModelBuffer. If + /// extra_verifier is supplied, the buffer is checked against the + /// extra_verifier after the check against tflite::VerifyModelBuilder. The + /// caller retains ownership of the buffer and should keep it alive until the + /// returned object is destroyed. Caller retains ownership of `error_reporter` + /// and must ensure its lifetime is longer than the FlatBufferModel instance. + /// Returns a nullptr in case of failure. + static std::unique_ptr VerifyAndBuildFromBuffer( + const char* caller_owned_buffer, size_t buffer_size, + TfLiteVerifier* extra_verifier = nullptr, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + /// Builds a model directly from a flatbuffer pointer + /// Caller retains ownership of the buffer and should keep it alive until the + /// returned object is destroyed. Caller retains ownership of `error_reporter` + /// and must ensure its lifetime is longer than the FlatBufferModel instance. + /// Returns a nullptr in case of failure. + static std::unique_ptr BuildFromModel( + const tflite::Model* caller_owned_model_spec, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Releases memory or unmaps mmaped memory. + ~FlatBufferModel(); + + // Copying or assignment is disallowed to simplify ownership semantics. + FlatBufferModel(const FlatBufferModel&) = delete; + FlatBufferModel& operator=(const FlatBufferModel&) = delete; + + bool initialized() const { return model_ != nullptr; } + const tflite::Model* operator->() const { return model_; } + const tflite::Model* GetModel() const { return model_; } + ErrorReporter* error_reporter() const { return error_reporter_; } + const Allocation* allocation() const { return allocation_.get(); } + + // Returns the minimum runtime version from the flatbuffer. This runtime + // version encodes the minimum required interpreter version to run the + // flatbuffer model. If the minimum version can't be determined, an empty + // string will be returned. + // Note that the returned minimum version is a lower-bound but not a strict + // lower-bound; ops in the graph may not have an associated runtime version, + // in which case the actual required runtime might be greater than the + // reported minimum. + string GetMinimumRuntime() const; + + /// Returns true if the model identifier is correct (otherwise false and + /// reports an error). + bool CheckModelIdentifier() const; + + private: + /// Loads a model from a given allocation. FlatBufferModel will take over the + /// ownership of `allocation`, and delete it in destructor. The ownership of + /// `error_reporter`remains with the caller and must have lifetime at least + /// as much as FlatBufferModel. This is to allow multiple models to use the + /// same ErrorReporter instance. + FlatBufferModel(std::unique_ptr allocation, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + /// Loads a model from Model flatbuffer. The `model` has to remain alive and + /// unchanged until the end of this flatbuffermodel's lifetime. + FlatBufferModel(const Model* model, ErrorReporter* error_reporter); + + /// Flatbuffer traverser pointer. (Model* is a pointer that is within the + /// allocated memory of the data allocated by allocation's internals. + const tflite::Model* model_ = nullptr; + /// The error reporter to use for model errors and subsequent errors when + /// the interpreter is created + ErrorReporter* error_reporter_; + /// The allocator used for holding memory of the model. Note that this will + /// be null if the client provides a tflite::Model directly. + std::unique_ptr allocation_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MODEL_BUILDER_H_ From 5522bfa37f967a3e7414a453227ea5d8d119a4d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 19:11:07 -0700 Subject: [PATCH 481/492] Updated reduce_sum description to include comments in code sample PiperOrigin-RevId: 302570920 Change-Id: I21f2967ee52cd22c72b727e5a01af0ef50da7dc2 --- tensorflow/python/ops/math_ops.py | 39 ++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 9395016bd20..bf725b34e0b 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1712,20 +1712,41 @@ def reduce_sum(input_tensor, axis=None, keepdims=False, name=None): For example: - ```python - x = tf.constant([[1, 1, 1], [1, 1, 1]]) - tf.reduce_sum(x) # 6 - tf.reduce_sum(x, 0) # [2, 2, 2] - tf.reduce_sum(x, 1) # [3, 3] - tf.reduce_sum(x, 1, keepdims=True) # [[3], [3]] - tf.reduce_sum(x, [0, 1]) # 6 - ``` + >>> # x has a shape of (2, 3) (two rows and three columns): + >>> x = tf.constant([[1, 1, 1], [1, 1, 1]]) + >>> x.numpy() + array([[1, 1, 1], + [1, 1, 1]], dtype=int32) + >>> # sum all the elements + >>> # 1 + 1 + 1 + 1 + 1+ 1 = 6 + >>> tf.reduce_sum(x).numpy() + 6 + >>> # reduce along the first dimension + >>> # the result is [1, 1, 1] + [1, 1, 1] = [2, 2, 2] + >>> tf.reduce_sum(x, 0).numpy() + array([2, 2, 2], dtype=int32) + >>> # reduce along the second dimension + >>> # the result is [1, 1] + [1, 1] + [1, 1] = [3, 3] + >>> tf.reduce_sum(x, 1).numpy() + array([3, 3], dtype=int32) + >>> # keep the original dimensions + >>> tf.reduce_sum(x, 1, keepdims=True).numpy() + array([[3], + [3]], dtype=int32) + >>> # reduce along both dimensions + >>> # the result is 1 + 1 + 1 + 1 + 1 + 1 = 6 + >>> # or, equivalently, reduce along rows, then reduce the resultant array + >>> # [1, 1, 1] + [1, 1, 1] = [2, 2, 2] + >>> # 2 + 2 + 2 = 6 + >>> tf.reduce_sum(x, [0, 1]).numpy() + 6 + Args: input_tensor: The tensor to reduce. Should have numeric type. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), - rank(input_tensor))`. + rank(input_tensor)]`. keepdims: If true, retains reduced dimensions with length 1. name: A name for the operation (optional). From b009488e56ab7a58b0690f6779353c11e1dfaf0a Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Mon, 23 Mar 2020 19:13:13 -0700 Subject: [PATCH 482/492] Update comment to match w/ NHWC. PiperOrigin-RevId: 302571156 Change-Id: Idfd8ea934d4324ccd7aae5911f16e94229522105 --- tensorflow/lite/schema/schema_v0.fbs | 2 +- tensorflow/lite/schema/schema_v1.fbs | 2 +- tensorflow/lite/schema/schema_v2.fbs | 2 +- tensorflow/lite/schema/schema_v3.fbs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/schema/schema_v0.fbs b/tensorflow/lite/schema/schema_v0.fbs index a080bbdaab4..e543df613cb 100644 --- a/tensorflow/lite/schema/schema_v0.fbs +++ b/tensorflow/lite/schema/schema_v0.fbs @@ -35,7 +35,7 @@ table QuantizationParameters { table Tensor { // The tensor shape. The meaning of each entry is operator-specific but - // builtin ops use: [batch size, number of channels, height, width] (That's + // builtin ops use: [batch size, height, width, number of channels] (That's // Tensorflow's NHWC). shape:[int]; type:TensorType; diff --git a/tensorflow/lite/schema/schema_v1.fbs b/tensorflow/lite/schema/schema_v1.fbs index 779492e8ee3..d49ea8e7f05 100644 --- a/tensorflow/lite/schema/schema_v1.fbs +++ b/tensorflow/lite/schema/schema_v1.fbs @@ -40,7 +40,7 @@ table QuantizationParameters { table Tensor { // The tensor shape. The meaning of each entry is operator-specific but - // builtin ops use: [batch size, number of channels, height, width] (That's + // builtin ops use: [batch size, height, width, number of channels] (That's // Tensorflow's NHWC). shape:[int]; type:TensorType; diff --git a/tensorflow/lite/schema/schema_v2.fbs b/tensorflow/lite/schema/schema_v2.fbs index 94963a4f9c2..05464a7ea60 100644 --- a/tensorflow/lite/schema/schema_v2.fbs +++ b/tensorflow/lite/schema/schema_v2.fbs @@ -41,7 +41,7 @@ table QuantizationParameters { table Tensor { // The tensor shape. The meaning of each entry is operator-specific but - // builtin ops use: [batch size, number of channels, height, width] (That's + // builtin ops use: [batch size, height, width, number of channels] (That's // Tensorflow's NHWC). shape:[int]; type:TensorType; diff --git a/tensorflow/lite/schema/schema_v3.fbs b/tensorflow/lite/schema/schema_v3.fbs index 3b3c763ffbc..86d561a408a 100644 --- a/tensorflow/lite/schema/schema_v3.fbs +++ b/tensorflow/lite/schema/schema_v3.fbs @@ -47,7 +47,7 @@ table QuantizationParameters { table Tensor { // The tensor shape. The meaning of each entry is operator-specific but - // builtin ops use: [batch size, number of channels, height, width] (That's + // builtin ops use: [batch size, height, width, number of channels] (That's // Tensorflow's NHWC). shape:[int]; type:TensorType; From fd5b26a3e9120773c93fd6303404989615df2a7d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 19:18:04 -0700 Subject: [PATCH 483/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302571728 Change-Id: I03e48dfc97e368f82061eff529a38b6c8c021159 --- tensorflow/go/op/wrappers.go | 45 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 75d86f71b78..56a3aa205b9 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -14074,6 +14074,7 @@ func TridiagonalSolvePartialPivoting(value bool) TridiagonalSolveAttr { // On CPU, solution is computed via Gaussian elimination with or without partial // pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE // library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv +// Partial pivoting is not yet supported by XLA backends. // // Arguments: // diagonals: Tensor of shape `[..., 3, M]` whose innermost 2 dimensions represent the @@ -19036,7 +19037,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20107,7 +20108,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21279,7 +21280,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21987,7 +21988,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22183,7 +22184,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22252,7 +22253,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22367,7 +22368,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22426,7 +22427,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22600,7 +22601,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22977,7 +22978,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25320,7 +25321,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25383,7 +25384,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25626,7 +25627,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26110,7 +26111,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40308,7 +40309,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45834,7 +45835,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46686,7 +46687,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46749,7 +46750,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 0de8de465c121a8497270c4176960ada54675c5b Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Mon, 23 Mar 2020 19:36:09 -0700 Subject: [PATCH 484/492] Move hardware to its own header file so it can be included by other files separately. PiperOrigin-RevId: 302573879 Change-Id: Ibcec7485fabe3e3df08cc1b486afd808b7ce2cf0 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../lite/experimental/estimators/estimator.h | 7 +----- .../lite/experimental/estimators/hardware.h | 25 +++++++++++++++++++ 3 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index c4314a86d92..3e86b1268b8 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -215,6 +215,7 @@ cc_library( "utils/attribute_utils.cc", ], hdrs = [ + "experimental/estimators/hardware.h", "ir/tfl_ops.h", "transforms/passes.h", "utils/attribute_utils.h", diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h b/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h index 7d58fc41ab3..c4a509945fa 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h @@ -17,14 +17,9 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" -namespace hardware { -// Empty classes that represents hardware types. -class CPU {}; -class GPU {}; -} // namespace hardware - template class TFLiteCostEstimator { public: diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h b/tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h new file mode 100644 index 00000000000..c5fffa6f3d7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_ + +namespace hardware { +// Empty classes that represents hardware types. +class CPU {}; +class GPU {}; +} // namespace hardware + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_ From b7f82785aded23264044cac4ade88f773ed57ec7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Mar 2020 19:49:20 -0700 Subject: [PATCH 485/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302575337 Change-Id: Ieebd6f603a02906ca1b1a5f41568ccf1aabae29d --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 56a3aa205b9..d8f12fab3d2 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19037,7 +19037,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20108,7 +20108,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21280,7 +21280,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21988,7 +21988,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22184,7 +22184,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22253,7 +22253,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22368,7 +22368,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22427,7 +22427,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22601,7 +22601,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22978,7 +22978,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25321,7 +25321,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25384,7 +25384,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25627,7 +25627,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26111,7 +26111,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40309,7 +40309,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45835,7 +45835,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46687,7 +46687,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46750,7 +46750,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From f1efe7d5684bfe368993d99a281d23b682190b18 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Mon, 23 Mar 2020 20:11:17 -0700 Subject: [PATCH 486/492] add gpu target to avg_pool PiperOrigin-RevId: 302578305 Change-Id: I22c469b634feb30c4b094cb8704031c14e827ea1 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../estimators/gpu_estimator.h.inc | 31 +++++++++++++++++++ tensorflow/compiler/mlir/lite/ir/tfl_ops.h | 2 ++ tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 5 ++- 4 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 3e86b1268b8..61919204f9a 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -206,6 +206,7 @@ cc_library( name = "tensorflow_lite", srcs = [ "experimental/estimators/estimator.h", + "experimental/estimators/gpu_estimator.h.inc", "ir/tfl_ops.cc", "ir/tfl_ops.cc.inc", "ir/tfl_ops.h.inc", diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc new file mode 100644 index 00000000000..c75e84f09f2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_ + +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_ diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 3755bf490b9..baef9a41e3a 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -53,6 +53,8 @@ class TensorFlowLiteDialect : public Dialect { #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" +// Include all specializes estimators below this line +#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc" } // end namespace TFL } // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index c90fdfbfe1c..4f560913593 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -480,7 +480,10 @@ Note this is a custom op that is not supported in the standard runtime. } def TFL_AveragePool2DOp: - TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> { + TFL_Op<"average_pool_2d", + [NoSideEffect, + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Average_pool_2d operator"; let description = [{ From 32a62ac58bc248195b01cf0ca293a426a8b1bced Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Mon, 23 Mar 2020 21:34:39 -0700 Subject: [PATCH 487/492] Fix Kokoro iOS breakage. PiperOrigin-RevId: 302588422 Change-Id: Ida7270db05a141247930c9c92fe22f9cb9c39505 --- tensorflow/lite/delegates/gpu/metal/api.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 6abcbcaed4f..9232b527af3 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -51,6 +51,7 @@ namespace tflite { namespace gpu { namespace metal { namespace { + bool IsWidthBroadcastedForSecondInput( const std::vector>*>& inputs) { return inputs.size() == 2 && @@ -232,7 +233,7 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, broadcast); } else { - return absl::UnimplementedError( + return UnimplementedError( "No support of multiply with more than 2 inputs"); } } From 7247fefebbbae4ba02a4aa01d18c503719eb72c8 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Mon, 23 Mar 2020 21:35:06 -0700 Subject: [PATCH 488/492] Utilize weak symbol to apply XNNPACK delegate by default in TFLite. PiperOrigin-RevId: 302588479 Change-Id: I01fa8964ae4f487224f8da2b49f09360aea0dcf8 --- tensorflow/lite/BUILD | 41 ++++++++++++++++++ tensorflow/lite/core/macros.h | 19 +++++++++ tensorflow/lite/interpreter_builder.cc | 53 +++++++++++++++-------- tensorflow/lite/interpreter_builder.h | 3 +- tensorflow/lite/model_xnnpack_test.cc | 59 ++++++++++++++++++++++++++ tensorflow/lite/tflite_with_xnnpack.cc | 30 +++++++++++++ 6 files changed, 186 insertions(+), 19 deletions(-) create mode 100644 tensorflow/lite/model_xnnpack_test.cc create mode 100644 tensorflow/lite/tflite_with_xnnpack.cc diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index a2d8b40bbce..f3d0494f15a 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -308,6 +308,16 @@ cc_library( ], ) +cc_library( + name = "tflite_with_xnnpack", + srcs = ["tflite_with_xnnpack.cc"], + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + deps = [ + "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + ], +) + cc_test( name = "string_util_test", size = "small", @@ -435,6 +445,37 @@ tf_cc_test( ], ) +# Test model framework with the XNNPACK delegate. +cc_test( + name = "model_xnnpack_test", + size = "small", + srcs = [ + "model_xnnpack_test.cc", + ], + data = [ + "testdata/multi_add.bin", + ], + tags = [ + "no_windows", # No weak symbols with MSVC. + "tflite_not_portable_android", + "tflite_not_portable_ios", + ], + deps = [ + ":framework", + ":tflite_with_xnnpack", + ":util", + ":version", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/profiling:platform_profiler", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test OpResolver. cc_test( name = "mutable_op_resolver_test", diff --git a/tensorflow/lite/core/macros.h b/tensorflow/lite/core/macros.h index 5ff00e4814a..034ad8daac5 100644 --- a/tensorflow/lite/core/macros.h +++ b/tensorflow/lite/core/macros.h @@ -32,4 +32,23 @@ limitations under the License. #define TFLITE_EXPECT_TRUE(cond) (cond) #endif +// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but +// we avoid the absl dependency for binary size reasons. +#ifdef __has_attribute +#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x) +#else +#define TFLITE_HAS_ATTRIBUTE(x) 0 +#endif + +#if (TFLITE_HAS_ATTRIBUTE(weak) || \ + (defined(__GNUC__) && !defined(__clang__))) && \ + !(defined(__llvm__) && defined(_WIN32)) && !defined(__MINGW32__) +#undef TFLITE_ATTRIBUTE_WEAK +#define TFLITE_ATTRIBUTE_WEAK __attribute__((weak)) +#define TFLITE_HAS_ATTRIBUTE_WEAK 1 +#else +#define TFLITE_ATTRIBUTE_WEAK +#define TFLITE_HAS_ATTRIBUTE_WEAK 0 +#endif + #endif // TENSORFLOW_LITE_CORE_MACROS_H_ diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index ef8f5a8773a..5d7807cd291 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -97,23 +97,25 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src, const char* kEmptyTensorName = ""; -// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but -// we avoid the absl dependency for binary size reasons. -#ifdef __has_attribute -#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x) -#else -#define TFLITE_HAS_ATTRIBUTE(x) 0 -#endif +#if TFLITE_HAS_ATTRIBUTE_WEAK +// Using weak symbols to create a delegate allows automatic injection of the +// delegate simply by adding it as a dependency. -#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__)) -// Using weak symbols for the flex delegate allows automatic injection of the -// delegate simply by adding it as a dependency. See also the strong override in +// For flex delegate, see also the strong override in // lite/delegates/flex/delegate.cc. -__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { +TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { + return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); +} + +// For XNNPACK delegate, see also the strong override in +// lite/enable_xnnpack_delegate.cc. +TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireXNNPACKDelegate( + int num_threads) { return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); } #else Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr; +Interpreter::TfLiteDelegatePtr (*AcquireXNNPACKDelegate)(int) = nullptr; #endif namespace impl { @@ -415,6 +417,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors( return kEmptyTensorName; }; + num_fp32_tensors_ = 0; for (int i = 0; i < tensors->size(); ++i) { const auto* tensor = tensors->Get(i); std::vector dims = FlatBufferIntArrayToVector(tensor->shape()); @@ -425,6 +428,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; continue; } + if (type == kTfLiteFloat32) { + ++num_fp32_tensors_; + } auto get_readonly_data = [&](const char** buffer_data, size_t* buffer_size) { // TODO(aselle): Check what happens if we have an unspecified size @@ -507,12 +513,23 @@ TfLiteStatus InterpreterBuilder::ParseTensors( return status; } -TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) { - // Apply Flex delegate if applicable. - if (!has_flex_op_ || AcquireFlexDelegate == nullptr) { - return kTfLiteOk; - } else if (auto flex_delegate = AcquireFlexDelegate()) { - return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate)); +TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter, + int num_threads) { + // First, apply XNNPACK delegate if applicable. + if (AcquireXNNPACKDelegate && num_fp32_tensors_ > 0) { + if (auto xnnpack_delegate = AcquireXNNPACKDelegate(num_threads)) { + // The execution will fall back to default implementation if the XNNPACK + // delegate fails to be applied. Therefore, we ignore the return status + // here and let it fall through the rest of the code. + interpreter->ModifyGraphWithDelegate(std::move(xnnpack_delegate)); + } + } + + // Secondly, apply Flex delegate if applicable. + if (has_flex_op_ && AcquireFlexDelegate) { + if (auto flex_delegate = AcquireFlexDelegate()) { + return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate)); + } } return kTfLiteOk; @@ -625,7 +642,7 @@ TfLiteStatus InterpreterBuilder::operator()( modified_subgraph->SetVariables(std::move(variables)); } - if (ApplyDelegates(interpreter->get()) != kTfLiteOk) + if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk) return cleanup_and_error(); return kTfLiteOk; diff --git a/tensorflow/lite/interpreter_builder.h b/tensorflow/lite/interpreter_builder.h index 1d150d6f1d4..1b8ae5a8e68 100644 --- a/tensorflow/lite/interpreter_builder.h +++ b/tensorflow/lite/interpreter_builder.h @@ -78,7 +78,7 @@ class InterpreterBuilder { const flatbuffers::Vector>* buffers, const flatbuffers::Vector>* tensors, Subgraph* subgraph); - TfLiteStatus ApplyDelegates(Interpreter* interpreter); + TfLiteStatus ApplyDelegates(Interpreter* interpreter, int num_threads); TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, TfLiteQuantization* quantization, const std::vector& dims); @@ -95,6 +95,7 @@ class InterpreterBuilder { const Allocation* allocation_ = nullptr; bool has_flex_op_ = false; + int num_fp32_tensors_ = 0; }; } // namespace impl diff --git a/tensorflow/lite/model_xnnpack_test.cc b/tensorflow/lite/model_xnnpack_test.cc new file mode 100644 index 00000000000..9c06147f602 --- /dev/null +++ b/tensorflow/lite/model_xnnpack_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/model.h" + +#include + +#include +#include "tensorflow/lite/core/macros.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/util.h" + +namespace tflite { + +TEST(FloatModel, WithXnnpackDelegate) { + // Note: this graph will be fully delegated by the XNNPACK delegate. + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/multi_add.bin"); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + ASSERT_EQ(InterpreterBuilder(*model, + ops::builtin::BuiltinOpResolver{})(&interpreter), + kTfLiteOk); + ASSERT_TRUE(interpreter); + + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); + +#if TFLITE_HAS_ATTRIBUTE_WEAK + // As the graph is fully delegated by XNNPACK delegate, we will expect the + // following: + EXPECT_EQ(1, interpreter->execution_plan().size()); + int first_node_id = interpreter->execution_plan()[0]; + const auto& first_node_reg = + interpreter->node_and_registration(first_node_id)->second; + const std::string op_name = GetOpNameByRegistration(first_node_reg); + EXPECT_EQ("DELEGATE TfLiteXNNPackDelegate", op_name); +#endif +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/tflite_with_xnnpack.cc b/tensorflow/lite/tflite_with_xnnpack.cc new file mode 100644 index 00000000000..c8c2c2e02c1 --- /dev/null +++ b/tensorflow/lite/tflite_with_xnnpack.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" + +namespace tflite { +// Corresponding weak declaration found in lite/model.cc. +std::unique_ptr +AcquireXNNPACKDelegate(int num_threads) { + auto opts = TfLiteXNNPackDelegateOptionsDefault(); + // Note that we don't want to use the thread pool for num_threads == 1. + opts.num_threads = num_threads > 1 ? num_threads : 0; + return std::unique_ptr( + TfLiteXNNPackDelegateCreate(&opts), TfLiteXNNPackDelegateDelete); +} +} // namespace tflite From be9d33754d4f7d8016a863675d37809215ff3dab Mon Sep 17 00:00:00 2001 From: Tiezhen WANG Date: Tue, 24 Mar 2020 00:09:09 -0700 Subject: [PATCH 489/492] TFLM: Move forward with the original CL and fix bluepill test by adding -fno-threadsafe-statics flag. PiperOrigin-RevId: 302606302 Change-Id: Iaff9548f5aa7bbdfc81b981b300098f5c3ed8dea --- .../lite/micro/kernels/fully_connected.cc | 25 ++++- .../micro/kernels/fully_connected_test.cc | 1 - tensorflow/lite/micro/test_helpers.h | 3 - tensorflow/lite/micro/testing/BUILD | 1 + tensorflow/lite/micro/testing/test_utils.cc | 103 ++++++++++++++++-- tensorflow/lite/micro/testing/test_utils.h | 63 ++++++++++- .../tools/make/targets/bluepill_makefile.inc | 1 + 7 files changed, 176 insertions(+), 21 deletions(-) diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 64bf788f538..91df80b328c 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -71,18 +71,35 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + OpData* data = nullptr; + TfLiteStatus status = context->AllocatePersistentBuffer( + context, sizeof(OpData), reinterpret_cast(&data)); + if (status != kTfLiteOk || data == nullptr) { + return nullptr; + } + return data; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + auto* params = + reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); + + TfLiteType data_type = input->type; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); + return kTfLiteOk; } @@ -178,11 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteType data_type = input->type; - OpData local_data_object; - OpData* data = &local_data_object; - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, - filter, bias, output, data)); + OpData* data = reinterpret_cast(node->user_data); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 0859e4af591..4687ae89108 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -49,7 +49,6 @@ void TestFullyConnectedFloat( TfLiteContext context; PopulateContext(tensors, tensors_size, micro_test::reporter, &context); - ::tflite::ops::micro::AllOpsResolver resolver; const TfLiteRegistration* registration = resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 010e1f9e336..f4e7fa8dfba 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -58,9 +58,6 @@ CreateFlatbufferBuffers(); // Performs a simple string comparison without requiring standard C library. int TestStrcmp(const char* a, const char* b); -// Wrapper to forward kernel errors to the interpreter's error reporter. -void ReportOpError(struct TfLiteContext* context, const char* format, ...); - void PopulateContext(TfLiteTensor* tensors, int tensors_size, TfLiteContext* context); diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 01bdffc6892..42f25f0e8b0 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -17,6 +17,7 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", ], diff --git a/tensorflow/lite/micro/testing/test_utils.cc b/tensorflow/lite/micro/testing/test_utils.cc index 9f7803fcf62..5fd0161d621 100644 --- a/tensorflow/lite/micro/testing/test_utils.cc +++ b/tensorflow/lite/micro/testing/test_utils.cc @@ -15,24 +15,107 @@ limitations under the License. #include "tensorflow/lite/micro/testing/test_utils.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" + namespace tflite { namespace testing { +TfLiteStatus FakeAllocator::AllocatePersistentBuffer(size_t bytes, void** ptr) { + uint8_t* addr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); + *ptr = addr; + return kTfLiteOk; +} + +TfLiteStatus FakeAllocator::RequestScratchBufferInArena(int node_idx, + size_t bytes, + int* buffer_idx) { + if (scratch_buffers_count_ >= max_scratch_buffers_count_) { + return kTfLiteError; + } + uint8_t* ptr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); + scratch_buffers_[scratch_buffers_count_] = ptr; + *buffer_idx = scratch_buffers_count_; + scratch_buffers_count_++; + return kTfLiteOk; +} + +void FakeAllocator::Reset() { + // Get A fresh memory allocator. + memory_allocator_ = CreateInPlaceSimpleMemoryAllocator(arena_, arena_size_); + TFLITE_DCHECK_NE(memory_allocator_, nullptr); + + // Allocate enough space holding pointers to the scrtach buffers. + scratch_buffers_ = + reinterpret_cast(memory_allocator_->AllocateFromTail( + sizeof(uint8_t*) * max_scratch_buffers_count_, alignof(uint8_t*))); + TFLITE_DCHECK_NE(scratch_buffers_, nullptr); + + scratch_buffers_count_ = 0; +} + +void* FakeAllocator::GetScratchBuffer(int buffer_idx) { + if (buffer_idx < 0 || buffer_idx >= scratch_buffers_count_) { + return nullptr; + } + return scratch_buffers_[buffer_idx]; +} + +TfLiteStatus FakeContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx, + size_t bytes, + void** ptr) { + return reinterpret_cast(ctx->impl_) + ->allocator_->AllocatePersistentBuffer(bytes, ptr); +} + +TfLiteStatus FakeContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx, + size_t bytes, + int* buffer_idx) { + FakeContextHelper* helper = reinterpret_cast(ctx->impl_); + // FakeAllocator doesn't do memory reusing so it doesn't need node_idx to + // calculate the lifetime of the scratch buffer. + int node_idx = -1; + return helper->allocator_->RequestScratchBufferInArena(node_idx, bytes, + buffer_idx); +} + +void* FakeContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) { + return reinterpret_cast(ctx->impl_) + ->allocator_->GetScratchBuffer(buffer_idx); +} + +void FakeContextHelper::ReportOpError(struct TfLiteContext* context, + const char* format, ...) { + FakeContextHelper* helper = static_cast(context->impl_); + va_list args; + va_start(args, format); + TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args); + va_end(args); +} + +namespace { +constexpr size_t kArenaSize = 10000; +constexpr int kMaxScratchBufferCount = 32; +uint8_t arena[kArenaSize]; +} // namespace + // TODO(b/141330728): Move this method elsewhere as part clean up. void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context) { + // This should be a large enough arena for each test cases. + static FakeAllocator allocator(arena, kArenaSize, kMaxScratchBufferCount); + static FakeContextHelper helper(error_reporter, &allocator); + // Reset the allocator so that it's ready for another test. + allocator.Reset(); + + *context = {}; + context->recommended_num_threads = 1; context->tensors_size = tensors_size; context->tensors = tensors; - context->impl_ = static_cast(error_reporter); - context->GetExecutionPlan = nullptr; - context->ResizeTensor = nullptr; - context->ReportError = ReportOpError; - context->AddTensors = nullptr; - context->GetNodeAndRegistration = nullptr; - context->ReplaceNodeSubsetsWithDelegateKernels = nullptr; - context->recommended_num_threads = 1; - context->GetExternalContext = nullptr; - context->SetExternalContext = nullptr; + context->impl_ = static_cast(&helper); + context->AllocatePersistentBuffer = helper.AllocatePersistentBuffer; + context->RequestScratchBufferInArena = helper.RequestScratchBufferInArena; + context->GetScratchBuffer = helper.GetScratchBuffer; + context->ReportError = helper.ReportOpError; for (int i = 0; i < tensors_size; ++i) { if (context->tensors[i].is_variable) { diff --git a/tensorflow/lite/micro/testing/test_utils.h b/tensorflow/lite/micro/testing/test_utils.h index 7aa1e9d488f..f7f5dff6bb1 100644 --- a/tensorflow/lite/micro/testing/test_utils.h +++ b/tensorflow/lite/micro/testing/test_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/tensor_utils.h" #include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -95,7 +96,67 @@ inline int32_t F2Q32(const float value, const float scale) { return static_cast(quantized); } -// TODO(b/141330728): Move this method elsewhere as part clean up. +// A fake version of MemoryAllocator that allocates everything from the tail +// without static memory planning or reusing. +// TODO(b/150260678): Consider splitting this into its own file and inherit from +// the same public interface as MicroAllocator. +class FakeAllocator { + public: + FakeAllocator(uint8_t* arena, size_t arena_size, + size_t max_scratch_buffers_count) + : arena_(arena), + arena_size_(arena_size), + max_scratch_buffers_count_(max_scratch_buffers_count) { + Reset(); + } + + TfLiteStatus AllocatePersistentBuffer(size_t bytes, void** ptr); + TfLiteStatus RequestScratchBufferInArena(int node_idx, size_t bytes, + int* buffer_idx); + void* GetScratchBuffer(int buffer_idx); + + // Reset the allocator to the intial state. + void Reset(); + + private: + uint8_t* arena_; + size_t arena_size_; + size_t max_scratch_buffers_count_; + + SimpleMemoryAllocator* memory_allocator_; + // An array of buffer pointers. + uint8_t** scratch_buffers_; + size_t scratch_buffers_count_ = 0; + static constexpr size_t kBufferAlignment = 16; +}; + +// A fake implementation of ContextHelper. Instead of forwarding requests to +// MicroAllocator, it calls into FakeAllocator. +// PopulateContext will point context->impl_ to an instance of this class. +// TODO(b/150260678): Consider moving this into the same file as FakeAllocator. +class FakeContextHelper { + public: + explicit FakeContextHelper(ErrorReporter* error_reporter, + FakeAllocator* allocator) + : allocator_(allocator), error_reporter_(error_reporter) {} + + static TfLiteStatus AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes, + void** ptr); + + static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx, + size_t bytes, + int* buffer_idx); + + static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx); + + static void ReportOpError(struct TfLiteContext* context, const char* format, + ...); + + private: + FakeAllocator* allocator_; + ErrorReporter* error_reporter_; +}; + void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context); diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc index 878067cf083..29a49288081 100644 --- a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc @@ -38,6 +38,7 @@ ifeq ($(TARGET), bluepill) -Wno-unused-parameter \ -Wno-write-strings \ -fno-delete-null-pointer-checks \ + -fno-threadsafe-statics \ -fomit-frame-pointer \ -fpermissive \ -nostdlib \ From 282828af67de29d13dd2c69d96413c030b02543c Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 24 Mar 2020 00:56:11 -0700 Subject: [PATCH 490/492] Fix path, it's talking about this test. fixes #37480 PiperOrigin-RevId: 302612383 Change-Id: If6cdb1149b08909d878b7808db26e64179a1a7e2 --- tensorflow/compiler/aot/codegen_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index babbd7fb2f5..26d160a4cb4 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -154,7 +154,7 @@ static void CompareWithGoldenFile( // To update the golden file, flip update_golden to true and run the // following: // bazel test --test_strategy=local \ - // third_party/tensorflow/compiler/aot:codegen_test + // "third_party/tensorflow/compiler/aot:codegen_test" const bool update_golden = false; string golden_file_name = GetDataDependencyFilepath(tensorflow_relative_golden_file_name); From fb739579c38594fdd3a333af25aff072e35f5c0c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Mar 2020 01:46:04 -0700 Subject: [PATCH 491/492] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 302618975 Change-Id: I6f5bebcba8fd2380d4693a68695e89ce3f73c9a1 --- tensorflow/go/op/wrappers.go | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index d8f12fab3d2..56a3aa205b9 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12021,7 +12021,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12032,7 +12032,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12250,7 +12250,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12261,7 +12261,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19037,7 +19037,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20108,7 +20108,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21280,7 +21280,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21988,7 +21988,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22184,7 +22184,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22253,7 +22253,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22368,7 +22368,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22427,7 +22427,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22601,7 +22601,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22978,7 +22978,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25321,7 +25321,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25384,7 +25384,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25627,7 +25627,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26111,7 +26111,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40309,7 +40309,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45835,7 +45835,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46687,7 +46687,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46750,7 +46750,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value From 99e754b3a189eefab15fdbf326115d312e44fc7b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Mar 2020 02:02:39 -0700 Subject: [PATCH 492/492] compat: Update forward compatibility horizon to 2020-03-24 PiperOrigin-RevId: 302620825 Change-Id: Iff2e34fcae81ab4a9533f8900e25b2ec35b069c3 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index da05db6f7f4..5d7ee54a469 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 23) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 24) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None